Add Discord integration

Clean up several TODOs

Implement the full disconnect / resume flow

Detect zombied connections and restart

Implement the random delay on reconnect

Implement message sending!!

(with a goofy feedback loop on the echo bot)

Fix the feedback loop in the echo bot

Clean up the Discord gateway code

Many things are methods now to reduce the amount of explicit plumbing.
Connection handling should be a little more robust, and we have an
actual error handling strategy now.

Allow sending multiple Discord messages at once

Delete irrelevant tests

uhh, start rate limiting

Add per-route rate limiting

Add global rate limit handling

Handle context cancellation in Discord REST code

Allow changing buckets per route

Add the showcase rejection bot

Add library bot
This commit is contained in:
Ben Visness 2021-08-06 18:23:51 -05:00
parent c3c5968512
commit 38a1188be7
14 changed files with 1672 additions and 2 deletions

2
go.mod
View File

@ -10,11 +10,13 @@ require (
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8
github.com/go-stack/stack v1.8.0 github.com/go-stack/stack v1.8.0
github.com/google/uuid v1.2.0 github.com/google/uuid v1.2.0
github.com/gorilla/websocket v1.4.2
github.com/huandu/xstrings v1.3.2 // indirect github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgconn v1.8.0 github.com/jackc/pgconn v1.8.0
github.com/jackc/pgtype v1.6.2 github.com/jackc/pgtype v1.6.2
github.com/jackc/pgx/v4 v4.10.1 github.com/jackc/pgx/v4 v4.10.1
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.1.1 // indirect github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/rs/zerolog v1.21.0 github.com/rs/zerolog v1.21.0
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3

3
go.sum
View File

@ -101,6 +101,7 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
@ -185,6 +186,8 @@ github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv
github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94=
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=

View File

@ -25,6 +25,7 @@ type HMNConfig struct {
Auth AuthConfig Auth AuthConfig
Email EmailConfig Email EmailConfig
DigitalOcean DigitalOceanConfig DigitalOcean DigitalOceanConfig
Discord DiscordConfig
} }
type PostgresConfig struct { type PostgresConfig struct {
@ -62,6 +63,14 @@ type EmailConfig struct {
OverrideRecipientEmail string OverrideRecipientEmail string
} }
type DiscordConfig struct {
BotToken string
BotUserID string
ShowcaseChannelID string
LibraryChannelID string
}
func (info PostgresConfig) DSN() string { func (info PostgresConfig) DSN() string {
return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName) return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName)
} }

643
src/discord/gateway.go Normal file
View File

@ -0,0 +1,643 @@
package discord
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net"
"runtime"
"sync"
"time"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/utils"
"github.com/gorilla/websocket"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jpillora/backoff"
)
func RunDiscordBot(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{} {
log := logging.ExtractLogger(ctx).With().Str("module", "discord").Logger()
ctx = logging.AttachLoggerToContext(&log, ctx)
if config.Config.Discord.BotToken == "" {
log.Warn().Msg("No Discord bot token was provided, so the Discord bot cannot run.")
done := make(chan struct{}, 1)
done <- struct{}{}
return done
}
done := make(chan struct{})
go func() {
defer func() {
log.Debug().Msg("shut down Discord bot")
done <- struct{}{}
}()
boff := backoff.Backoff{
Min: 1 * time.Second,
Max: 5 * time.Minute,
}
for {
select {
case <-ctx.Done():
return
default:
}
func() {
log.Info().Msg("Connecting to the Discord gateway")
bot := newBotInstance(dbConn)
err := bot.Run(ctx)
if err != nil {
dur := boff.Duration()
log.Error().
Err(err).
Dur("retrying after", dur).
Msg("failed to run Discord bot")
timer := time.NewTimer(dur)
select {
case <-ctx.Done():
case <-timer.C:
}
return
}
select {
case <-ctx.Done():
return
default:
}
// This delay satisfies the 1 to 5 second delay Discord
// wants on reconnects, and seems fine to do every time.
delay := time.Duration(int64(time.Second) + rand.Int63n(int64(time.Second*4)))
log.Info().Dur("delay", delay).Msg("Reconnecting to Discord")
time.Sleep(delay)
boff.Reset()
}()
}
}()
return done
}
var outgoingMessagesReady = make(chan struct{}, 1)
type discordBotInstance struct {
conn *websocket.Conn
dbConn *pgxpool.Pool
heartbeatIntervalMs int
forceHeartbeat chan struct{}
/*
Every time we send a heartbeat, we set this variable to false.
Whenever we ack a heartbeat, we set this variable to true.
If we try to send a heartbeat but the previous one was not
acked, then we close the connection and try to reconnect.
*/
didAckHeartbeat bool
/*
All goroutines should call this when they exit, to ensure that
the other goroutines shut down as well.
*/
cancel context.CancelFunc
wg sync.WaitGroup
}
func newBotInstance(dbConn *pgxpool.Pool) *discordBotInstance {
return &discordBotInstance{
dbConn: dbConn,
forceHeartbeat: make(chan struct{}),
didAckHeartbeat: true,
}
}
/*
Runs a bot instance to completion. It will start up a gateway connection and return when the
connection is closed. It only returns an error when something unexpected occurs; if so, you should
do exponential backoff before reconnecting. Otherwise you can reconnect right away.
*/
func (bot *discordBotInstance) Run(ctx context.Context) (err error) {
defer utils.RecoverPanicAsError(&err)
ctx, bot.cancel = context.WithCancel(ctx)
defer bot.cancel()
err = bot.connect(ctx)
if err != nil {
return oops.New(err, "failed to connect to Discord gateway")
}
defer bot.conn.Close()
bot.wg.Add(1)
go bot.doSender(ctx)
// Wait for child goroutines to exit (they will do so when context is canceled). This ensures
// that nothing is in the middle of sending. Then close the connection, so that this goroutine
// can finish as well.
go func() {
bot.wg.Wait()
bot.conn.Close()
}()
for {
msg, err := bot.receiveGatewayMessage(ctx)
if err != nil {
// TODO: Are there other kinds of connection close events that we need to handle? Probably?
if errors.Is(err, net.ErrClosed) {
// If the connection is closed, that's our cue to shut down the bot. Any errors
// related to the closure will have been logged elsewhere anyway.
return nil
} else {
return oops.New(err, "failed to receive message from the gateway")
}
}
// Update the sequence number in the db
if msg.SequenceNumber != nil {
_, err = bot.dbConn.Exec(ctx, `UPDATE discord_session SET sequence_number = $1`, *msg.SequenceNumber)
if err != nil {
return oops.New(err, "failed to save latest sequence number")
}
}
switch msg.Opcode {
case OpcodeDispatch:
// Just a normal event
err := bot.processEventMsg(ctx, msg)
if err != nil {
return oops.New(err, "failed to process gateway event")
}
case OpcodeHeartbeat:
bot.forceHeartbeat <- struct{}{}
case OpcodeHeartbeatACK:
bot.didAckHeartbeat = true
case OpcodeReconnect:
logging.ExtractLogger(ctx).Info().Msg("Discord asked us to reconnect to the gateway")
return nil
case OpcodeInvalidSession:
// We tried to resume but the session was invalid.
// Delete the session and reconnect from scratch again.
_, err := bot.dbConn.Exec(ctx, `DELETE FROM discord_session`)
if err != nil {
return oops.New(err, "failed to delete invalid session")
}
return nil
}
}
}
/*
The connection process in short:
- Gateway sends Hello, asking the client to heartbeat on some interval
- Client sends Identify and starts heartbeat process
- Gateway sends Ready, client is now connected to gateway
Or, if we have an existing session:
- Gateway sends Hello, asking the client to heartbeat on some interval
- Client sends Resume and starts heartbeat process
- Gateway sends all missed events followed by a RESUMED event, or an Invalid Session if the
session is ded
Note that some events probably won't be received until the Guild Create message is received.
It's a little annoying to handle resumes since we want to handle the missed messages as if we were
receiving them in real time. But we're kind of in a different state from from when we're normally
receiving messages, because we are expecting a RESUMED event at the end, and the first message we
receive might be an Invalid Session. So, unfortunately, we just have to handle the Invalid Session
and RESUMED messages in our main message receiving loop instead of here.
(Discord could have prevented this if they send a "Resume ACK" message before replaying events.
That way, we could receive exactly one message after sending Resume, either a Resume ACK or an
Invalid Session, and from there it would be crystal clear what to do. Alas!)
*/
func (bot *discordBotInstance) connect(ctx context.Context) error {
res, err := GetGatewayBot(ctx)
if err != nil {
return oops.New(err, "failed to get gateway URL")
}
conn, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s/?v=9&encoding=json", res.URL), nil)
if err != nil {
return oops.New(err, "failed to connect to the Discord gateway")
}
bot.conn = conn
helloMessage, err := bot.receiveGatewayMessage(ctx)
if err != nil {
return oops.New(err, "failed to read Hello message")
}
if helloMessage.Opcode != OpcodeHello {
return oops.New(nil, "expected a Hello (opcode %d), but got opcode %d", OpcodeHello, helloMessage.Opcode)
}
helloData := HelloFromMap(helloMessage.Data)
bot.heartbeatIntervalMs = helloData.HeartbeatIntervalMs
// Now that the gateway has said hello, we need to establish a new session, either resuming
// an old one or starting a new one.
shouldResume := true
isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`)
if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) {
// No session yet! Just identify and get on with it
shouldResume = false
} else {
return oops.New(err, "failed to get current session from database")
}
}
if shouldResume {
// Reconnect to the previous session
session := isession.(*models.DiscordSession)
err := bot.sendGatewayMessage(ctx, GatewayMessage{
Opcode: OpcodeResume,
Data: Resume{
Token: config.Config.Discord.BotToken,
SessionID: session.ID,
SequenceNumber: session.SequenceNumber,
},
})
if err != nil {
return oops.New(err, "failed to send Resume message")
}
return nil
} else {
// Start a new session
err := bot.sendGatewayMessage(ctx, GatewayMessage{
Opcode: OpcodeIdentify,
Data: Identify{
Token: config.Config.Discord.BotToken,
Properties: IdentifyConnectionProperties{
OS: runtime.GOOS,
Browser: BotName,
Device: BotName,
},
Intents: IntentGuilds | IntentGuildMessages,
},
})
if err != nil {
return oops.New(err, "failed to send Identify message")
}
readyMessage, err := bot.receiveGatewayMessage(ctx)
if err != nil {
return oops.New(err, "failed to read Ready message")
}
if readyMessage.Opcode != OpcodeDispatch {
return oops.New(err, "expected a READY event, but got a message with opcode %d", readyMessage.Opcode)
}
if *readyMessage.EventName != "READY" {
return oops.New(err, "expected a READY event, but got a %s event", *readyMessage.EventName)
}
readyData := ReadyFromMap(readyMessage.Data)
_, err = bot.dbConn.Exec(ctx,
`
INSERT INTO discord_session (session_id, sequence_number)
VALUES ($1, $2)
ON CONFLICT (pk) DO UPDATE
SET session_id = $1, sequence_number = $2
`,
readyData.SessionID,
*readyMessage.SequenceNumber,
)
if err != nil {
return oops.New(err, "failed to save new bot session in the database")
}
}
return nil
}
/*
Sends outgoing gateway messages and channel messages. Handles heartbeats. This function should be
run as its own goroutine.
*/
func (bot *discordBotInstance) doSender(ctx context.Context) {
defer bot.wg.Done()
defer bot.cancel()
log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender").Logger()
ctx = logging.AttachLoggerToContext(&log, ctx)
defer log.Info().Msg("shutting down Discord sender")
/*
The first heartbeat is supposed to occur at a random time within
the first heartbeat interval.
https://discord.com/developers/docs/topics/gateway#heartbeating
*/
dur := time.Duration(bot.heartbeatIntervalMs) * time.Millisecond
firstDelay := time.NewTimer(time.Duration(rand.Int63n(int64(dur))))
heartbeatTicker := &time.Ticker{} // this will start never ticking, and get initialized after the first heartbeat
// Returns false if the heartbeat failed
sendHeartbeat := func() bool {
if !bot.didAckHeartbeat {
log.Error().Msg("did not receive a heartbeat ACK in between heartbeats")
return false
}
bot.didAckHeartbeat = false
latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`)
if err != nil {
log.Error().Err(err).Msg("failed to fetch latest sequence number from the db")
return false
}
err = bot.sendGatewayMessage(ctx, GatewayMessage{
Opcode: OpcodeHeartbeat,
Data: latestSequenceNumber,
})
if err != nil {
log.Error().Err(err).Msg("failed to send heartbeat")
return false
}
return true
}
/*
Start a goroutine to fetch outgoing messages from the db. We do this in a separate goroutine
to ensure that issues talking to the database don't prevent us from sending heartbeats.
*/
messages := make(chan *models.DiscordOutgoingMessage)
bot.wg.Add(1)
go func(ctx context.Context) {
defer bot.wg.Done()
defer bot.cancel()
log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender db reader").Logger()
ctx = logging.AttachLoggerToContext(&log, ctx)
defer log.Info().Msg("stopping db reader")
// We will poll the database just in case the notification mechanism doesn't work.
ticker := time.NewTicker(time.Second * 5)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-outgoingMessagesReady:
}
func() {
tx, err := bot.dbConn.Begin(ctx)
if err != nil {
log.Error().Err(err).Msg("failed to start transaction")
return
}
defer tx.Rollback(ctx)
itMessages, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, `
SELECT $columns
FROM discord_outgoingmessages
ORDER BY id ASC
`)
if err != nil {
log.Error().Err(err).Msg("failed to fetch outgoing Discord messages")
return
}
msgs := itMessages.ToSlice()
for _, imsg := range msgs {
msg := imsg.(*models.DiscordOutgoingMessage)
if time.Now().After(msg.ExpiresAt) {
continue
}
messages <- msg
}
/*
NOTE(ben): Doing this in a transaction means that we will only delete the
messages that we originally fetched. At least, as long as the database's
isolation level is Read Committed, which is the default.
https://www.postgresql.org/docs/current/transaction-iso.html
*/
_, err = tx.Exec(ctx, `DELETE FROM discord_outgoingmessages`)
if err != nil {
log.Error().Err(err).Msg("failed to delete outgoing messages")
return
}
err = tx.Commit(ctx)
if err != nil {
log.Error().Err(err).Msg("failed to read and delete outgoing messages")
return
}
if len(msgs) > 0 {
log.Debug().Int("num messages", len(msgs)).Msg("Sent and deleted outgoing messages")
}
}()
}
}(ctx)
/*
Whenever we want to send a gateway message, we must receive a value from
this channel first. A goroutine continuously fills the channel at a rate
that respects Discord's gateway rate limit.
Don't use this for heartbeats; heartbeats should go out immediately.
Don't forget that the server can request a heartbeat at any time.
See the docs for more details. The capacity of this channel is chosen to
always leave us overhead for heartbeats and other shenanigans.
https://discord.com/developers/docs/topics/gateway#rate-limiting
*/
rateLimiter := make(chan struct{}, 100)
go func() {
for {
rateLimiter <- struct{}{}
time.Sleep(500 * time.Millisecond)
}
}()
/*
NOTE(ben): This rate limiter is actually not used right now
because we're not actually sending any meaningful gateway
messages. But in the future, if we end up sending presence
updates or other gateway commands, we need to make sure to
put this limiter on all of those outgoing commands.
*/
for {
select {
case <-ctx.Done():
return
case <-firstDelay.C:
if ok := sendHeartbeat(); !ok {
return
}
heartbeatTicker = time.NewTicker(dur)
case <-heartbeatTicker.C:
if ok := sendHeartbeat(); !ok {
return
}
case <-bot.forceHeartbeat:
if ok := sendHeartbeat(); !ok {
return
}
heartbeatTicker.Reset(dur)
case msg := <-messages:
_, err := CreateMessage(ctx, msg.ChannelID, msg.PayloadJSON)
if err != nil {
log.Error().Err(err).Msg("failed to send Discord message")
}
}
}
}
func (bot *discordBotInstance) receiveGatewayMessage(ctx context.Context) (*GatewayMessage, error) {
_, msgBytes, err := bot.conn.ReadMessage()
if err != nil {
return nil, err
}
var msg GatewayMessage
err = json.Unmarshal(msgBytes, &msg)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord gateway message")
}
logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("received gateway message")
return &msg, nil
}
func (bot *discordBotInstance) sendGatewayMessage(ctx context.Context, msg GatewayMessage) error {
logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("sending gateway message")
return bot.conn.WriteMessage(websocket.TextMessage, msg.ToJSON())
}
/*
Processes a single event message from Discord. If this returns an error, it means something has
really gone wrong, bad enough that the connection should be shut down. Otherwise it will just log
any errors that occur.
*/
func (bot *discordBotInstance) processEventMsg(ctx context.Context, msg *GatewayMessage) error {
if msg.Opcode != OpcodeDispatch {
panic(fmt.Sprintf("processEventMsg must only be used on Dispatch messages (opcode %d). Validate this before you call this function.", OpcodeDispatch))
}
switch *msg.EventName {
case "RESUMED":
// Nothing to do, but at least we can log something
logging.ExtractLogger(ctx).Info().Msg("Finished resuming gateway session")
case "MESSAGE_CREATE":
newMessage := MessageFromMap(msg.Data)
err := bot.messageCreateOrUpdate(ctx, &newMessage)
if err != nil {
return oops.New(err, "error on new message")
}
case "MESSAGE_UPDATE":
newMessage := MessageFromMap(msg.Data)
err := bot.messageCreateOrUpdate(ctx, &newMessage)
if err != nil {
return oops.New(err, "error on updated message")
}
}
return nil
}
func (bot *discordBotInstance) messageCreateOrUpdate(ctx context.Context, msg *Message) error {
if msg.Author != nil && msg.Author.ID == config.Config.Discord.BotUserID {
// Don't process your own messages
return nil
}
if msg.ChannelID == config.Config.Discord.ShowcaseChannelID {
err := bot.processShowcaseMsg(ctx, msg)
if err != nil {
return oops.New(err, "failed to process showcase message")
}
return nil
}
if msg.ChannelID == config.Config.Discord.LibraryChannelID {
err := bot.processLibraryMsg(ctx, msg)
if err != nil {
return oops.New(err, "failed to process library message")
}
return nil
}
return nil
}
type MessageToSend struct {
ChannelID string
Req CreateMessageRequest
ExpiresAt time.Time
}
func SendMessages(
ctx context.Context,
conn *pgxpool.Pool,
msgs ...MessageToSend,
) error {
tx, err := conn.Begin(ctx)
if err != nil {
return oops.New(err, "failed to start transaction")
}
defer tx.Rollback(ctx)
for _, msg := range msgs {
if msg.ExpiresAt.IsZero() {
msg.ExpiresAt = time.Now().Add(30 * time.Second)
}
reqBytes, err := json.Marshal(msg.Req)
if err != nil {
return oops.New(err, "failed to marshal Discord message to JSON")
}
_, err = tx.Exec(ctx,
`
INSERT INTO discord_outgoingmessages (channel_id, payload_json, expires_at)
VALUES ($1, $2, $3)
`,
msg.ChannelID,
string(reqBytes),
msg.ExpiresAt,
)
if err != nil {
return oops.New(err, "failed to save outgoing Discord message to the database")
}
}
err = tx.Commit(ctx)
if err != nil {
return oops.New(err, "failed to commit outgoing Discord messages")
}
// Notify the sender that messages are ready to go
select {
case outgoingMessagesReady <- struct{}{}:
default:
}
return nil
}

292
src/discord/payloads.go Normal file
View File

@ -0,0 +1,292 @@
package discord
import (
"encoding/json"
)
type Opcode int
// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes
// NOTE(ben): I'm not using iota because 5 is missing
const (
OpcodeDispatch Opcode = 0
OpcodeHeartbeat Opcode = 1
OpcodeIdentify Opcode = 2
OpcodePresenceUpdate Opcode = 3
OpcodeVoiceStateUpdate Opcode = 4
OpcodeResume Opcode = 6
OpcodeReconnect Opcode = 7
OpcodeRequestGuildMembers Opcode = 8
OpcodeInvalidSession Opcode = 9
OpcodeHello Opcode = 10
OpcodeHeartbeatACK Opcode = 11
)
type Intent int
// https://discord.com/developers/docs/topics/gateway#list-of-intents
// NOTE(ben): I'm not using iota because the opcode thing made me paranoid
const (
IntentGuilds Intent = 1 << 0
IntentGuildMembers Intent = 1 << 1
IntentGuildBans Intent = 1 << 2
IntentGuildEmojisAndStickers Intent = 1 << 3
IntentGuildIntegrations Intent = 1 << 4
IntentGuildWebhooks Intent = 1 << 5
IntentGuildInvites Intent = 1 << 6
IntentGuildVoiceStates Intent = 1 << 7
IntentGuildPresences Intent = 1 << 8
IntentGuildMessages Intent = 1 << 9
IntentGuildMessageReactions Intent = 1 << 10
IntentGuildMessageTyping Intent = 1 << 11
IntentDirectMessages Intent = 1 << 12
IntentDirectMessageReactions Intent = 1 << 13
IntentDirectMessageTyping Intent = 1 << 14
)
type GatewayMessage struct {
Opcode Opcode `json:"op"`
Data interface{} `json:"d"`
SequenceNumber *int `json:"s,omitempty"`
EventName *string `json:"t,omitempty"`
}
func (m *GatewayMessage) ToJSON() []byte {
mBytes, err := json.Marshal(m)
if err != nil {
panic(err)
}
// TODO: check if the payload is too big, either here or where we actually send
// https://discord.com/developers/docs/topics/gateway#sending-payloads
return mBytes
}
type Hello struct {
HeartbeatIntervalMs int `json:"heartbeat_interval"`
}
func HelloFromMap(m interface{}) Hello {
// TODO: This should probably have some error handling, right?
return Hello{
HeartbeatIntervalMs: int(m.(map[string]interface{})["heartbeat_interval"].(float64)),
}
}
type Identify struct {
Token string `json:"token"`
Properties IdentifyConnectionProperties `json:"properties"`
Intents Intent `json:"intents"`
}
type IdentifyConnectionProperties struct {
OS string `json:"$os"`
Browser string `json:"$browser"`
Device string `json:"$device"`
}
type Ready struct {
GatewayVersion int `json:"v"`
User User `json:"user"`
SessionID string `json:"session_id"`
}
func ReadyFromMap(m interface{}) Ready {
mmap := m.(map[string]interface{})
return Ready{
GatewayVersion: int(mmap["v"].(float64)),
User: UserFromMap(mmap["user"]),
SessionID: mmap["session_id"].(string),
}
}
type Resume struct {
Token string `json:"token"`
SessionID string `json:"session_id"`
SequenceNumber int `json:"seq"`
}
type ChannelType int
const (
ChannelTypeGuildext ChannelType = 0
ChannelTypeDM ChannelType = 1
ChannelTypeGuildVoice ChannelType = 2
ChannelTypeGroupDM ChannelType = 3
ChannelTypeGuildCategory ChannelType = 4
ChannelTypeGuildNews ChannelType = 5
ChannelTypeGuildStore ChannelType = 6
ChannelTypeGuildNewsThread ChannelType = 10
ChannelTypeGuildPublicThread ChannelType = 11
ChannelTypeGuildPrivateThread ChannelType = 12
ChannelTypeGuildStageVoice ChannelType = 13
)
type Channel struct {
ID string `json:"id"`
Type ChannelType `json:"type"`
GuildID string `json:"guild_id"`
Name string `json:"name"`
Receipients []User `json:"recipients"`
OwnerID User `json:"owner_id"`
ParentID *string `json:"parent_id"`
}
type MessageType int
const (
MessageTypeDefault MessageType = 0
MessageTypeRecipientAdd MessageType = 1
MessageTypeRecipientRemove MessageType = 2
MessageTypeCall MessageType = 3
MessageTypeChannelNameChange MessageType = 4
MessageTypeChannelIconChange MessageType = 5
MessageTypeChannelPinnedMessage MessageType = 6
MessageTypeGuildMemberJoin MessageType = 7
MessageTypeUserPremiumGuildSubscription MessageType = 8
MessageTypeUserPremiumGuildSubscriptionTier1 MessageType = 9
MessageTypeUserPremiumGuildSubscriptionTier2 MessageType = 10
MessageTypeUserPremiumGuildSubscriptionTier3 MessageType = 11
MessageTypeChannelFollowAdd MessageType = 12
MessageTypeGuildDiscoveryDisqualified MessageType = 14
MessageTypeGuildDiscoveryRequalified MessageType = 15
MessageTypeGuildDiscoveryGracePeriodInitialWarning MessageType = 16
MessageTypeGuildDiscoveryGracePeriodFinalWarning MessageType = 17
MessageTypeThreadCreated MessageType = 18
MessageTypeReply MessageType = 19
MessageTypeApplicationCommand MessageType = 20
MessageTypeThreadStarterMessage MessageType = 21
MessageTypeGuildInviteReminder MessageType = 22
)
// https://discord.com/developers/docs/resources/channel#message-object
type Message struct {
ID string `json:"id"`
ChannelID string `json:"channel_id"`
Content string `json:"content"`
Author *User `json:"author"` // note that this may not be an actual valid user (see the docs)
// TODO: Author info
// TODO: Timestamp parsing, yay
Type MessageType `json:"type"`
Attachments []Attachment `json:"attachments"`
originalMap map[string]interface{}
}
func MessageFromMap(m interface{}) Message {
/*
Some gateway events, like MESSAGE_UPDATE, do not contain the
entire message body. So we need to be defensive on all fields here,
except the most basic identifying information.
*/
mmap := m.(map[string]interface{})
msg := Message{
ID: mmap["id"].(string),
ChannelID: mmap["channel_id"].(string),
Content: maybeString(mmap, "content"),
Type: MessageType(maybeInt(mmap, "type")),
originalMap: mmap,
}
if author, ok := mmap["author"]; ok {
u := UserFromMap(author)
msg.Author = &u
}
if iattachments, ok := mmap["attachments"]; ok {
attachments := iattachments.([]interface{})
for _, iattachment := range attachments {
msg.Attachments = append(msg.Attachments, AttachmentFromMap(iattachment))
}
}
return msg
}
// https://discord.com/developers/docs/resources/user#user-object
type User struct {
ID string `json:"id"`
Username string `json:"username"`
Discriminator string `json:"discriminator"`
IsBot bool `json:"bot"`
}
func UserFromMap(m interface{}) User {
mmap := m.(map[string]interface{})
u := User{
ID: mmap["id"].(string),
Username: mmap["username"].(string),
Discriminator: mmap["discriminator"].(string),
}
if isBot, ok := mmap["bot"]; ok {
u.IsBot = isBot.(bool)
}
return u
}
type Attachment struct {
ID string `json:"id"`
Filename string `json:"filename"`
ContentType string `json:"content_type"`
Size int `json:"size"`
Url string `json:"url"`
ProxyUrl string `json:"proxy_url"`
Height *int `json:"height"`
Width *int `json:"width"`
}
func AttachmentFromMap(m interface{}) Attachment {
mmap := m.(map[string]interface{})
a := Attachment{
ID: mmap["id"].(string),
Filename: mmap["filename"].(string),
ContentType: maybeString(mmap, "content_type"),
Size: int(mmap["size"].(float64)),
Url: mmap["url"].(string),
ProxyUrl: mmap["proxy_url"].(string),
Height: maybeIntP(mmap, "height"),
Width: maybeIntP(mmap, "width"),
}
return a
}
func maybeString(m map[string]interface{}, k string) string {
val, ok := m[k]
if !ok {
return ""
}
return val.(string)
}
func maybeInt(m map[string]interface{}, k string) int {
val, ok := m[k]
if !ok {
return 0
}
return int(val.(float64))
}
func maybeIntP(m map[string]interface{}, k string) *int {
val, ok := m[k]
if !ok {
return nil
}
intval := int(val.(float64))
return &intval
}

284
src/discord/ratelimiting.go Normal file
View File

@ -0,0 +1,284 @@
package discord
import (
"context"
"errors"
"math"
"net/http"
"strconv"
"sync"
"time"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/utils"
)
var limiterLog = logging.GlobalLogger().With().
Str("module", "discord").
Str("discord actor", "rate limiter").
Logger()
var buckets sync.Map // map[route name]bucket name
var rateLimiters sync.Map // map[bucket name]*restRateLimiter
var limiterInitMutex sync.Mutex
type restRateLimiter struct {
requests chan struct{}
refills chan rateLimiterRefill
}
type rateLimiterRefill struct {
resetAfter time.Duration
maxRequests int
}
/*
Whenever we send a request, we must sleep until this time
(if it is in the future, of course). This is a quick and
dirty way to pause all sending in case of a global rate
limit.
I could put a mutex on this but I don't think it's actually
a problem to have race conditions here. Just set it when
you get throttled. EZ.
*/
var globalRateLimitTime time.Time
type rateLimitHeaders struct {
Bucket string
Limit int
Remaining int
ResetAfter time.Duration
}
func parseRateLimitHeaders(header http.Header) (rateLimitHeaders, bool) {
var err error
bucket := header.Get("X-RateLimit-Bucket")
var limit int
var remaining int
var resetAfter time.Duration
limitStr := header.Get("X-RateLimit-Limit")
if limitStr != "" {
limit, err = strconv.Atoi(limitStr)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", limitStr).
Msg("failed to parse X-RateLimit-Limit header")
return rateLimitHeaders{}, false
}
}
remainingStr := header.Get("X-RateLimit-Remaining")
if remainingStr != "" {
remaining, err = strconv.Atoi(remainingStr)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", remainingStr).
Msg("failed to parse X-RateLimit-Remaining header")
return rateLimitHeaders{}, false
}
}
resetAfterStr := header.Get("X-RateLimit-Reset-After")
if resetAfterStr != "" {
resetAfterSeconds, err := strconv.ParseFloat(resetAfterStr, 64)
if err != nil {
limiterLog.Error().
Err(err).
Str("value", resetAfterStr).
Msg("failed to parse X-RateLimit-Reset-After header")
return rateLimitHeaders{}, false
}
resetAfter = time.Duration(math.Ceil(resetAfterSeconds)) * time.Second
}
return rateLimitHeaders{
Bucket: bucket,
Limit: limit,
Remaining: remaining,
ResetAfter: resetAfter,
}, true
}
func createLimiter(headers rateLimitHeaders, routeName string) {
limiterInitMutex.Lock()
defer limiterInitMutex.Unlock()
buckets.Store(routeName, headers.Bucket)
ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{
requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts
refills: make(chan rateLimiterRefill),
})
if !loaded {
limiter := ilimiter.(*restRateLimiter)
log := limiterLog.With().Str("bucket", headers.Bucket).Logger()
prefillloop:
// Pre-fill the limiter with remaining requests
for i := 0; i < headers.Remaining; i++ {
select {
case limiter.requests <- struct{}{}:
default:
log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity")
break prefillloop
}
}
/*
Start the refiller for this bucket. It waits for a request to tell
it when to next reset the rate limit, and how full to fill the bucket.
It then sleeps and refills the bucket, just like it should :)
*/
go func() {
for {
// Wake up on the first request after refilling
refill := <-limiter.refills
// Sleep for the remainder of the bucket's time
time.Sleep(refill.resetAfter)
drainloop:
// drain the bucket
for {
select {
case <-limiter.requests:
default:
break drainloop
}
}
refillloop:
// refill it with the max number of requests
for i := 0; i < refill.maxRequests; i++ {
select {
case limiter.requests <- struct{}{}:
default:
log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity")
break refillloop
}
}
// And then we wait again to hear about our next
// bucket's worth of requests.
}
}()
// Tell the refiller about its first refill
limiter.refills <- rateLimiterRefill{
resetAfter: headers.ResetAfter,
maxRequests: headers.Limit,
}
}
}
func (l *restRateLimiter) update(headers rateLimitHeaders) {
refill := rateLimiterRefill{
resetAfter: headers.ResetAfter,
maxRequests: headers.Limit,
}
/*
Tell the refiller about this request. If the refiller is already
busy sleeping, this will have no effect, which is what we want.
(It's already sleeping for as long as it needs to.)
*/
select {
case l.refills <- refill:
default:
}
}
func doWithRateLimiting(ctx context.Context, routeName string, getReq func(ctx context.Context) *http.Request) (*http.Response, error) {
var bucket string
ibucket, ok := buckets.Load(routeName)
if ok {
bucket = ibucket.(string)
}
for {
var limiter *restRateLimiter
if bucket != "" {
ilimiter, ok := rateLimiters.Load(bucket)
if ok {
limiter = ilimiter.(*restRateLimiter)
}
}
if globalRateLimitTime.After(time.Now()) {
// oh boy, global rate limit, pause until the coast is clear
err := utils.SleepContext(ctx, globalRateLimitTime.Sub(time.Now())+1*time.Second)
if err != nil {
return nil, err
}
}
if limiter != nil {
select {
case <-limiter.requests:
case <-ctx.Done():
return nil, errors.New("request interrupted during rate limiting")
}
}
res, err := httpClient.Do(getReq(ctx))
if err != nil {
return nil, err
}
headers, headersOk := parseRateLimitHeaders(res.Header)
if headersOk {
if limiter == nil || headers.Bucket != bucket {
createLimiter(headers, routeName)
} else {
limiter.update(headers)
}
}
if res.StatusCode == 429 {
if res.Header.Get("X-RateLimit-Global") != "" {
// globally rate limited
logging.ExtractLogger(ctx).Warn().Msg("got globally rate limited by Discord")
retryAfter, err := strconv.Atoi(res.Header.Get("Retry-After"))
if err == nil {
globalRateLimitTime = time.Now().Add(time.Duration(retryAfter) * time.Second)
} else {
// well this is bad, just sleep for 60 seconds and pray that it's long enough
logging.ExtractLogger(ctx).Warn().
Err(err).
Msg("got globally rate limited but couldn't determine how long to wait")
globalRateLimitTime = time.Now().Add(60 * time.Second)
}
} else {
// locally rate limited
/*
Despite our best efforts, we ended up rate limited anyway.
Simply wait the amount of time Discord asks, and then try
again. On the next go-around, hopefully we'll either succeed
or have a rate limiter initialized and ready to go.
*/
logging.ExtractLogger(ctx).Warn().Msg("got rate limited by Discord")
if headersOk {
err := utils.SleepContext(ctx, headers.ResetAfter)
if err != nil {
return nil, err
}
} else {
logging.ExtractLogger(ctx).Warn().Msg("got rate limited, but didn't have the headers??")
err := utils.SleepContext(ctx, 1*time.Second)
if err != nil {
return nil, err
}
}
}
continue
}
return res, nil
}
}

182
src/discord/rest.go Normal file
View File

@ -0,0 +1,182 @@
package discord
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httputil"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/oops"
)
const (
BotName = "HandmadeNetwork"
BaseUrl = "https://discord.com/api/v9"
UserAgentURL = "https://handmade.network/"
UserAgentVersion = "1.0"
)
var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersion)
var httpClient = &http.Client{}
func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request {
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewBuffer(body)
}
req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader)
if err != nil {
panic(err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bot %s", config.Config.Discord.BotToken))
req.Header.Add("User-Agent", UserAgent)
return req
}
type GetGatewayBotResponse struct {
URL string `json:"url"`
// We don't care about shards or session limit stuff; we will never hit those limits
}
func GetGatewayBot(ctx context.Context) (*GetGatewayBotResponse, error) {
const name = "Get Gateway Bot"
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
return makeRequest(ctx, http.MethodGet, "/gateway/bot", nil)
})
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
logErrorResponse(ctx, name, res, "")
return nil, oops.New(nil, "received error from Discord")
}
body, err := io.ReadAll(res.Body)
if err != nil {
panic(err)
}
var result GetGatewayBotResponse
err = json.Unmarshal(body, &result)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord response")
}
return &result, nil
}
type CreateMessageRequest struct {
Content string `json:"content"`
}
func CreateMessage(ctx context.Context, channelID string, payloadJSON string) (*Message, error) {
const name = "Create Message"
path := fmt.Sprintf("/channels/%s/messages", channelID)
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
req := makeRequest(ctx, http.MethodPost, path, []byte(payloadJSON))
req.Header.Add("Content-Type", "application/json")
return req
})
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
logErrorResponse(ctx, name, res, "")
return nil, oops.New(nil, "received error from Discord")
}
// Maybe in the future we could more nicely handle errors like "bad channel",
// but honestly what are the odds that we mess that up...
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
panic(err)
}
var msg Message
err = json.Unmarshal(bodyBytes, &msg)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord message")
}
return &msg, nil
}
func DeleteMessage(ctx context.Context, channelID string, messageID string) error {
const name = "Delete Message"
path := fmt.Sprintf("/channels/%s/messages/%s", channelID, messageID)
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
return makeRequest(ctx, http.MethodDelete, path, nil)
})
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
logErrorResponse(ctx, name, res, "")
return oops.New(nil, "got unexpected status code when deleting message")
}
return nil
}
func CreateDM(ctx context.Context, recipientID string) (*Channel, error) {
const name = "Create DM"
path := "/users/@me/channels"
body := []byte(fmt.Sprintf(`{"recipient_id":"%s"}`, recipientID))
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
req := makeRequest(ctx, http.MethodPost, path, body)
req.Header.Add("Content-Type", "application/json")
return req
})
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
logErrorResponse(ctx, name, res, "")
return nil, oops.New(nil, "received error from Discord")
}
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
panic(err)
}
var channel Channel
err = json.Unmarshal(bodyBytes, &channel)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord channel")
}
return &channel, nil
}
func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) {
dump, err := httputil.DumpResponse(res, true)
if err != nil {
panic(err)
}
logging.ExtractLogger(ctx).Error().Str("name", name).Msg(msg)
fmt.Println(string(dump))
}

115
src/discord/showcase.go Normal file
View File

@ -0,0 +1,115 @@
package discord
import (
"context"
"net/url"
"regexp"
"strings"
"git.handmade.network/hmn/hmn/src/oops"
)
var reDiscordMessageLink = regexp.MustCompile(`https?://.+?(\s|$)`)
func (bot *discordBotInstance) processShowcaseMsg(ctx context.Context, msg *Message) error {
switch msg.Type {
case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand:
default:
return nil
}
hasGoodContent := true
if originalMessageHasField(msg, "content") && !messageHasLinks(msg.Content) {
hasGoodContent = false
}
hasGoodAttachments := true
if originalMessageHasField(msg, "attachments") && len(msg.Attachments) == 0 {
hasGoodAttachments = false
}
if !hasGoodContent && !hasGoodAttachments {
err := DeleteMessage(ctx, msg.ChannelID, msg.ID)
if err != nil {
return oops.New(err, "failed to delete message")
}
if msg.Author != nil && !msg.Author.IsBot {
channel, err := CreateDM(ctx, msg.Author.ID)
if err != nil {
return oops.New(err, "failed to create DM channel")
}
err = SendMessages(ctx, bot.dbConn, MessageToSend{
ChannelID: channel.ID,
Req: CreateMessageRequest{
Content: "Posts in #project-showcase are required to have either an image/video or a link. Discuss showcase content in #projects.",
},
})
if err != nil {
return oops.New(err, "failed to send showcase warning message")
}
}
}
return nil
}
func (bot *discordBotInstance) processLibraryMsg(ctx context.Context, msg *Message) error {
switch msg.Type {
case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand:
default:
return nil
}
if !originalMessageHasField(msg, "content") {
return nil
}
if !messageHasLinks(msg.Content) {
err := DeleteMessage(ctx, msg.ChannelID, msg.ID)
if err != nil {
return oops.New(err, "failed to delete message")
}
if msg.Author != nil && !msg.Author.IsBot {
channel, err := CreateDM(ctx, msg.Author.ID)
if err != nil {
return oops.New(err, "failed to create DM channel")
}
err = SendMessages(ctx, bot.dbConn, MessageToSend{
ChannelID: channel.ID,
Req: CreateMessageRequest{
Content: "Posts in #the-library are required to have a link. Discuss library content in other relevant channels.",
},
})
if err != nil {
return oops.New(err, "failed to send showcase warning message")
}
}
}
return nil
}
func messageHasLinks(content string) bool {
links := reDiscordMessageLink.FindAllString(content, -1)
for _, link := range links {
_, err := url.Parse(strings.TrimSpace(link))
if err == nil {
return true
}
}
return false
}
func originalMessageHasField(msg *Message, field string) bool {
if msg.originalMap == nil {
return false
}
_, ok := msg.originalMap[field]
return ok
}

View File

@ -1,6 +1,7 @@
package logging package logging
import ( import (
"context"
"encoding/json" "encoding/json"
"os" "os"
"sort" "sort"
@ -16,7 +17,7 @@ import (
func init() { func init() {
zerolog.ErrorStackMarshaler = oops.ZerologStackMarshaler zerolog.ErrorStackMarshaler = oops.ZerologStackMarshaler
log.Logger = log.Output(NewPrettyZerologWriter()) log.Logger = log.Output(NewPrettyZerologWriter()).With().Stack().Logger()
zerolog.SetGlobalLevel(config.Config.LogLevel) zerolog.SetGlobalLevel(config.Config.LogLevel)
} }
@ -203,3 +204,17 @@ func LogPanicValue(logger *zerolog.Logger, val interface{}, msg string) {
Msg(msg) Msg(msg)
} }
} }
const LoggerContextKey = "logger"
func AttachLoggerToContext(logger *zerolog.Logger, ctx context.Context) context.Context {
return context.WithValue(ctx, LoggerContextKey, logger)
}
func ExtractLogger(ctx context.Context) *zerolog.Logger {
ilogger := ctx.Value(LoggerContextKey)
if ilogger == nil {
return GlobalLogger()
}
return ilogger.(*zerolog.Logger)
}

View File

@ -0,0 +1,61 @@
package migrations
import (
"context"
"time"
"git.handmade.network/hmn/hmn/src/migration/types"
"git.handmade.network/hmn/hmn/src/oops"
"github.com/jackc/pgx/v4"
)
func init() {
registerMigration(AddDiscordBotTables{})
}
type AddDiscordBotTables struct{}
func (m AddDiscordBotTables) Version() types.MigrationVersion {
return types.MigrationVersion(time.Date(2021, 8, 7, 18, 53, 30, 0, time.UTC))
}
func (m AddDiscordBotTables) Name() string {
return "AddDiscordBotTables"
}
func (m AddDiscordBotTables) Description() string {
return "Add tables for Discord bot sessions and messages"
}
func (m AddDiscordBotTables) Up(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, `
CREATE TABLE discord_session (
pk INT NOT NULL DEFAULT 1337 PRIMARY KEY, -- this should always be set to 1337 to ensure that we only have one row :)
session_id VARCHAR(255) NOT NULL,
sequence_number INT NOT NULL,
CONSTRAINT only_one_session CHECK (pk = 1337)
);
`)
if err != nil {
return oops.New(err, "failed to create discord session table")
}
_, err = tx.Exec(ctx, `
CREATE TABLE discord_outgoingmessages (
id SERIAL NOT NULL PRIMARY KEY,
channel_id VARCHAR(64) NOT NULL,
payload_json TEXT NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
`)
if err != nil {
return oops.New(err, "failed to create discord outgoing messages table")
}
return nil
}
func (m AddDiscordBotTables) Down(ctx context.Context, tx pgx.Tx) error {
panic("Implement me")
}

View File

@ -4,6 +4,18 @@ import (
"time" "time"
) )
type DiscordSession struct {
ID string `db:"session_id"`
SequenceNumber int `db:"sequence_number"`
}
type DiscordOutgoingMessage struct {
ID int `db:"id"`
ChannelID string `db:"channel_id"`
PayloadJSON string `db:"payload_json"`
ExpiresAt time.Time `db:"expires_at"`
}
type DiscordMessage struct { type DiscordMessage struct {
ID string `db:"id"` ID string `db:"id"`
ChannelID string `db:"channel_id"` ChannelID string `db:"channel_id"`

View File

@ -14,7 +14,11 @@ type Error struct {
} }
func (e *Error) Error() string { func (e *Error) Error() string {
if e.Wrapped == nil {
return e.Message
} else {
return fmt.Sprintf("%s: %v", e.Message, e.Wrapped) return fmt.Sprintf("%s: %v", e.Message, e.Wrapped)
}
} }
func (e *Error) Unwrap() error { func (e *Error) Unwrap() error {

View File

@ -1,5 +1,14 @@
package utils package utils
import (
"context"
"errors"
"fmt"
"time"
"git.handmade.network/hmn/hmn/src/oops"
)
func IntMin(a, b int) int { func IntMin(a, b int) int {
if a < b { if a < b {
return a return a
@ -17,3 +26,38 @@ func IntMax(a, b int) int {
func IntClamp(min, t, max int) int { func IntClamp(min, t, max int) int {
return IntMax(min, IntMin(t, max)) return IntMax(min, IntMin(t, max))
} }
/*
Recover a panic and convert it to a returned error. Call it like so:
func MyFunc() (err error) {
defer utils.RecoverPanicAsError(&err)
}
If an error was already present, the panicked error will take precedence. Unfortunately there's
no good way to include both errors because you can't really have two chains of errors and still
play nice with the standard library's Unwrap behavior. But most of the time this shouldn't be an
issue, since the panic will probably occur before a meaningful error value was set.
*/
func RecoverPanicAsError(err *error) {
if r := recover(); r != nil {
var recoveredErr error
if rerr, ok := r.(error); ok {
recoveredErr = rerr
} else {
recoveredErr = fmt.Errorf("panic with value: %v", r)
}
*err = oops.New(recoveredErr, "panic recovered as error")
}
}
var ErrSleepInterrupted = errors.New("sleep interrupted by context cancellation")
func SleepContext(ctx context.Context, d time.Duration) error {
select {
case <-ctx.Done():
return ErrSleepInterrupted
case <-time.After(d):
return nil
}
}

View File

@ -13,6 +13,7 @@ import (
"git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/discord"
"git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/perf" "git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/templates" "git.handmade.network/hmn/hmn/src/templates"
@ -42,6 +43,7 @@ var WebsiteCommand = &cobra.Command{
auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn),
auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn), auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn),
perfCollector.Done, perfCollector.Done,
discord.RunDiscordBot(backgroundJobContext, conn),
) )
signals := make(chan os.Signal, 1) signals := make(chan os.Signal, 1)
@ -52,7 +54,9 @@ var WebsiteCommand = &cobra.Command{
go func() { go func() {
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
logging.Info().Msg("shutting down web server")
server.Shutdown(timeout) server.Shutdown(timeout)
logging.Info().Msg("cancelling background jobs")
cancelBackgroundJobs() cancelBackgroundJobs()
}() }()