diff --git a/core/config.go b/core/config.go index d218df6..08c8bd6 100644 --- a/core/config.go +++ b/core/config.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "path/filepath" "regexp" + "time" "github.com/op/go-logging" "gopkg.in/yaml.v2" @@ -28,6 +29,7 @@ type ConfigInterface interface { GetDBConfig() DatabaseConfig GetAWSConfig() ConfigAWS GetTransportInfo() InfoInterface + GetHTTPClientConfig() *HTTPClientConfig GetUpdateInterval() int IsDebug() bool } @@ -41,15 +43,16 @@ type InfoInterface interface { // Config struct type Config struct { - Version string `yaml:"version"` - LogLevel logging.Level `yaml:"log_level"` - Database DatabaseConfig `yaml:"database"` - SentryDSN string `yaml:"sentry_dsn"` - HTTPServer HTTPServerConfig `yaml:"http_server"` - Debug bool `yaml:"debug"` - UpdateInterval int `yaml:"update_interval"` - ConfigAWS ConfigAWS `yaml:"config_aws"` - TransportInfo Info `yaml:"transport_info"` + Version string `yaml:"version"` + LogLevel logging.Level `yaml:"log_level"` + Database DatabaseConfig `yaml:"database"` + SentryDSN string `yaml:"sentry_dsn"` + HTTPServer HTTPServerConfig `yaml:"http_server"` + Debug bool `yaml:"debug"` + UpdateInterval int `yaml:"update_interval"` + ConfigAWS ConfigAWS `yaml:"config_aws"` + TransportInfo Info `yaml:"transport_info"` + HTTPClientConfig *HTTPClientConfig `yaml:"http_client"` } // Info struct @@ -79,6 +82,14 @@ type DatabaseConfig struct { ConnectionLifetime int `yaml:"connection_lifetime"` } +// HTTPClientConfig struct +type HTTPClientConfig struct { + Timeout time.Duration `yaml:"timeout"` + SSLVerification bool `yaml:"ssl_verification"` + MockAddress string `yaml:"mock_address"` + MockedDomains []string `yaml:"mocked_domains"` +} + // HTTPServerConfig struct type HTTPServerConfig struct { Host string `yaml:"host"` @@ -168,6 +179,11 @@ func (c Config) GetUpdateInterval() int { return c.UpdateInterval } +// GetHTTPClientConfig returns http client config +func (c Config) GetHTTPClientConfig() *HTTPClientConfig { + return c.HTTPClientConfig +} + // GetName transport name func (t Info) GetName() string { return t.Name diff --git a/core/config_test.go b/core/config_test.go index abe0e2d..9c796f2 100644 --- a/core/config_test.go +++ b/core/config_test.go @@ -20,7 +20,7 @@ type ConfigTest struct { data []byte } -func (c *ConfigTest) SetupTest() { +func (c *ConfigTest) SetupSuite() { c.data = []byte(` version: 3.2.1 @@ -102,7 +102,7 @@ func (c *ConfigTest) Test_GetConfigAWS() { assert.Equal(c.T(), "image/jpeg", c.config.GetAWSConfig().ContentType) } -func (c *ConfigTest) TearDownTest() { +func (c *ConfigTest) TearDownSuite() { _ = os.Remove(testConfigFile) } diff --git a/core/engine.go b/core/engine.go index b09abd2..ae98a07 100644 --- a/core/engine.go +++ b/core/engine.go @@ -2,6 +2,7 @@ package core import ( "html/template" + "net/http" "github.com/gin-gonic/gin" "github.com/gobuffalo/packr/v2" @@ -15,6 +16,7 @@ type Engine struct { Sentry Utils ginEngine *gin.Engine + httpClient *http.Client Logger *logging.Logger Config ConfigInterface LogFormatter logging.Formatter @@ -124,6 +126,28 @@ func (e *Engine) Router() *gin.Engine { return e.ginEngine } +// BuildHTTPClient builds HTTP client with provided configuration +func (e *Engine) BuildHTTPClient(replaceDefault ...bool) *Engine { + if e.Config.GetHTTPClientConfig() != nil { + if client, err := NewHTTPClientBuilder().FromEngine(e).Build(replaceDefault...); err != nil { + panic(err) + } else { + e.httpClient = client + } + } + + return e +} + +// HTTPClient returns inner http client or default http client +func (e *Engine) HTTPClient() *http.Client { + if e.httpClient == nil { + return http.DefaultClient + } else { + return e.httpClient + } +} + // ConfigureRouter will call provided callback with current gin.Engine, or panic if engine is not present func (e *Engine) ConfigureRouter(callback func(*gin.Engine)) *Engine { callback(e.Router()) diff --git a/core/engine_test.go b/core/engine_test.go index 0d33b43..244b472 100644 --- a/core/engine_test.go +++ b/core/engine_test.go @@ -19,7 +19,7 @@ type EngineTest struct { engine *Engine } -func (e *EngineTest) SetupTest() { +func (e *EngineTest) SetupSuite() { var ( db *sql.DB err error diff --git a/core/http_client_builder.go b/core/http_client_builder.go new file mode 100644 index 0000000..fe9f514 --- /dev/null +++ b/core/http_client_builder.go @@ -0,0 +1,227 @@ +package core + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "time" + + "github.com/pkg/errors" +) + +var ( + DefaultClient = http.DefaultClient + DefaultTransport = http.DefaultTransport +) + +// HTTPClientBuilder builds http client with mocks (if necessary) and timeout +type HTTPClientBuilder struct { + httpClient *http.Client + httpTransport *http.Transport + dialer *net.Dialer + engine *Engine + built bool + logging bool + timeout time.Duration + mockAddress string + mockHost string + mockPort string + mockedDomains []string +} + +// NewHTTPClientBuilder returns HTTPClientBuilder with default values +func NewHTTPClientBuilder() *HTTPClientBuilder { + return &HTTPClientBuilder{ + built: false, + httpClient: &http.Client{}, + httpTransport: &http.Transport{}, + timeout: 30 * time.Second, + mockAddress: "", + mockedDomains: []string{}, + logging: false, + } +} + +// SetTimeout sets timeout for http client +func (b *HTTPClientBuilder) SetTimeout(timeout time.Duration) *HTTPClientBuilder { + timeout = timeout * time.Second + b.timeout = timeout + b.httpClient.Timeout = timeout + return b +} + +// SetMockAddress sets mock address +func (b *HTTPClientBuilder) SetMockAddress(address string) *HTTPClientBuilder { + b.mockAddress = address + return b +} + +// AddMockedDomain adds new mocked domain +func (b *HTTPClientBuilder) AddMockedDomain(domain string) *HTTPClientBuilder { + b.mockedDomains = append(b.mockedDomains, domain) + return b +} + +// SetMockedDomains sets mocked domains from slice +func (b *HTTPClientBuilder) SetMockedDomains(domains []string) *HTTPClientBuilder { + b.mockedDomains = domains + return b +} + +// DisableSSLVerification disables SSL certificates verification in client +func (b *HTTPClientBuilder) DisableSSLVerification() *HTTPClientBuilder { + b.logf("WARNING: SSL verification is now disabled, don't use this parameter in production!") + + b.httpTransport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + return b +} + +// EnableLogging enables logging in mocks +func (b *HTTPClientBuilder) EnableLogging() *HTTPClientBuilder { + b.logging = true + return b +} + +// FromConfig fulfills mock configuration from HTTPClientConfig +func (b *HTTPClientBuilder) FromConfig(config *HTTPClientConfig) *HTTPClientBuilder { + if config == nil { + return b + } + + if config.MockAddress != "" { + b.mockAddress = config.MockAddress + b.mockedDomains = config.MockedDomains + } + + if !config.SSLVerification { + b.DisableSSLVerification() + } + + if config.Timeout > 0 { + b.SetTimeout(config.Timeout) + } + + return b +} + +// FromEngine fulfills mock configuration from ConfigInterface inside Engine +func (b *HTTPClientBuilder) FromEngine(engine *Engine) *HTTPClientBuilder { + b.engine = engine + b.logging = engine.Config.IsDebug() + return b.FromConfig(engine.Config.GetHTTPClientConfig()) +} + +// buildDialer initializes dialer with provided timeout +func (b *HTTPClientBuilder) buildDialer() *HTTPClientBuilder { + b.dialer = &net.Dialer{ + Timeout: b.timeout, + KeepAlive: b.timeout, + } + + return b +} + +// parseAddress parses address and returns error in case of error (port is necessary) +func (b *HTTPClientBuilder) parseAddress() error { + if host, port, err := net.SplitHostPort(b.mockAddress); err == nil { + b.mockHost = host + b.mockPort = port + return nil + } else { + return errors.Errorf("cannot split host and port: %s", err.Error()) + } +} + +// buildMocks builds mocks for http client +func (b *HTTPClientBuilder) buildMocks() error { + if b.dialer == nil { + return errors.New("dialer must be built first") + } + + if b.mockHost != "" && b.mockPort != "" && len(b.mockedDomains) > 0 { + b.logf("Mock address is \"%s\"\n", net.JoinHostPort(b.mockHost, b.mockPort)) + b.logf("Mocked domains: ") + + for _, domain := range b.mockedDomains { + b.logf(" - %s\n", domain) + } + + b.httpTransport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) { + if host, port, err := net.SplitHostPort(addr); err != nil { + return b.dialer.DialContext(ctx, network, addr) + } else { + for _, mock := range b.mockedDomains { + if mock == host { + oldAddr := addr + + if b.mockPort == "0" { + addr = net.JoinHostPort(b.mockHost, port) + } else { + addr = net.JoinHostPort(b.mockHost, b.mockPort) + } + + b.logf("Mocking \"%s\" with \"%s\"\n", oldAddr, addr) + } + } + } + + return b.dialer.DialContext(ctx, network, addr) + } + } + + return nil +} + +// logf prints logs via Engine or via fmt.Printf +func (b *HTTPClientBuilder) logf(format string, args ...interface{}) { + if b.logging { + if b.engine != nil && b.engine.Logger != nil { + b.engine.Logger.Infof(format, args...) + } else { + fmt.Printf(format, args...) + } + } +} + +// ReplaceDefault replaces default client and transport with generated ones +func (b *HTTPClientBuilder) ReplaceDefault() *HTTPClientBuilder { + if b.built { + http.DefaultClient = b.httpClient + http.DefaultTransport = b.httpTransport + } + + return b +} + +// RestoreDefault restores default client and transport after replacement +func (b *HTTPClientBuilder) RestoreDefault() *HTTPClientBuilder { + http.DefaultClient = DefaultClient + http.DefaultTransport = DefaultTransport + + return b +} + +// Build builds client, pass true to replace http.DefaultClient with generated one +func (b *HTTPClientBuilder) Build(replaceDefault ...bool) (*http.Client, error) { + if err := b.buildDialer().parseAddress(); err != nil { + return nil, err + } + + if err := b.buildMocks(); err != nil { + return nil, err + } + + b.built = true + b.httpClient.Transport = b.httpTransport + + if len(replaceDefault) > 0 && replaceDefault[0] { + b.ReplaceDefault() + } + + return b.httpClient, nil +} diff --git a/core/http_client_builder_test.go b/core/http_client_builder_test.go new file mode 100644 index 0000000..7e5c5cb --- /dev/null +++ b/core/http_client_builder_test.go @@ -0,0 +1,126 @@ +package core + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type HTTPClientBuilderTest struct { + suite.Suite + builder *HTTPClientBuilder +} + +func (t *HTTPClientBuilderTest) SetupSuite() { + t.builder = NewHTTPClientBuilder() +} + +func (t *HTTPClientBuilderTest) Test_SetTimeout() { + t.builder.SetTimeout(90) + + assert.Equal(t.T(), 90, t.builder.timeout) + assert.Equal(t.T(), 90, t.builder.httpClient.Timeout) +} + +func (t *HTTPClientBuilderTest) Test_SetMockAddress() { + addr := "http://mock.local:3004" + t.builder.SetMockAddress(addr) + + assert.Equal(t.T(), addr, t.builder.mockAddress) +} + +func (t *HTTPClientBuilderTest) Test_AddMockedDomain() { + domain := "example.com" + t.builder.AddMockedDomain(domain) + + assert.NotEmpty(t.T(), t.builder.mockedDomains) + assert.Equal(t.T(), domain, t.builder.mockedDomains[0]) +} + +func (t *HTTPClientBuilderTest) Test_SetMockedDomains() { + domains := []string{"example.com"} + t.builder.SetMockedDomains(domains) + + assert.NotEmpty(t.T(), t.builder.mockedDomains) + assert.Equal(t.T(), domains[0], t.builder.mockedDomains[0]) +} + +func (t *HTTPClientBuilderTest) Test_DisableSSLVerification() { + t.builder.DisableSSLVerification() + + assert.True(t.T(), t.builder.httpTransport.TLSClientConfig.InsecureSkipVerify) +} + +func (t *HTTPClientBuilderTest) Test_FromConfig() { + config := &HTTPClientConfig{ + SSLVerification: true, + MockAddress: "http://anothermock.local:3004", + MockedDomains: []string{"example.gov"}, + } + + t.builder.FromConfig(config) + + assert.Equal(t.T(), !config.SSLVerification, t.builder.httpTransport.TLSClientConfig.InsecureSkipVerify) + assert.Equal(t.T(), config.MockAddress, t.builder.mockAddress) + assert.Equal(t.T(), config.MockedDomains[0], t.builder.mockedDomains[0]) + assert.Equal(t.T(), config.Timeout*time.Second, t.builder.timeout) + assert.Equal(t.T(), config.Timeout*time.Second, t.builder.httpClient.Timeout) +} + +func (t *HTTPClientBuilderTest) Test_FromEngine() { + engine := &Engine{ + Config: Config{ + HTTPClientConfig: &HTTPClientConfig{ + SSLVerification: true, + MockAddress: "http://anothermock.local:3004", + MockedDomains: []string{"example.gov"}, + }, + Debug: false, + }, + } + + assert.Equal(t.T(), engine, t.builder.engine) +} + +func (t *HTTPClientBuilderTest) Test_buildDialer() { + t.builder.buildDialer() + + assert.NotNil(t.T(), t.builder.dialer) +} + +func (t *HTTPClientBuilderTest) Test_parseAddress() { + assert.NoError(t.T(), t.builder.parseAddress()) +} + +func (t *HTTPClientBuilderTest) Test_buildMocks() { + assert.NoError(t.T(), t.builder.buildMocks()) +} + +func (t *HTTPClientBuilderTest) Test_logf() { + defer func() { + assert.Nil(t.T(), recover()) + }() + + t.builder.logf("test %s", "string") +} + +func (t *HTTPClientBuilderTest) Test_Build() { + client, err := t.builder.Build(true) + + assert.NoError(t.T(), err) + assert.NotNil(t.T(), client) + assert.Equal(t.T(), client, http.DefaultClient) +} + +func (t *HTTPClientBuilderTest) Test_RestoreDefault() { + t.builder.RestoreDefault() + + assert.NotEqual(t.T(), http.DefaultClient, t.builder.httpClient) +} + +func Test_HTTPClientBuilder(t *testing.T) { + suite.Run(t, new(HTTPClientBuilderTest)) +} diff --git a/core/localizer_test.go b/core/localizer_test.go index e2e5251..b90151b 100644 --- a/core/localizer_test.go +++ b/core/localizer_test.go @@ -23,7 +23,7 @@ type LocalizerTest struct { localizer *Localizer } -func (l *LocalizerTest) SetupTest() { +func (l *LocalizerTest) SetupSuite() { if _, err := os.Stat(testTranslationsDir); err != nil && os.IsNotExist(err) { err := os.Mkdir(testTranslationsDir, os.ModePerm) require.Nil(l.T(), err) @@ -78,7 +78,7 @@ func (l *LocalizerTest) Test_BadRequestLocalized() { assert.Equal(l.T(), "Test message", resp.(ErrorResponse).Error) } -func (l *LocalizerTest) TearDownTest() { +func (l *LocalizerTest) TearDownSuite() { err := os.RemoveAll(testTranslationsDir) require.Nil(l.T(), err) } diff --git a/core/sentry_test.go b/core/sentry_test.go index 5f197e1..f0c87f3 100644 --- a/core/sentry_test.go +++ b/core/sentry_test.go @@ -24,7 +24,7 @@ type SentryTest struct { scalarTags *SentryTaggedScalar } -func (s *SentryTest) SetupTest() { +func (s *SentryTest) SetupSuite() { s.structTags = NewTaggedStruct(SampleStruct{}, "struct", map[string]string{"fake": "prop"}) s.scalarTags = NewTaggedScalar("", "scalar", "Scalar") require.Equal(s.T(), "struct", s.structTags.GetContextKey()) diff --git a/core/template_test.go b/core/template_test.go index 05e6e9f..c32dec7 100644 --- a/core/template_test.go +++ b/core/template_test.go @@ -23,7 +23,7 @@ type TemplateTest struct { renderer Renderer } -func (t *TemplateTest) SetupTest() { +func (t *TemplateTest) SetupSuite() { if _, err := os.Stat(testTemplatesDir); err != nil && os.IsNotExist(err) { err := os.Mkdir(testTemplatesDir, os.ModePerm) require.Nil(t.T(), err) diff --git a/core/utils_test.go b/core/utils_test.go index bcda9c7..b93dffb 100644 --- a/core/utils_test.go +++ b/core/utils_test.go @@ -29,7 +29,7 @@ func mgClient() *v1.MgClient { return v1.New(testMGURL, "token") } -func (u *UtilsTest) SetupTest() { +func (u *UtilsTest) SetupSuite() { logger := NewLogger("code", logging.DEBUG, DefaultLogFormatter()) awsConfig := ConfigAWS{ AccessKeyID: "access key id (will be removed)",