add proxy support to the HTTP client built by library

This change restores default *http.Transport behavior for the client built by the library.
This commit is contained in:
Pavel 2024-06-06 14:06:15 +03:00
parent 584c5ed306
commit 655b86abb0
2 changed files with 52 additions and 23 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/pkg/errors"
@ -24,28 +25,29 @@ var DefaultTransport = http.DefaultTransport
// HTTPClientBuilder builds http client with mocks (if necessary) and timeout.
// Example:
// // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
// client, err := NewHTTPClientBuilder().
// SetTimeout(10).
// SetMockAddress("api_mock:3004").
// AddMockedDomain("google.com").
// SetSSLVerification(false).
// Build()
//
// if err != nil {
// fmt.Print(err)
// }
// // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
// client, err := NewHTTPClientBuilder().
// SetTimeout(10).
// SetMockAddress("api_mock:3004").
// AddMockedDomain("google.com").
// SetSSLVerification(false).
// Build()
//
// // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate)
// if resp, err := client.Get("https://google.com"); err == nil {
// if data, err := ioutil.ReadAll(resp.Body); err == nil {
// fmt.Printf("Data: %s", string(data))
// } else {
// fmt.Print(err)
// }
// } else {
// fmt.Print(err)
// }
// if err != nil {
// fmt.Print(err)
// }
//
// // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate)
// if resp, err := client.Get("https://google.com"); err == nil {
// if data, err := ioutil.ReadAll(resp.Body); err == nil {
// fmt.Printf("Data: %s", string(data))
// } else {
// fmt.Print(err)
// }
// } else {
// fmt.Print(err)
// }
type HTTPClientBuilder struct {
logger logger.Logger
httpClient *http.Client
@ -64,9 +66,20 @@ type HTTPClientBuilder struct {
// NewHTTPClientBuilder returns HTTPClientBuilder with default values.
func NewHTTPClientBuilder() *HTTPClientBuilder {
return &HTTPClientBuilder{
built: false,
httpClient: &http.Client{},
httpTransport: &http.Transport{},
built: false,
httpClient: &http.Client{},
httpTransport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
tlsVersion: tls.VersionTLS12,
timeout: 30 * time.Second,
mockAddress: "",
@ -147,6 +160,11 @@ func (b *HTTPClientBuilder) SetLogging(flag bool) *HTTPClientBuilder {
return b
}
func (b *HTTPClientBuilder) SetProxy(proxy func(*http.Request) (*url.URL, error)) *HTTPClientBuilder {
b.httpTransport.Proxy = proxy
return b
}
// FromConfig fulfills mock configuration from HTTPClientConfig.
func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder {
if config == nil {
@ -212,6 +230,7 @@ func (b *HTTPClientBuilder) buildMocks() error {
b.logf(" - %s\n", domain)
}
b.httpTransport.Proxy = nil
b.httpTransport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
var (
host string

View File

@ -92,6 +92,13 @@ func (t *HTTPClientBuilderTest) Test_SetCertPool() {
assert.Equal(t.T(), pool, t.builder.httpTransport.TLSClientConfig.RootCAs)
}
func (t *HTTPClientBuilderTest) Test_SetProxy() {
t.builder.SetProxy(nil)
assert.Nil(t.T(), t.builder.httpTransport.Proxy)
t.builder.SetProxy(http.ProxyFromEnvironment)
assert.NotNil(t.T(), t.builder.httpTransport.Proxy)
}
func (t *HTTPClientBuilderTest) Test_FromConfigNil() {
defer func() {
assert.Nil(t.T(), recover())
@ -161,6 +168,7 @@ func (t *HTTPClientBuilderTest) Test_Build() {
assert.NoError(t.T(), err)
assert.NotNil(t.T(), client)
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
assert.Equal(t.T(), client, http.DefaultClient)
assert.Equal(t.T(), timeout*time.Second, client.Timeout)
assert.Equal(t.T(), pool, client.Transport.(*http.Transport).TLSClientConfig.RootCAs)
@ -290,6 +298,7 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
SetSSLVerification(false).
Build()
require.NoError(t.T(), err, "cannot build client")
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
resp, err := client.Get(mockProto + mockDomainAddr)
if err != nil && strings.Contains(err.Error(), "connection refused") {
@ -314,6 +323,7 @@ func (t *HTTPClientBuilderTest) Test_UseTLS10() {
t.Require().NotNil(client.Transport)
t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig)
t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion)
t.Assert().NotNil(client.Transport.(*http.Transport).Proxy)
}
// taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go