optimize rate limiter for heavier load
This commit is contained in:
parent
0312ddcdd2
commit
5fb9f0f895
120
v1/rate_limit.go
120
v1/rate_limit.go
@ -1,100 +1,132 @@
|
|||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"hash/fnv"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NoopLimiter implements Limiter but doesn't limit anything.
|
|
||||||
var NoopLimiter Limiter = &noopLimiter{}
|
var NoopLimiter Limiter = &noopLimiter{}
|
||||||
|
|
||||||
type token struct {
|
type token struct {
|
||||||
rps atomic.Uint32
|
rps uint32
|
||||||
lastUse atomic.Value
|
lastUse int64 // Unix timestamp in nanoseconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// Limiter implements some form of rate limiting.
|
// Limiter interface for rate-limiting.
|
||||||
type Limiter interface {
|
type Limiter interface {
|
||||||
// Obtain the right to send a request. Should lock the execution if current goroutine needs to wait.
|
Obtain(id string)
|
||||||
Obtain(string)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokensBucket implements basic Limiter with fixed window and fixed amount of tokens per window.
|
// TokensBucket implements a sharded rate limiter with fixed window and tokens.
|
||||||
type TokensBucket struct {
|
type TokensBucket struct {
|
||||||
maxRPS uint32
|
maxRPS uint32
|
||||||
tokens sync.Map
|
unusedTokenTime int64 // in nanoseconds
|
||||||
unusedTokenTime time.Duration
|
|
||||||
checkTokenTime time.Duration
|
checkTokenTime time.Duration
|
||||||
|
shards []*tokenShard
|
||||||
|
shardCount uint32
|
||||||
cancel atomic.Bool
|
cancel atomic.Bool
|
||||||
sleep sleeper
|
sleep sleeper
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTokensBucket constructs TokensBucket with provided parameters.
|
type tokenShard struct {
|
||||||
|
tokens map[string]*token
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokensBucket creates a sharded token bucket limiter.
|
||||||
func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter {
|
func NewTokensBucket(maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration) Limiter {
|
||||||
|
shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding
|
||||||
|
shards := make([]*tokenShard, shardCount)
|
||||||
|
for i := range shards {
|
||||||
|
shards[i] = &tokenShard{tokens: make(map[string]*token)}
|
||||||
|
}
|
||||||
|
|
||||||
bucket := &TokensBucket{
|
bucket := &TokensBucket{
|
||||||
maxRPS: maxRPS,
|
maxRPS: maxRPS,
|
||||||
unusedTokenTime: unusedTokenTime,
|
unusedTokenTime: unusedTokenTime.Nanoseconds(),
|
||||||
checkTokenTime: checkTokenTime,
|
checkTokenTime: checkTokenTime,
|
||||||
|
shards: shards,
|
||||||
|
shardCount: shardCount,
|
||||||
sleep: realSleeper{},
|
sleep: realSleeper{},
|
||||||
}
|
}
|
||||||
|
|
||||||
go bucket.deleteUnusedToken()
|
go bucket.cleanupRoutine()
|
||||||
runtime.SetFinalizer(bucket, destructBasket)
|
runtime.SetFinalizer(bucket, destructBucket)
|
||||||
return bucket
|
return bucket
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Obtain request hit. Will throttle RPS.
|
||||||
func (m *TokensBucket) Obtain(id string) {
|
func (m *TokensBucket) Obtain(id string) {
|
||||||
val, ok := m.tokens.Load(id)
|
shard := m.getShard(id)
|
||||||
if !ok {
|
|
||||||
token := &token{}
|
shard.mu.Lock()
|
||||||
token.lastUse.Store(time.Now())
|
defer shard.mu.Unlock()
|
||||||
token.rps.Store(1)
|
|
||||||
m.tokens.Store(id, token)
|
item, exists := shard.tokens[id]
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
shard.tokens[id] = &token{
|
||||||
|
rps: 1,
|
||||||
|
lastUse: now,
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token := val.(*token)
|
sleepTime := int64(time.Second) - (now - item.lastUse)
|
||||||
sleepTime := time.Second - time.Since(token.lastUse.Load().(time.Time))
|
|
||||||
if sleepTime <= 0 {
|
if sleepTime <= 0 {
|
||||||
token.lastUse.Store(time.Now())
|
item.lastUse = now
|
||||||
token.rps.Store(0)
|
atomic.StoreUint32(&item.rps, 1)
|
||||||
} else if token.rps.Load() >= m.maxRPS {
|
} else if atomic.LoadUint32(&item.rps) >= m.maxRPS {
|
||||||
m.sleep.Sleep(sleepTime)
|
m.sleep.Sleep(time.Duration(sleepTime))
|
||||||
token.lastUse.Store(time.Now())
|
item.lastUse = time.Now().UnixNano()
|
||||||
token.rps.Store(0)
|
atomic.StoreUint32(&item.rps, 1)
|
||||||
|
} else {
|
||||||
|
atomic.AddUint32(&item.rps, 1)
|
||||||
}
|
}
|
||||||
token.rps.Add(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func destructBasket(m *TokensBucket) {
|
func (m *TokensBucket) getShard(id string) *tokenShard {
|
||||||
m.cancel.Store(true)
|
hash := fnv.New32a()
|
||||||
|
_, _ = hash.Write([]byte(id))
|
||||||
|
return m.shards[hash.Sum32()%m.shardCount]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TokensBucket) deleteUnusedToken() {
|
func (m *TokensBucket) cleanupRoutine() {
|
||||||
|
ticker := time.NewTicker(m.checkTokenTime)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if m.cancel.Load() {
|
select {
|
||||||
return
|
case <-ticker.C:
|
||||||
}
|
if m.cancel.Load() {
|
||||||
|
return
|
||||||
m.tokens.Range(func(key, value any) bool {
|
|
||||||
id, token := key.(string), value.(*token)
|
|
||||||
if time.Since(token.lastUse.Load().(time.Time)) >= m.unusedTokenTime {
|
|
||||||
m.tokens.Delete(id)
|
|
||||||
}
|
}
|
||||||
return false
|
now := time.Now().UnixNano()
|
||||||
})
|
for _, shard := range m.shards {
|
||||||
|
shard.mu.Lock()
|
||||||
m.sleep.Sleep(m.checkTokenTime)
|
for id, token := range shard.tokens {
|
||||||
|
if now-token.lastUse >= m.unusedTokenTime {
|
||||||
|
delete(shard.tokens, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shard.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func destructBucket(m *TokensBucket) {
|
||||||
|
m.cancel.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
type noopLimiter struct{}
|
type noopLimiter struct{}
|
||||||
|
|
||||||
func (l *noopLimiter) Obtain(string) {}
|
func (l *noopLimiter) Obtain(string) {}
|
||||||
|
|
||||||
// sleeper sleeps. This thing is necessary for tests.
|
|
||||||
type sleeper interface {
|
type sleeper interface {
|
||||||
Sleep(time.Duration)
|
Sleep(time.Duration)
|
||||||
}
|
}
|
||||||
|
@ -24,13 +24,22 @@ func (t *TokensBucketTest) Test_NewTokensBucket() {
|
|||||||
|
|
||||||
func (t *TokensBucketTest) new(
|
func (t *TokensBucketTest) new(
|
||||||
maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket {
|
maxRPS uint32, unusedTokenTime, checkTokenTime time.Duration, sleeper sleeper) *TokensBucket {
|
||||||
|
shardCount := uint32(runtime.NumCPU() * 2) // Use double the CPU count for sharding
|
||||||
|
shards := make([]*tokenShard, shardCount)
|
||||||
|
for i := range shards {
|
||||||
|
shards[i] = &tokenShard{tokens: make(map[string]*token)}
|
||||||
|
}
|
||||||
|
|
||||||
bucket := &TokensBucket{
|
bucket := &TokensBucket{
|
||||||
maxRPS: maxRPS,
|
maxRPS: maxRPS,
|
||||||
unusedTokenTime: unusedTokenTime,
|
unusedTokenTime: unusedTokenTime.Nanoseconds(),
|
||||||
checkTokenTime: checkTokenTime,
|
checkTokenTime: checkTokenTime,
|
||||||
|
shards: shards,
|
||||||
|
shardCount: shardCount,
|
||||||
sleep: sleeper,
|
sleep: sleeper,
|
||||||
}
|
}
|
||||||
runtime.SetFinalizer(bucket, destructBasket)
|
|
||||||
|
runtime.SetFinalizer(bucket, destructBucket)
|
||||||
return bucket
|
return bucket
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,12 +55,14 @@ func (t *TokensBucketTest) Test_Obtain_NoThrottle() {
|
|||||||
func (t *TokensBucketTest) Test_Obtain_Sleep() {
|
func (t *TokensBucketTest) Test_Obtain_Sleep() {
|
||||||
clock := &fakeSleeper{}
|
clock := &fakeSleeper{}
|
||||||
tb := t.new(100, time.Hour, time.Minute, clock)
|
tb := t.new(100, time.Hour, time.Minute, clock)
|
||||||
|
_, exists := tb.getShard("w").tokens["w"]
|
||||||
|
t.Require().False(exists)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 301; i++ {
|
for i := 0; i < 301; i++ {
|
||||||
tb.Obtain("a")
|
tb.Obtain("w")
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
@ -63,15 +74,15 @@ func (t *TokensBucketTest) Test_Obtain_Sleep() {
|
|||||||
func (t *TokensBucketTest) Test_Obtain_AddRPS() {
|
func (t *TokensBucketTest) Test_Obtain_AddRPS() {
|
||||||
clock := clockwork.NewFakeClock()
|
clock := clockwork.NewFakeClock()
|
||||||
tb := t.new(100, time.Hour, time.Minute, clock)
|
tb := t.new(100, time.Hour, time.Minute, clock)
|
||||||
go tb.deleteUnusedToken()
|
go tb.cleanupRoutine()
|
||||||
tb.Obtain("a")
|
tb.Obtain("a")
|
||||||
clock.Advance(time.Minute * 2)
|
clock.Advance(time.Minute * 2)
|
||||||
|
|
||||||
item, found := tb.tokens.Load("a")
|
item, found := tb.getShard("a").tokens["a"]
|
||||||
t.Require().True(found)
|
t.Require().True(found)
|
||||||
t.Assert().Equal(1, int(item.(*token).rps.Load()))
|
t.Assert().Equal(1, int(item.rps))
|
||||||
tb.Obtain("a")
|
tb.Obtain("a")
|
||||||
t.Assert().Equal(2, int(item.(*token).rps.Load()))
|
t.Assert().Equal(2, int(item.rps))
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeSleeper struct {
|
type fakeSleeper struct {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user