diff --git a/go.mod b/go.mod index 927ab2a..72525c2 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,12 @@ go 1.22 require ( github.com/google/go-querystring v1.0.0 - github.com/maypok86/otter v1.0.0 github.com/stretchr/testify v1.8.1 gopkg.in/h2non/gock.v1 v1.1.2 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dolthub/maphash v0.1.0 // indirect - github.com/dolthub/swiss v0.2.1 // indirect - github.com/gammazero/deque v0.2.1 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index e5cbe5e..fd64c2d 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= -github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= -github.com/dolthub/swiss v0.2.1 h1:gs2osYs5SJkAaH5/ggVJqXQxRXtWshF6uE0lgR/Y3Gw= -github.com/dolthub/swiss v0.2.1/go.mod h1:8AhKZZ1HK7g18j7v7k6c5cYIGEZJcPn0ARsai8cUrh0= -github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= -github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= -github.com/maypok86/otter v1.0.0 h1:nP13eaFQrfRQHD1vxEgdlqR9gLHvfW2VcS0hFitglIY= -github.com/maypok86/otter v1.0.0/go.mod h1:koSPT30yWtqMNrFohaywMlgSHCuUg6IVqeDerwIM/Mg= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/v1/client.go b/v1/client.go index afa6c47..273dc44 100644 --- a/v1/client.go +++ b/v1/client.go @@ -35,6 +35,12 @@ func (c *MgClient) WithLogger(logger BasicLogger) *MgClient { return c } +// WithLimiter sets the provided limiter instance into the Client. +func (c *MgClient) WithLimiter(limiter *TokensBucket) *MgClient { + c.limiter = limiter + return c +} + // writeLog writes a message to the log. func (c *MgClient) writeLog(format string, v ...interface{}) { if c.logger != nil { diff --git a/v1/rate_limit.go b/v1/rate_limit.go new file mode 100644 index 0000000..3d4fb8b --- /dev/null +++ b/v1/rate_limit.go @@ -0,0 +1,81 @@ +package v1 + +import ( + "runtime" + "sync" + "sync/atomic" + "time" +) + +type token struct { + rps uint32 + lastUse time.Time +} + +type TokensBucket struct { + maxRPS uint32 + mux sync.Mutex + tokens map[string]*token + unusedTokenTime time.Duration + checkTokenTime time.Duration + cancel atomic.Bool +} + +func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) *TokensBucket { + bucket := &TokensBucket{ + maxRPS: maxRPS, + tokens: map[string]*token{}, + unusedTokenTime: unusedTokenTime, + checkTokenTime: checkTokenTime, + } + + go bucket.deleteUnusedToken() + runtime.SetFinalizer(bucket, destructBasket) + return bucket +} + +func (m *TokensBucket) Obtain(id string) { + m.mux.Lock() + defer m.mux.Unlock() + + if _, ok := m.tokens[id]; !ok { + m.tokens[id] = &token{ + lastUse: time.Now(), + rps: 1, + } + return + } + + sleepTime := time.Second - time.Since(m.tokens[id].lastUse) + if sleepTime < 0 { + m.tokens[id].lastUse = time.Now() + m.tokens[id].rps = 0 + } else if m.tokens[id].rps >= m.maxRPS { + time.Sleep(sleepTime) + m.tokens[id].lastUse = time.Now() + m.tokens[id].rps = 0 + } + m.tokens[id].rps++ +} + +func destructBasket(m *TokensBucket) { + m.cancel.Store(true) +} + +func (m *TokensBucket) deleteUnusedToken() { + for { + if m.cancel.Load() { + return + } + m.mux.Lock() + + for id, token := range m.tokens { + if time.Since(token.lastUse) >= m.unusedTokenTime { + delete(m.tokens, id) + } + } + m.mux.Unlock() + + time.Sleep(m.checkTokenTime) + } +} diff --git a/v1/request.go b/v1/request.go index 74f2c7a..74bc0d1 100644 --- a/v1/request.go +++ b/v1/request.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "strings" - "time" ) const MaxRPS = 100 @@ -53,6 +52,12 @@ func (c *MgClient) DeleteRequest(url string, parameters []byte) ([]byte, int, er ) } +func (c *MgClient) WaitForRateLimit() { + if c.limiter != nil && c.Token != "" { + c.limiter.Obtain(c.Token) + } +} + func makeRequest(reqType, url string, buf io.Reader, c *MgClient) ([]byte, int, error) { var res []byte req, err := http.NewRequest(reqType, url, buf) @@ -63,22 +68,9 @@ func makeRequest(reqType, url string, buf io.Reader, c *MgClient) ([]byte, int, req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Transport-Token", c.Token) - defer c.mux.Unlock() - c.mux.Lock() - attempt := 0 tryAgain: - sleepTime := time.Second - time.Since(c.lastTime) - if sleepTime < 0 { - c.lastTime = time.Now() - c.rps = 0 - } else if c.rps == MaxRPS { - time.Sleep(sleepTime) - c.lastTime = time.Now() - c.rps = 0 - } - c.rps++ - + c.WaitForRateLimit() if c.Debug { if strings.Contains(url, "/files/upload") { c.writeLog("MG TRANSPORT API Request: %s %s %s [file data]", reqType, url, c.Token) diff --git a/v1/storage.go b/v1/storage.go deleted file mode 100644 index 55495ab..0000000 --- a/v1/storage.go +++ /dev/null @@ -1,45 +0,0 @@ -package v1 - -import ( - "errors" - "time" - - "github.com/maypok86/otter" -) - -const mgClientCacheTTL = time.Hour * 1 - -var ErrNegativeCapacity = errors.New("capacity cannot be less than 1") - -type MGClientPool struct { - cache *otter.CacheWithVariableTTL[string, *MgClient] -} - -// NewMGClientPool initializes the client cache. -func NewMGClientPool(capacity int) (*MGClientPool, error) { - if capacity <= 0 { - return nil, ErrNegativeCapacity - } - - cache, _ := otter.MustBuilder[string, *MgClient](capacity).WithVariableTTL().Build() - return &MGClientPool{cache: &cache}, nil -} - -func (m *MGClientPool) Get(token string, url string) *MgClient { - if client, ok := m.cache.Get(token); ok { - return client - } - - client := New(url, token) - m.cache.Set(token, client, mgClientCacheTTL) - - return client -} - -func (m *MGClientPool) Remove(token string) { - m.cache.Delete(token) -} - -func (m *MGClientPool) Close() { - m.cache.Close() -} diff --git a/v1/storage_test.go b/v1/storage_test.go deleted file mode 100644 index ca41c28..0000000 --- a/v1/storage_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package v1 - -import ( - "testing" - - "github.com/stretchr/testify/suite" -) - -type StorageTest struct { - suite.Suite -} - -func TestStorage(t *testing.T) { - suite.Run(t, new(StorageTest)) -} - -func (t *StorageTest) Test_MGClientPool() { - clientPool, err := NewMGClientPool(1) - t.Assert().NoError(err) - - client := clientPool.Get("test_token", "test_url") - t.Assert().Equal("test_url", client.URL) - - clientPool.Remove("test_token") - clientPool.Close() -} - -func (t *StorageTest) Test_NegativeCapacity() { - _, err := NewMGClientPool(-1) - t.Assert().Equal(ErrNegativeCapacity.Error(), err.Error()) -} diff --git a/v1/types.go b/v1/types.go index d24385a..a03adcc 100644 --- a/v1/types.go +++ b/v1/types.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "sync" "time" ) @@ -79,14 +78,12 @@ const ( // MgClient type. type MgClient struct { - URL string `json:"url"` - Token string `json:"token"` - Debug bool `json:"debug"` - httpClient *http.Client `json:"-"` - logger BasicLogger `json:"-"` - mux sync.Mutex `json:"-"` - lastTime time.Time `json:"-"` - rps int `json:"-"` + URL string `json:"url"` + Token string `json:"token"` + Debug bool `json:"debug"` + httpClient *http.Client `json:"-"` + logger BasicLogger `json:"-"` + limiter *TokensBucket `json:"-"` } // Channel type.