diff --git a/go.mod b/go.mod index 72525c2..3d72865 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/google/go-querystring v1.0.0 + github.com/jonboulle/clockwork v0.4.0 github.com/stretchr/testify v1.8.1 gopkg.in/h2non/gock.v1 v1.1.2 ) @@ -12,5 +13,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fd64c2d..4ebeec6 100644 --- a/go.sum +++ b/go.sum @@ -5,12 +5,15 @@ github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASu github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= diff --git a/v1/rate_limit.go b/v1/rate_limit.go index 3d4fb8b..48e99f0 100644 --- a/v1/rate_limit.go +++ b/v1/rate_limit.go @@ -8,25 +8,25 @@ import ( ) type token struct { - rps uint32 - lastUse time.Time + rps atomic.Uint32 + lastUse atomic.Value } type TokensBucket struct { maxRPS uint32 - mux sync.Mutex - tokens map[string]*token + tokens sync.Map unusedTokenTime time.Duration checkTokenTime time.Duration cancel atomic.Bool + sleep sleeper } func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) *TokensBucket { bucket := &TokensBucket{ maxRPS: maxRPS, - tokens: map[string]*token{}, unusedTokenTime: unusedTokenTime, checkTokenTime: checkTokenTime, + sleep: realSleeper{}, } go bucket.deleteUnusedToken() @@ -35,27 +35,26 @@ func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duratio } func (m *TokensBucket) Obtain(id string) { - m.mux.Lock() - defer m.mux.Unlock() - - if _, ok := m.tokens[id]; !ok { - m.tokens[id] = &token{ - lastUse: time.Now(), - rps: 1, - } + val, ok := m.tokens.Load(id) + if !ok { + token := &token{} + token.lastUse.Store(time.Now()) + token.rps.Store(1) + m.tokens.Store(id, token) return } - sleepTime := time.Second - time.Since(m.tokens[id].lastUse) - if sleepTime < 0 { - m.tokens[id].lastUse = time.Now() - m.tokens[id].rps = 0 - } else if m.tokens[id].rps >= m.maxRPS { - time.Sleep(sleepTime) - m.tokens[id].lastUse = time.Now() - m.tokens[id].rps = 0 + token := val.(*token) + sleepTime := time.Second - time.Since(token.lastUse.Load().(time.Time)) + if sleepTime <= 0 { + token.lastUse.Store(time.Now()) + token.rps.Store(0) + } else if token.rps.Load() >= m.maxRPS { + m.sleep.Sleep(sleepTime) + token.lastUse.Store(time.Now()) + token.rps.Store(0) } - m.tokens[id].rps++ + token.rps.Add(1) } func destructBasket(m *TokensBucket) { @@ -67,15 +66,25 @@ func (m *TokensBucket) deleteUnusedToken() { if m.cancel.Load() { return } - m.mux.Lock() - for id, token := range m.tokens { - if time.Since(token.lastUse) >= m.unusedTokenTime { - delete(m.tokens, id) + m.tokens.Range(func(key, value any) bool { + id, token := key.(string), value.(*token) + if time.Since(token.lastUse.Load().(time.Time)) >= m.unusedTokenTime { + m.tokens.Delete(id) } - } - m.mux.Unlock() + return false + }) - time.Sleep(m.checkTokenTime) + m.sleep.Sleep(m.checkTokenTime) } } + +type sleeper interface { + Sleep(time.Duration) +} + +type realSleeper struct{} + +func (s realSleeper) Sleep(d time.Duration) { + time.Sleep(d) +} diff --git a/v1/rate_limit_test.go b/v1/rate_limit_test.go new file mode 100644 index 0000000..7e9ab72 --- /dev/null +++ b/v1/rate_limit_test.go @@ -0,0 +1,83 @@ +package v1 + +import ( + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/suite" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +type TokensBucketTest struct { + suite.Suite +} + +func TestTokensBucket(t *testing.T) { + suite.Run(t, new(TokensBucketTest)) +} + +func (t *TokensBucketTest) Test_NewTokensBucket() { + t.Assert().NotNil(NewTokensBucket(10, time.Hour, time.Hour)) +} + +func (t *TokensBucketTest) new( + maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket { + bucket := &TokensBucket{ + maxRPS: maxRPS, + unusedTokenTime: unusedTokenTime, + checkTokenTime: checkTokenTime, + sleep: sleeper, + } + runtime.SetFinalizer(bucket, destructBasket) + return bucket +} + +func (t *TokensBucketTest) Test_Obtain_NoThrottle() { + tb := t.new(100, time.Hour, time.Minute, &realSleeper{}) + start := time.Now() + for i := 0; i < 100; i++ { + tb.Obtain("a") + } + t.Assert().True(time.Since(start) < time.Second) // check that rate limiter did not perform throttle. +} + +func (t *TokensBucketTest) Test_Obtain_Sleep() { + clock := &fakeSleeper{} + tb := t.new(100, time.Hour, time.Minute, clock) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + for i := 0; i < 301; i++ { + tb.Obtain("a") + } + wg.Done() + }() + + wg.Wait() + t.Assert().Equal(3, int(clock.total.Load())) +} + +func (t *TokensBucketTest) Test_Obtain_AddRPS() { + clock := clockwork.NewFakeClock() + tb := t.new(100, time.Hour, time.Minute, clock) + go tb.deleteUnusedToken() + tb.Obtain("a") + clock.Advance(time.Minute * 2) + + item, found := tb.tokens.Load("a") + t.Require().True(found) + t.Assert().Equal(1, int(item.(*token).rps.Load())) + tb.Obtain("a") + t.Assert().Equal(2, int(item.(*token).rps.Load())) +} + +type fakeSleeper struct { + total atomic.Uint32 +} + +func (s *fakeSleeper) Sleep(time.Duration) { + s.total.Add(1) +}