mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2024-11-21 20:56:04 +03:00
add proxy support to the HTTP client built by library
This commit is contained in:
commit
b6d2ee5d7c
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -16,6 +17,14 @@ import (
|
||||
"github.com/retailcrm/mg-transport-core/v2/core/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialerTimeout = 30 * time.Second
|
||||
defaultIdleConnTimeout = 90 * time.Second
|
||||
defaultTLSHandshakeTimeout = 10 * time.Second
|
||||
defaultExpectContinueTimeout = 1 * time.Second
|
||||
defaultMaxIdleConns = 100
|
||||
)
|
||||
|
||||
// DefaultClient stores original http.DefaultClient.
|
||||
var DefaultClient = http.DefaultClient
|
||||
|
||||
@ -24,28 +33,29 @@ var DefaultTransport = http.DefaultTransport
|
||||
|
||||
// HTTPClientBuilder builds http client with mocks (if necessary) and timeout.
|
||||
// Example:
|
||||
// // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
|
||||
// client, err := NewHTTPClientBuilder().
|
||||
// SetTimeout(10).
|
||||
// SetMockAddress("api_mock:3004").
|
||||
// AddMockedDomain("google.com").
|
||||
// SetSSLVerification(false).
|
||||
// Build()
|
||||
//
|
||||
// if err != nil {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
// // Build HTTP client with timeout = 10 sec, without SSL certificates verification and with mocked google.com
|
||||
// client, err := NewHTTPClientBuilder().
|
||||
// SetTimeout(10).
|
||||
// SetMockAddress("api_mock:3004").
|
||||
// AddMockedDomain("google.com").
|
||||
// SetSSLVerification(false).
|
||||
// Build()
|
||||
//
|
||||
// // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate)
|
||||
// if resp, err := client.Get("https://google.com"); err == nil {
|
||||
// if data, err := ioutil.ReadAll(resp.Body); err == nil {
|
||||
// fmt.Printf("Data: %s", string(data))
|
||||
// } else {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
// } else {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
// if err != nil {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
//
|
||||
// // Actual response will be returned from "api_mock:3004" (it should provide any ssl certificate)
|
||||
// if resp, err := client.Get("https://google.com"); err == nil {
|
||||
// if data, err := ioutil.ReadAll(resp.Body); err == nil {
|
||||
// fmt.Printf("Data: %s", string(data))
|
||||
// } else {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
// } else {
|
||||
// fmt.Print(err)
|
||||
// }
|
||||
type HTTPClientBuilder struct {
|
||||
logger logger.Logger
|
||||
httpClient *http.Client
|
||||
@ -64,11 +74,22 @@ type HTTPClientBuilder struct {
|
||||
// NewHTTPClientBuilder returns HTTPClientBuilder with default values.
|
||||
func NewHTTPClientBuilder() *HTTPClientBuilder {
|
||||
return &HTTPClientBuilder{
|
||||
built: false,
|
||||
httpClient: &http.Client{},
|
||||
httpTransport: &http.Transport{},
|
||||
built: false,
|
||||
httpClient: &http.Client{},
|
||||
httpTransport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: defaultDialerTimeout,
|
||||
KeepAlive: defaultDialerTimeout,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: defaultMaxIdleConns,
|
||||
IdleConnTimeout: defaultIdleConnTimeout,
|
||||
TLSHandshakeTimeout: defaultTLSHandshakeTimeout,
|
||||
ExpectContinueTimeout: defaultExpectContinueTimeout,
|
||||
},
|
||||
tlsVersion: tls.VersionTLS12,
|
||||
timeout: 30 * time.Second,
|
||||
timeout: defaultDialerTimeout,
|
||||
mockAddress: "",
|
||||
mockedDomains: []string{},
|
||||
logging: false,
|
||||
@ -147,6 +168,11 @@ func (b *HTTPClientBuilder) SetLogging(flag bool) *HTTPClientBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *HTTPClientBuilder) SetProxy(proxy func(*http.Request) (*url.URL, error)) *HTTPClientBuilder {
|
||||
b.httpTransport.Proxy = proxy
|
||||
return b
|
||||
}
|
||||
|
||||
// FromConfig fulfills mock configuration from HTTPClientConfig.
|
||||
func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder {
|
||||
if config == nil {
|
||||
@ -212,6 +238,7 @@ func (b *HTTPClientBuilder) buildMocks() error {
|
||||
b.logf(" - %s\n", domain)
|
||||
}
|
||||
|
||||
b.httpTransport.Proxy = nil
|
||||
b.httpTransport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
|
||||
var (
|
||||
host string
|
||||
|
@ -92,6 +92,13 @@ func (t *HTTPClientBuilderTest) Test_SetCertPool() {
|
||||
assert.Equal(t.T(), pool, t.builder.httpTransport.TLSClientConfig.RootCAs)
|
||||
}
|
||||
|
||||
func (t *HTTPClientBuilderTest) Test_SetProxy() {
|
||||
t.builder.SetProxy(nil)
|
||||
assert.Nil(t.T(), t.builder.httpTransport.Proxy)
|
||||
t.builder.SetProxy(http.ProxyFromEnvironment)
|
||||
assert.NotNil(t.T(), t.builder.httpTransport.Proxy)
|
||||
}
|
||||
|
||||
func (t *HTTPClientBuilderTest) Test_FromConfigNil() {
|
||||
defer func() {
|
||||
assert.Nil(t.T(), recover())
|
||||
@ -161,6 +168,7 @@ func (t *HTTPClientBuilderTest) Test_Build() {
|
||||
|
||||
assert.NoError(t.T(), err)
|
||||
assert.NotNil(t.T(), client)
|
||||
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
|
||||
assert.Equal(t.T(), client, http.DefaultClient)
|
||||
assert.Equal(t.T(), timeout*time.Second, client.Timeout)
|
||||
assert.Equal(t.T(), pool, client.Transport.(*http.Transport).TLSClientConfig.RootCAs)
|
||||
@ -290,6 +298,7 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
|
||||
SetSSLVerification(false).
|
||||
Build()
|
||||
require.NoError(t.T(), err, "cannot build client")
|
||||
assert.Nil(t.T(), client.Transport.(*http.Transport).Proxy)
|
||||
|
||||
resp, err := client.Get(mockProto + mockDomainAddr)
|
||||
if err != nil && strings.Contains(err.Error(), "connection refused") {
|
||||
@ -314,6 +323,7 @@ func (t *HTTPClientBuilderTest) Test_UseTLS10() {
|
||||
t.Require().NotNil(client.Transport)
|
||||
t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig)
|
||||
t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion)
|
||||
t.Assert().NotNil(client.Transport.(*http.Transport).Proxy)
|
||||
}
|
||||
|
||||
// taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go
|
||||
|
Loading…
Reference in New Issue
Block a user