hmn/src/discord/ratelimiting.go

289 lines
7.2 KiB
Go

package discord
import (
"context"
"errors"
"math"
"net/http"
"strconv"
"sync"
"time"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/utils"
)
var limiterLog = logging.GlobalLogger().With().
Str("module", "discord").
Str("discord actor", "rate limiter").
Logger()
var buckets sync.Map // map[route name]bucket name
var rateLimiters sync.Map // map[bucket name]*restRateLimiter
var limiterInitMutex sync.Mutex
type restRateLimiter struct {
requests chan struct{}
refills chan rateLimiterRefill
}
type rateLimiterRefill struct {
resetAfter time.Duration
maxRequests int
}
/*
Whenever we send a request, we must sleep until this time
(if it is in the future, of course). This is a quick and
dirty way to pause all sending in case of a global rate
limit.
I could put a mutex on this but I don't think it's actually
a problem to have race conditions here. Just set it when
you get throttled. EZ.
*/
var globalRateLimitTime time.Time
type rateLimitHeaders struct {
Bucket string
Limit int
Remaining int
ResetAfter time.Duration
}
func parseRateLimitHeaders(header http.Header) (rateLimitHeaders, bool) {
var err error
bucket := header.Get("X-RateLimit-Bucket")
var limit int
var remaining int
var resetAfter time.Duration
limitStr := header.Get("X-RateLimit-Limit")
if limitStr != "" {
limit, err = strconv.Atoi(limitStr)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", limitStr).
Msg("failed to parse X-RateLimit-Limit header")
return rateLimitHeaders{}, false
}
}
remainingStr := header.Get("X-RateLimit-Remaining")
if remainingStr != "" {
remaining, err = strconv.Atoi(remainingStr)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", remainingStr).
Msg("failed to parse X-RateLimit-Remaining header")
return rateLimitHeaders{}, false
}
}
resetAfterStr := header.Get("X-RateLimit-Reset-After")
if resetAfterStr != "" {
resetAfterSeconds, err := strconv.ParseFloat(resetAfterStr, 64)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", resetAfterStr).
Msg("failed to parse X-RateLimit-Reset-After header")
return rateLimitHeaders{}, false
}
resetAfter = time.Duration(math.Ceil(resetAfterSeconds)) * time.Second
}
return rateLimitHeaders{
Bucket: bucket,
Limit: limit,
Remaining: remaining,
ResetAfter: resetAfter,
}, true
}
func createLimiter(headers rateLimitHeaders, routeName string) {
limiterInitMutex.Lock()
defer limiterInitMutex.Unlock()
buckets.Store(routeName, headers.Bucket)
ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{
requests: make(chan struct{}, 200), // presumably this is big enough to handle bursts
refills: make(chan rateLimiterRefill),
})
if !loaded {
limiter := ilimiter.(*restRateLimiter)
log := limiterLog.With().Str("bucket", headers.Bucket).Logger()
prefillloop:
// Pre-fill the limiter with remaining requests
for i := 0; i < headers.Remaining; i++ {
select {
case limiter.requests <- struct{}{}:
default:
log.Warn().
Int("remaining", headers.Remaining).
Msg("rate limiting channel was too small; you should increase the default capacity")
break prefillloop
}
}
/*
Start the refiller for this bucket. It waits for a request to tell
it when to next reset the rate limit, and how full to fill the bucket.
It then sleeps and refills the bucket, just like it should :)
*/
go func() {
for {
// Wake up on the first request after refilling
refill := <-limiter.refills
// Sleep for the remainder of the bucket's time
time.Sleep(refill.resetAfter)
drainloop:
// drain the bucket
for {
select {
case <-limiter.requests:
default:
break drainloop
}
}
refillloop:
// refill it with the max number of requests
for i := 0; i < refill.maxRequests; i++ {
select {
case limiter.requests <- struct{}{}:
default:
log.Warn().
Int("maxRequests", refill.maxRequests).
Msg("rate limiting channel was too small; you should increase the default capacity")
break refillloop
}
}
// And then we wait again to hear about our next
// bucket's worth of requests.
}
}()
// Tell the refiller about its first refill
limiter.refills <- rateLimiterRefill{
resetAfter: headers.ResetAfter,
maxRequests: headers.Limit,
}
}
}
func (l *restRateLimiter) update(headers rateLimitHeaders) {
refill := rateLimiterRefill{
resetAfter: headers.ResetAfter,
maxRequests: headers.Limit,
}
/*
Tell the refiller about this request. If the refiller is already
busy sleeping, this will have no effect, which is what we want.
(It's already sleeping for as long as it needs to.)
*/
select {
case l.refills <- refill:
default:
}
}
func doWithRateLimiting(ctx context.Context, routeName string, getReq func(ctx context.Context) *http.Request) (*http.Response, error) {
var bucket string
ibucket, ok := buckets.Load(routeName)
if ok {
bucket = ibucket.(string)
}
for {
var limiter *restRateLimiter
if bucket != "" {
ilimiter, ok := rateLimiters.Load(bucket)
if ok {
limiter = ilimiter.(*restRateLimiter)
}
}
if globalRateLimitTime.After(time.Now()) {
// oh boy, global rate limit, pause until the coast is clear
err := utils.SleepContext(ctx, globalRateLimitTime.Sub(time.Now())+1*time.Second)
if err != nil {
return nil, err
}
}
if limiter != nil {
select {
case <-limiter.requests:
case <-ctx.Done():
return nil, errors.New("request interrupted during rate limiting")
}
}
res, err := httpClient.Do(getReq(ctx))
if err != nil {
return nil, err
}
headers, headersOk := parseRateLimitHeaders(res.Header)
if headersOk {
if limiter == nil || headers.Bucket != bucket {
createLimiter(headers, routeName)
} else {
limiter.update(headers)
}
}
if res.StatusCode == 429 {
if res.Header.Get("X-RateLimit-Global") != "" {
// globally rate limited
logging.ExtractLogger(ctx).Warn().Msg("got globally rate limited by Discord")
retryAfter, err := strconv.Atoi(res.Header.Get("Retry-After"))
if err == nil {
globalRateLimitTime = time.Now().Add(time.Duration(retryAfter) * time.Second)
} else {
// well this is bad, just sleep for 60 seconds and pray that it's long enough
logging.ExtractLogger(ctx).Warn().
Err(err).
Msg("got globally rate limited but couldn't determine how long to wait")
globalRateLimitTime = time.Now().Add(60 * time.Second)
}
} else {
// locally rate limited
/*
Despite our best efforts, we ended up rate limited anyway.
Simply wait the amount of time Discord asks, and then try
again. On the next go-around, hopefully we'll either succeed
or have a rate limiter initialized and ready to go.
*/
logging.ExtractLogger(ctx).Warn().Msg("got rate limited by Discord")
if headersOk {
err := utils.SleepContext(ctx, headers.ResetAfter)
if err != nil {
return nil, err
}
} else {
logging.ExtractLogger(ctx).Warn().Msg("got rate limited, but didn't have the headers??")
err := utils.SleepContext(ctx, 1*time.Second)
if err != nil {
return nil, err
}
}
}
continue
}
return res, nil
}
}