add proxy support to the HTTP client built by library

This commit is contained in:
Pavel 2024-06-06 14:36:24 +03:00 committed by GitHub
commit b6d2ee5d7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 24 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,6 +17,14 @@ import (
"github.com/retailcrm/mg-transport-core/v2/core/logger" "github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
const (
defaultDialerTimeout = 30 * time.Second
defaultIdleConnTimeout = 90 * time.Second
defaultTLSHandshakeTimeout = 10 * time.Second
defaultExpectContinueTimeout = 1 * time.Second
defaultMaxIdleConns = 100
)
// DefaultClient stores original http.DefaultClient. // DefaultClient stores original http.DefaultClient.
var DefaultClient = http.DefaultClient var DefaultClient = http.DefaultClient
@ -24,28 +33,29 @@ var DefaultTransport = http.DefaultTransport
// HTTPClientBuilder builds http client with mocks (if necessary) and timeout. // HTTPClientBuilder builds http client with mocks (if necessary) and timeout.
// Example: // Example:
// // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
// client, err := NewHTTPClientBuilder().
// SetTimeout(10).
// SetMockAddress("api_mock:3004").
// AddMockedDomain("google.com").
// SetSSLVerification(false).
// Build()
// //
// if err != nil { // // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
// fmt.Print(err) // client, err := NewHTTPClientBuilder().
// } // SetTimeout(10).
// SetMockAddress("api_mock:3004").
// AddMockedDomain("google.com").
// SetSSLVerification(false).
// Build()
// //
// // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate) // if err != nil {
// if resp, err := client.Get("https://google.com"); err == nil { // fmt.Print(err)
// if data, err := ioutil.ReadAll(resp.Body); err == nil { // }
// fmt.Printf("Data: %s", string(data)) //
// } else { // // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate)
// fmt.Print(err) // if resp, err := client.Get("https://google.com"); err == nil {
// } // if data, err := ioutil.ReadAll(resp.Body); err == nil {
// } else { // fmt.Printf("Data: %s", string(data))
// fmt.Print(err) // } else {
// } // fmt.Print(err)
// }
// } else {
// fmt.Print(err)
// }
type HTTPClientBuilder struct { type HTTPClientBuilder struct {
logger logger.Logger logger logger.Logger
httpClient *http.Client httpClient *http.Client
@ -64,11 +74,22 @@ type HTTPClientBuilder struct {
// NewHTTPClientBuilder returns HTTPClientBuilder with default values. // NewHTTPClientBuilder returns HTTPClientBuilder with default values.
func NewHTTPClientBuilder() *HTTPClientBuilder { func NewHTTPClientBuilder() *HTTPClientBuilder {
return &HTTPClientBuilder{ return &HTTPClientBuilder{
built: false, built: false,
httpClient: &http.Client{}, httpClient: &http.Client{},
httpTransport: &http.Transport{}, httpTransport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: defaultDialerTimeout,
KeepAlive: defaultDialerTimeout,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: defaultMaxIdleConns,
IdleConnTimeout: defaultIdleConnTimeout,
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
ExpectContinueTimeout: defaultExpectContinueTimeout,
},
tlsVersion: tls.VersionTLS12, tlsVersion: tls.VersionTLS12,
timeout: 30 * time.Second, timeout: defaultDialerTimeout,
mockAddress: "", mockAddress: "",
mockedDomains: []string{}, mockedDomains: []string{},
logging: false, logging: false,
@ -147,6 +168,11 @@ func (b *HTTPClientBuilder) SetLogging(flag bool) *HTTPClientBuilder {
return b return b
} }
func (b *HTTPClientBuilder) SetProxy(proxy func(*http.Request) (*url.URL, error)) *HTTPClientBuilder {
b.httpTransport.Proxy = proxy
return b
}
// FromConfig fulfills mock configuration from HTTPClientConfig. // FromConfig fulfills mock configuration from HTTPClientConfig.
func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder { func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder {
if config == nil { if config == nil {
@ -212,6 +238,7 @@ func (b *HTTPClientBuilder) buildMocks() error {
b.logf(" - %s\n", domain) b.logf(" - %s\n", domain)
} }
b.httpTransport.Proxy = nil
b.httpTransport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) { b.httpTransport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
var ( var (
host string host string

View File

@ -92,6 +92,13 @@ func (t *HTTPClientBuilderTest) Test_SetCertPool() {
assert.Equal(t.T(), pool, t.builder.httpTransport.TLSClientConfig.RootCAs) assert.Equal(t.T(), pool, t.builder.httpTransport.TLSClientConfig.RootCAs)
} }
func (t *HTTPClientBuilderTest) Test_SetProxy() {
t.builder.SetProxy(nil)
assert.Nil(t.T(), t.builder.httpTransport.Proxy)
t.builder.SetProxy(http.ProxyFromEnvironment)
assert.NotNil(t.T(), t.builder.httpTransport.Proxy)
}
func (t *HTTPClientBuilderTest) Test_FromConfigNil() { func (t *HTTPClientBuilderTest) Test_FromConfigNil() {
defer func() { defer func() {
assert.Nil(t.T(), recover()) assert.Nil(t.T(), recover())
@ -161,6 +168,7 @@ func (t *HTTPClientBuilderTest) Test_Build() {
assert.NoError(t.T(), err) assert.NoError(t.T(), err)
assert.NotNil(t.T(), client) assert.NotNil(t.T(), client)
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
assert.Equal(t.T(), client, http.DefaultClient) assert.Equal(t.T(), client, http.DefaultClient)
assert.Equal(t.T(), timeout*time.Second, client.Timeout) assert.Equal(t.T(), timeout*time.Second, client.Timeout)
assert.Equal(t.T(), pool, client.Transport.(*http.Transport).TLSClientConfig.RootCAs) assert.Equal(t.T(), pool, client.Transport.(*http.Transport).TLSClientConfig.RootCAs)
@ -290,6 +298,7 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
SetSSLVerification(false). SetSSLVerification(false).
Build() Build()
require.NoError(t.T(), err, "cannot build client") require.NoError(t.T(), err, "cannot build client")
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
resp, err := client.Get(mockProto + mockDomainAddr) resp, err := client.Get(mockProto + mockDomainAddr)
if err != nil && strings.Contains(err.Error(), "connection refused") { if err != nil && strings.Contains(err.Error(), "connection refused") {
@ -314,6 +323,7 @@ func (t *HTTPClientBuilderTest) Test_UseTLS10() {
t.Require().NotNil(client.Transport) t.Require().NotNil(client.Transport)
t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig) t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig)
t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion) t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion)
t.Assert().NotNil(client.Transport.(*http.Transport).Proxy)
} }
// taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go // taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go