1
0
mirror of synced 2025-02-22 07:03:16 +03:00

optimize rate limiter for heavier load

This commit is contained in:
Pavel 2024-12-06 19:03:33 +03:00
parent 0312ddcdd2
commit 5fb9f0f895
2 changed files with 94 additions and 51 deletions

View File

@ -1,100 +1,132 @@
package v1 package v1
import ( import (
"hash/fnv"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
// NoopLimiter implements Limiter but doesn't limit anything.
var NoopLimiter Limiter = &noopLimiter{} var NoopLimiter Limiter = &noopLimiter{}
type token struct { type token struct {
rps atomic.Uint32 rps uint32
lastUse atomic.Value lastUse int64 // Unix timestamp in nanoseconds
} }
// Limiter implements some form of rate limiting. // Limiter interface for rate-limiting.
type Limiter interface { type Limiter interface {
// Obtain the right to send a request. Should lock the execution if current goroutine needs to wait. Obtain(id string)
Obtain(string)
} }
// TokensBucket implements basic Limiter with fixed window and fixed amount of tokens per window. // TokensBucket implements a sharded rate limiter with fixed window and tokens.
type TokensBucket struct { type TokensBucket struct {
maxRPS uint32 maxRPS uint32
tokens sync.Map unusedTokenTime int64 // in nanoseconds
unusedTokenTime time.Duration
checkTokenTime time.Duration checkTokenTime time.Duration
shards []*tokenShard
shardCount uint32
cancel atomic.Bool cancel atomic.Bool
sleep sleeper sleep sleeper
} }
// NewTokensBucket constructs TokensBucket with provided parameters. type tokenShard struct {
tokens map[string]*token
mu sync.Mutex
}
// NewTokensBucket creates a sharded token bucket limiter.
func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter { func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter {
shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding
shards := make([]*tokenShard, shardCount)
for i := range shards {
shards[i] = &tokenShard{tokens: make(map[string]*token)}
}
bucket := &TokensBucket{ bucket := &TokensBucket{
maxRPS: maxRPS, maxRPS: maxRPS,
unusedTokenTime: unusedTokenTime, unusedTokenTime: unusedTokenTime.Nanoseconds(),
checkTokenTime: checkTokenTime, checkTokenTime: checkTokenTime,
shards: shards,
shardCount: shardCount,
sleep: realSleeper{}, sleep: realSleeper{},
} }
go bucket.deleteUnusedToken() go bucket.cleanupRoutine()
runtime.SetFinalizer(bucket, destructBasket) runtime.SetFinalizer(bucket, destructBucket)
return bucket return bucket
} }
// Obtain request hit. Will throttle RPS.
func (m *TokensBucket) Obtain(id string) { func (m *TokensBucket) Obtain(id string) {
val, ok := m.tokens.Load(id) shard := m.getShard(id)
if !ok {
token := &token{} shard.mu.Lock()
token.lastUse.Store(time.Now()) defer shard.mu.Unlock()
token.rps.Store(1)
m.tokens.Store(id, token) item, exists := shard.tokens[id]
now := time.Now().UnixNano()
if !exists {
shard.tokens[id] = &token{
rps: 1,
lastUse: now,
}
return return
} }
token := val.(*token) sleepTime := int64(time.Second) - (now - item.lastUse)
sleepTime := time.Second - time.Since(token.lastUse.Load().(time.Time))
if sleepTime <= 0 { if sleepTime <= 0 {
token.lastUse.Store(time.Now()) item.lastUse = now
token.rps.Store(0) atomic.StoreUint32(&item.rps, 1)
} else if token.rps.Load() >= m.maxRPS { } else if atomic.LoadUint32(&item.rps) >= m.maxRPS {
m.sleep.Sleep(sleepTime) m.sleep.Sleep(time.Duration(sleepTime))
token.lastUse.Store(time.Now()) item.lastUse = time.Now().UnixNano()
token.rps.Store(0) atomic.StoreUint32(&item.rps, 1)
} else {
atomic.AddUint32(&item.rps, 1)
} }
token.rps.Add(1)
} }
func destructBasket(m *TokensBucket) { func (m *TokensBucket) getShard(id string) *tokenShard {
m.cancel.Store(true) hash := fnv.New32a()
_, _ = hash.Write([]byte(id))
return m.shards[hash.Sum32()%m.shardCount]
} }
func (m *TokensBucket) deleteUnusedToken() { func (m *TokensBucket) cleanupRoutine() {
ticker := time.NewTicker(m.checkTokenTime)
defer ticker.Stop()
for { for {
if m.cancel.Load() { select {
return case <-ticker.C:
} if m.cancel.Load() {
return
m.tokens.Range(func(key, value any) bool {
id, token := key.(string), value.(*token)
if time.Since(token.lastUse.Load().(time.Time)) >= m.unusedTokenTime {
m.tokens.Delete(id)
} }
return false now := time.Now().UnixNano()
}) for _, shard := range m.shards {
shard.mu.Lock()
m.sleep.Sleep(m.checkTokenTime) for id, token := range shard.tokens {
if now-token.lastUse >= m.unusedTokenTime {
delete(shard.tokens, id)
}
}
shard.mu.Unlock()
}
}
} }
} }
func destructBucket(m *TokensBucket) {
m.cancel.Store(true)
}
type noopLimiter struct{} type noopLimiter struct{}
func (l *noopLimiter) Obtain(string) {} func (l *noopLimiter) Obtain(string) {}
// sleeper sleeps. This thing is necessary for tests.
type sleeper interface { type sleeper interface {
Sleep(time.Duration) Sleep(time.Duration)
} }

View File

@ -24,13 +24,22 @@ func (t *TokensBucketTest) Test_NewTokensBucket() {
func (t *TokensBucketTest) new( func (t *TokensBucketTest) new(
maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket { maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket {
shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding
shards := make([]*tokenShard, shardCount)
for i := range shards {
shards[i] = &tokenShard{tokens: make(map[string]*token)}
}
bucket := &TokensBucket{ bucket := &TokensBucket{
maxRPS: maxRPS, maxRPS: maxRPS,
unusedTokenTime: unusedTokenTime, unusedTokenTime: unusedTokenTime.Nanoseconds(),
checkTokenTime: checkTokenTime, checkTokenTime: checkTokenTime,
shards: shards,
shardCount: shardCount,
sleep: sleeper, sleep: sleeper,
} }
runtime.SetFinalizer(bucket, destructBasket)
runtime.SetFinalizer(bucket, destructBucket)
return bucket return bucket
} }
@ -46,12 +55,14 @@ func (t *TokensBucketTest) Test_Obtain_NoThrottle() {
func (t *TokensBucketTest) Test_Obtain_Sleep() { func (t *TokensBucketTest) Test_Obtain_Sleep() {
clock := &fakeSleeper{} clock := &fakeSleeper{}
tb := t.new(100, time.Hour, time.Minute, clock) tb := t.new(100, time.Hour, time.Minute, clock)
_, exists := tb.getShard("w").tokens["w"]
t.Require().False(exists)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
for i := 0; i < 301; i++ { for i := 0; i < 301; i++ {
tb.Obtain("a") tb.Obtain("w")
} }
wg.Done() wg.Done()
}() }()
@ -63,15 +74,15 @@ func (t *TokensBucketTest) Test_Obtain_Sleep() {
func (t *TokensBucketTest) Test_Obtain_AddRPS() { func (t *TokensBucketTest) Test_Obtain_AddRPS() {
clock := clockwork.NewFakeClock() clock := clockwork.NewFakeClock()
tb := t.new(100, time.Hour, time.Minute, clock) tb := t.new(100, time.Hour, time.Minute, clock)
go tb.deleteUnusedToken() go tb.cleanupRoutine()
tb.Obtain("a") tb.Obtain("a")
clock.Advance(time.Minute * 2) clock.Advance(time.Minute * 2)
item, found := tb.tokens.Load("a") item, found := tb.getShard("a").tokens["a"]
t.Require().True(found) t.Require().True(found)
t.Assert().Equal(1, int(item.(*token).rps.Load())) t.Assert().Equal(1, int(item.rps))
tb.Obtain("a") tb.Obtain("a")
t.Assert().Equal(2, int(item.(*token).rps.Load())) t.Assert().Equal(2, int(item.rps))
} }
type fakeSleeper struct { type fakeSleeper struct {