285 lines
7.1 KiB
Go
285 lines
7.1 KiB
Go
|
package discord
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"math"
|
||
|
"net/http"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"git.handmade.network/hmn/hmn/src/logging"
|
||
|
"git.handmade.network/hmn/hmn/src/utils"
|
||
|
)
|
||
|
|
||
|
var limiterLog = logging.GlobalLogger().With().
|
||
|
Str("module", "discord").
|
||
|
Str("discord actor", "rate limiter").
|
||
|
Logger()
|
||
|
|
||
|
var buckets sync.Map // map[route name]bucket name
|
||
|
var rateLimiters sync.Map // map[bucket name]*restRateLimiter
|
||
|
var limiterInitMutex sync.Mutex
|
||
|
|
||
|
type restRateLimiter struct {
|
||
|
requests chan struct{}
|
||
|
refills chan rateLimiterRefill
|
||
|
}
|
||
|
|
||
|
type rateLimiterRefill struct {
|
||
|
resetAfter time.Duration
|
||
|
maxRequests int
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
Whenever we send a request, we must sleep until this time
|
||
|
(if it is in the future, of course). This is a quick and
|
||
|
dirty way to pause all sending in case of a global rate
|
||
|
limit.
|
||
|
|
||
|
I could put a mutex on this but I don't think it's actually
|
||
|
a problem to have race conditions here. Just set it when
|
||
|
you get throttled. EZ.
|
||
|
*/
|
||
|
var globalRateLimitTime time.Time
|
||
|
|
||
|
type rateLimitHeaders struct {
|
||
|
Bucket string
|
||
|
Limit int
|
||
|
Remaining int
|
||
|
ResetAfter time.Duration
|
||
|
}
|
||
|
|
||
|
func parseRateLimitHeaders(header http.Header) (rateLimitHeaders, bool) {
|
||
|
var err error
|
||
|
|
||
|
bucket := header.Get("X-RateLimit-Bucket")
|
||
|
var limit int
|
||
|
var remaining int
|
||
|
var resetAfter time.Duration
|
||
|
|
||
|
limitStr := header.Get("X-RateLimit-Limit")
|
||
|
if limitStr != "" {
|
||
|
limit, err = strconv.Atoi(limitStr)
|
||
|
if err != nil {
|
||
|
limiterLog.Error().
|
||
|
Err(err).
|
||
|
Str("value", limitStr).
|
||
|
Msg("failed to parse X-RateLimit-Limit header")
|
||
|
return rateLimitHeaders{}, false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
remainingStr := header.Get("X-RateLimit-Remaining")
|
||
|
if remainingStr != "" {
|
||
|
remaining, err = strconv.Atoi(remainingStr)
|
||
|
if err != nil {
|
||
|
limiterLog.Error().
|
||
|
Err(err).
|
||
|
Str("value", remainingStr).
|
||
|
Msg("failed to parse X-RateLimit-Remaining header")
|
||
|
return rateLimitHeaders{}, false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
resetAfterStr := header.Get("X-RateLimit-Reset-After")
|
||
|
if resetAfterStr != "" {
|
||
|
resetAfterSeconds, err := strconv.ParseFloat(resetAfterStr, 64)
|
||
|
if err != nil {
|
||
|
limiterLog.Error().
|
||
|
Err(err).
|
||
|
Str("value", resetAfterStr).
|
||
|
Msg("failed to parse X-RateLimit-Reset-After header")
|
||
|
return rateLimitHeaders{}, false
|
||
|
}
|
||
|
resetAfter = time.Duration(math.Ceil(resetAfterSeconds)) * time.Second
|
||
|
}
|
||
|
|
||
|
return rateLimitHeaders{
|
||
|
Bucket: bucket,
|
||
|
Limit: limit,
|
||
|
Remaining: remaining,
|
||
|
ResetAfter: resetAfter,
|
||
|
}, true
|
||
|
}
|
||
|
|
||
|
func createLimiter(headers rateLimitHeaders, routeName string) {
|
||
|
limiterInitMutex.Lock()
|
||
|
defer limiterInitMutex.Unlock()
|
||
|
|
||
|
buckets.Store(routeName, headers.Bucket)
|
||
|
ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{
|
||
|
requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts
|
||
|
refills: make(chan rateLimiterRefill),
|
||
|
})
|
||
|
if !loaded {
|
||
|
limiter := ilimiter.(*restRateLimiter)
|
||
|
|
||
|
log := limiterLog.With().Str("bucket", headers.Bucket).Logger()
|
||
|
|
||
|
prefillloop:
|
||
|
// Pre-fill the limiter with remaining requests
|
||
|
for i := 0; i < headers.Remaining; i++ {
|
||
|
select {
|
||
|
case limiter.requests <- struct{}{}:
|
||
|
default:
|
||
|
log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity")
|
||
|
break prefillloop
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
Start the refiller for this bucket. It waits for a request to tell
|
||
|
it when to next reset the rate limit, and how full to fill the bucket.
|
||
|
It then sleeps and refills the bucket, just like it should :)
|
||
|
*/
|
||
|
go func() {
|
||
|
for {
|
||
|
// Wake up on the first request after refilling
|
||
|
refill := <-limiter.refills
|
||
|
|
||
|
// Sleep for the remainder of the bucket's time
|
||
|
time.Sleep(refill.resetAfter)
|
||
|
|
||
|
drainloop:
|
||
|
// drain the bucket
|
||
|
for {
|
||
|
select {
|
||
|
case <-limiter.requests:
|
||
|
default:
|
||
|
break drainloop
|
||
|
}
|
||
|
}
|
||
|
|
||
|
refillloop:
|
||
|
// refill it with the max number of requests
|
||
|
for i := 0; i < refill.maxRequests; i++ {
|
||
|
select {
|
||
|
case limiter.requests <- struct{}{}:
|
||
|
default:
|
||
|
log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity")
|
||
|
break refillloop
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// And then we wait again to hear about our next
|
||
|
// bucket's worth of requests.
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
// Tell the refiller about its first refill
|
||
|
limiter.refills <- rateLimiterRefill{
|
||
|
resetAfter: headers.ResetAfter,
|
||
|
maxRequests: headers.Limit,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (l *restRateLimiter) update(headers rateLimitHeaders) {
|
||
|
refill := rateLimiterRefill{
|
||
|
resetAfter: headers.ResetAfter,
|
||
|
maxRequests: headers.Limit,
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
Tell the refiller about this request. If the refiller is already
|
||
|
busy sleeping, this will have no effect, which is what we want.
|
||
|
(It's already sleeping for as long as it needs to.)
|
||
|
*/
|
||
|
select {
|
||
|
case l.refills <- refill:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func doWithRateLimiting(ctx context.Context, routeName string, getReq func(ctx context.Context) *http.Request) (*http.Response, error) {
|
||
|
var bucket string
|
||
|
ibucket, ok := buckets.Load(routeName)
|
||
|
if ok {
|
||
|
bucket = ibucket.(string)
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
var limiter *restRateLimiter
|
||
|
if bucket != "" {
|
||
|
ilimiter, ok := rateLimiters.Load(bucket)
|
||
|
if ok {
|
||
|
limiter = ilimiter.(*restRateLimiter)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if globalRateLimitTime.After(time.Now()) {
|
||
|
// oh boy, global rate limit, pause until the coast is clear
|
||
|
err := utils.SleepContext(ctx, globalRateLimitTime.Sub(time.Now())+1*time.Second)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if limiter != nil {
|
||
|
select {
|
||
|
case <-limiter.requests:
|
||
|
case <-ctx.Done():
|
||
|
return nil, errors.New("request interrupted during rate limiting")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
res, err := httpClient.Do(getReq(ctx))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
headers, headersOk := parseRateLimitHeaders(res.Header)
|
||
|
if headersOk {
|
||
|
if limiter == nil || headers.Bucket != bucket {
|
||
|
createLimiter(headers, routeName)
|
||
|
} else {
|
||
|
limiter.update(headers)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if res.StatusCode == 429 {
|
||
|
if res.Header.Get("X-RateLimit-Global") != "" {
|
||
|
// globally rate limited
|
||
|
logging.ExtractLogger(ctx).Warn().Msg("got globally rate limited by Discord")
|
||
|
retryAfter, err := strconv.Atoi(res.Header.Get("Retry-After"))
|
||
|
if err == nil {
|
||
|
globalRateLimitTime = time.Now().Add(time.Duration(retryAfter) * time.Second)
|
||
|
} else {
|
||
|
// well this is bad, just sleep for 60 seconds and pray that it's long enough
|
||
|
logging.ExtractLogger(ctx).Warn().
|
||
|
Err(err).
|
||
|
Msg("got globally rate limited but couldn't determine how long to wait")
|
||
|
globalRateLimitTime = time.Now().Add(60 * time.Second)
|
||
|
}
|
||
|
} else {
|
||
|
// locally rate limited
|
||
|
|
||
|
/*
|
||
|
Despite our best efforts, we ended up rate limited anyway.
|
||
|
Simply wait the amount of time Discord asks, and then try
|
||
|
again. On the next go-around, hopefully we'll either succeed
|
||
|
or have a rate limiter initialized and ready to go.
|
||
|
*/
|
||
|
logging.ExtractLogger(ctx).Warn().Msg("got rate limited by Discord")
|
||
|
if headersOk {
|
||
|
err := utils.SleepContext(ctx, headers.ResetAfter)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
} else {
|
||
|
logging.ExtractLogger(ctx).Warn().Msg("got rate limited, but didn't have the headers??")
|
||
|
err := utils.SleepContext(ctx, 1*time.Second)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
return res, nil
|
||
|
}
|
||
|
}
|