From 38a1188be7f8285b61171509849c31a9bfc85d6a Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Fri, 6 Aug 2021 18:23:51 -0500 Subject: [PATCH] Add Discord integration Clean up several TODOs Implement the full disconnect / resume flow Detect zombied connections and restart Implement the random delay on reconnect Implement message sending!! (with a goofy feedback loop on the echo bot) Fix the feedback loop in the echo bot Clean up the Discord gateway code Many things are methods now to reduce the amount of explicit plumbing. Connection handling should be a little more robust, and we have an actual error handling strategy now. Allow sending multiple Discord messages at once Delete irrelevant tests uhh, start rate limiting Add per-route rate limiting Add global rate limit handling Handle context cancellation in Discord REST code Allow changing buckets per route Add the showcase rejection bot Add library bot --- go.mod | 2 + go.sum | 3 + src/config/types.go | 9 + src/discord/gateway.go | 643 ++++++++++++++++++ src/discord/payloads.go | 292 ++++++++ src/discord/ratelimiting.go | 284 ++++++++ src/discord/rest.go | 182 +++++ src/discord/showcase.go | 115 ++++ src/logging/logging.go | 17 +- .../2021-08-07T185330Z_AddDiscordBotTables.go | 61 ++ src/models/discord.go | 12 + src/oops/oops.go | 6 +- src/utils/utils.go | 44 ++ src/website/website.go | 4 + 14 files changed, 1672 insertions(+), 2 deletions(-) create mode 100644 src/discord/gateway.go create mode 100644 src/discord/payloads.go create mode 100644 src/discord/ratelimiting.go create mode 100644 src/discord/rest.go create mode 100644 src/discord/showcase.go create mode 100644 src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go diff --git a/go.mod b/go.mod index f9b87fe..16fef9d 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,13 @@ require ( github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 github.com/go-stack/stack v1.8.0 github.com/google/uuid v1.2.0 + github.com/gorilla/websocket v1.4.2 github.com/huandu/xstrings v1.3.2 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/jackc/pgconn v1.8.0 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.10.1 + github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.1.1 // indirect github.com/rs/zerolog v1.21.0 github.com/spf13/cobra v1.1.3 diff --git a/go.sum b/go.sum index 241c838..474fba9 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,7 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= @@ -185,6 +186,8 @@ github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= diff --git a/src/config/types.go b/src/config/types.go index f76a88b..19485b9 100644 --- a/src/config/types.go +++ b/src/config/types.go @@ -25,6 +25,7 @@ type HMNConfig struct { Auth AuthConfig Email EmailConfig DigitalOcean DigitalOceanConfig + Discord DiscordConfig } type PostgresConfig struct { @@ -62,6 +63,14 @@ type EmailConfig struct { OverrideRecipientEmail string } +type DiscordConfig struct { + BotToken string + BotUserID string + + ShowcaseChannelID string + LibraryChannelID string +} + func (info PostgresConfig) DSN() string { return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName) } diff --git a/src/discord/gateway.go b/src/discord/gateway.go new file mode 100644 index 0000000..7cdfdac --- /dev/null +++ b/src/discord/gateway.go @@ -0,0 +1,643 @@ +package discord + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "runtime" + "sync" + "time" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/models" + "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/utils" + "github.com/gorilla/websocket" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/jpillora/backoff" +) + +func RunDiscordBot(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{} { + log := logging.ExtractLogger(ctx).With().Str("module", "discord").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + if config.Config.Discord.BotToken == "" { + log.Warn().Msg("No Discord bot token was provided, so the Discord bot cannot run.") + done := make(chan struct{}, 1) + done <- struct{}{} + return done + } + + done := make(chan struct{}) + go func() { + defer func() { + log.Debug().Msg("shut down Discord bot") + done <- struct{}{} + }() + + boff := backoff.Backoff{ + Min: 1 * time.Second, + Max: 5 * time.Minute, + } + + for { + select { + case <-ctx.Done(): + return + default: + } + + func() { + log.Info().Msg("Connecting to the Discord gateway") + bot := newBotInstance(dbConn) + err := bot.Run(ctx) + if err != nil { + dur := boff.Duration() + log.Error(). + Err(err). + Dur("retrying after", dur). + Msg("failed to run Discord bot") + + timer := time.NewTimer(dur) + select { + case <-ctx.Done(): + case <-timer.C: + } + + return + } + + select { + case <-ctx.Done(): + return + default: + } + + // This delay satisfies the 1 to 5 second delay Discord + // wants on reconnects, and seems fine to do every time. + delay := time.Duration(int64(time.Second) + rand.Int63n(int64(time.Second*4))) + log.Info().Dur("delay", delay).Msg("Reconnecting to Discord") + time.Sleep(delay) + + boff.Reset() + }() + } + }() + return done +} + +var outgoingMessagesReady = make(chan struct{}, 1) + +type discordBotInstance struct { + conn *websocket.Conn + dbConn *pgxpool.Pool + + heartbeatIntervalMs int + forceHeartbeat chan struct{} + + /* + Every time we send a heartbeat, we set this variable to false. + Whenever we ack a heartbeat, we set this variable to true. + If we try to send a heartbeat but the previous one was not + acked, then we close the connection and try to reconnect. + */ + didAckHeartbeat bool + + /* + All goroutines should call this when they exit, to ensure that + the other goroutines shut down as well. + */ + cancel context.CancelFunc + wg sync.WaitGroup +} + +func newBotInstance(dbConn *pgxpool.Pool) *discordBotInstance { + return &discordBotInstance{ + dbConn: dbConn, + forceHeartbeat: make(chan struct{}), + didAckHeartbeat: true, + } +} + +/* +Runs a bot instance to completion. It will start up a gateway connection and return when the +connection is closed. It only returns an error when something unexpected occurs; if so, you should +do exponential backoff before reconnecting. Otherwise you can reconnect right away. +*/ +func (bot *discordBotInstance) Run(ctx context.Context) (err error) { + defer utils.RecoverPanicAsError(&err) + + ctx, bot.cancel = context.WithCancel(ctx) + defer bot.cancel() + + err = bot.connect(ctx) + if err != nil { + return oops.New(err, "failed to connect to Discord gateway") + } + defer bot.conn.Close() + + bot.wg.Add(1) + go bot.doSender(ctx) + + // Wait for child goroutines to exit (they will do so when context is canceled). This ensures + // that nothing is in the middle of sending. Then close the connection, so that this goroutine + // can finish as well. + go func() { + bot.wg.Wait() + bot.conn.Close() + }() + + for { + msg, err := bot.receiveGatewayMessage(ctx) + if err != nil { + // TODO: Are there other kinds of connection close events that we need to handle? Probably? + if errors.Is(err, net.ErrClosed) { + // If the connection is closed, that's our cue to shut down the bot. Any errors + // related to the closure will have been logged elsewhere anyway. + return nil + } else { + return oops.New(err, "failed to receive message from the gateway") + } + } + + // Update the sequence number in the db + if msg.SequenceNumber != nil { + _, err = bot.dbConn.Exec(ctx, `UPDATE discord_session SET sequence_number = $1`, *msg.SequenceNumber) + if err != nil { + return oops.New(err, "failed to save latest sequence number") + } + } + + switch msg.Opcode { + case OpcodeDispatch: + // Just a normal event + err := bot.processEventMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process gateway event") + } + case OpcodeHeartbeat: + bot.forceHeartbeat <- struct{}{} + case OpcodeHeartbeatACK: + bot.didAckHeartbeat = true + case OpcodeReconnect: + logging.ExtractLogger(ctx).Info().Msg("Discord asked us to reconnect to the gateway") + return nil + case OpcodeInvalidSession: + // We tried to resume but the session was invalid. + // Delete the session and reconnect from scratch again. + _, err := bot.dbConn.Exec(ctx, `DELETE FROM discord_session`) + if err != nil { + return oops.New(err, "failed to delete invalid session") + } + return nil + } + } +} + +/* +The connection process in short: +- Gateway sends Hello, asking the client to heartbeat on some interval +- Client sends Identify and starts heartbeat process +- Gateway sends Ready, client is now connected to gateway + +Or, if we have an existing session: +- Gateway sends Hello, asking the client to heartbeat on some interval +- Client sends Resume and starts heartbeat process +- Gateway sends all missed events followed by a RESUMED event, or an Invalid Session if the + session is ded + +Note that some events probably won't be received until the Guild Create message is received. + +It's a little annoying to handle resumes since we want to handle the missed messages as if we were +receiving them in real time. But we're kind of in a different state from from when we're normally +receiving messages, because we are expecting a RESUMED event at the end, and the first message we +receive might be an Invalid Session. So, unfortunately, we just have to handle the Invalid Session +and RESUMED messages in our main message receiving loop instead of here. + +(Discord could have prevented this if they send a "Resume ACK" message before replaying events. +That way, we could receive exactly one message after sending Resume, either a Resume ACK or an +Invalid Session, and from there it would be crystal clear what to do. Alas!) +*/ +func (bot *discordBotInstance) connect(ctx context.Context) error { + res, err := GetGatewayBot(ctx) + if err != nil { + return oops.New(err, "failed to get gateway URL") + } + + conn, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s/?v=9&encoding=json", res.URL), nil) + if err != nil { + return oops.New(err, "failed to connect to the Discord gateway") + } + bot.conn = conn + + helloMessage, err := bot.receiveGatewayMessage(ctx) + if err != nil { + return oops.New(err, "failed to read Hello message") + } + if helloMessage.Opcode != OpcodeHello { + return oops.New(nil, "expected a Hello (opcode %d), but got opcode %d", OpcodeHello, helloMessage.Opcode) + } + helloData := HelloFromMap(helloMessage.Data) + bot.heartbeatIntervalMs = helloData.HeartbeatIntervalMs + + // Now that the gateway has said hello, we need to establish a new session, either resuming + // an old one or starting a new one. + + shouldResume := true + isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) + if err != nil { + if errors.Is(err, db.ErrNoMatchingRows) { + // No session yet! Just identify and get on with it + shouldResume = false + } else { + return oops.New(err, "failed to get current session from database") + } + } + + if shouldResume { + // Reconnect to the previous session + session := isession.(*models.DiscordSession) + + err := bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeResume, + Data: Resume{ + Token: config.Config.Discord.BotToken, + SessionID: session.ID, + SequenceNumber: session.SequenceNumber, + }, + }) + if err != nil { + return oops.New(err, "failed to send Resume message") + } + + return nil + } else { + // Start a new session + err := bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeIdentify, + Data: Identify{ + Token: config.Config.Discord.BotToken, + Properties: IdentifyConnectionProperties{ + OS: runtime.GOOS, + Browser: BotName, + Device: BotName, + }, + Intents: IntentGuilds | IntentGuildMessages, + }, + }) + if err != nil { + return oops.New(err, "failed to send Identify message") + } + + readyMessage, err := bot.receiveGatewayMessage(ctx) + if err != nil { + return oops.New(err, "failed to read Ready message") + } + if readyMessage.Opcode != OpcodeDispatch { + return oops.New(err, "expected a READY event, but got a message with opcode %d", readyMessage.Opcode) + } + if *readyMessage.EventName != "READY" { + return oops.New(err, "expected a READY event, but got a %s event", *readyMessage.EventName) + } + readyData := ReadyFromMap(readyMessage.Data) + + _, err = bot.dbConn.Exec(ctx, + ` + INSERT INTO discord_session (session_id, sequence_number) + VALUES ($1, $2) + ON CONFLICT (pk) DO UPDATE + SET session_id = $1, sequence_number = $2 + `, + readyData.SessionID, + *readyMessage.SequenceNumber, + ) + if err != nil { + return oops.New(err, "failed to save new bot session in the database") + } + } + + return nil +} + +/* +Sends outgoing gateway messages and channel messages. Handles heartbeats. This function should be +run as its own goroutine. +*/ +func (bot *discordBotInstance) doSender(ctx context.Context) { + defer bot.wg.Done() + defer bot.cancel() + + log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + defer log.Info().Msg("shutting down Discord sender") + + /* + The first heartbeat is supposed to occur at a random time within + the first heartbeat interval. + + https://discord.com/developers/docs/topics/gateway#heartbeating + */ + dur := time.Duration(bot.heartbeatIntervalMs) * time.Millisecond + firstDelay := time.NewTimer(time.Duration(rand.Int63n(int64(dur)))) + heartbeatTicker := &time.Ticker{} // this will start never ticking, and get initialized after the first heartbeat + + // Returns false if the heartbeat failed + sendHeartbeat := func() bool { + if !bot.didAckHeartbeat { + log.Error().Msg("did not receive a heartbeat ACK in between heartbeats") + return false + } + bot.didAckHeartbeat = false + + latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) + if err != nil { + log.Error().Err(err).Msg("failed to fetch latest sequence number from the db") + return false + } + + err = bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeHeartbeat, + Data: latestSequenceNumber, + }) + if err != nil { + log.Error().Err(err).Msg("failed to send heartbeat") + return false + } + + return true + } + + /* + Start a goroutine to fetch outgoing messages from the db. We do this in a separate goroutine + to ensure that issues talking to the database don't prevent us from sending heartbeats. + */ + messages := make(chan *models.DiscordOutgoingMessage) + bot.wg.Add(1) + go func(ctx context.Context) { + defer bot.wg.Done() + defer bot.cancel() + + log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender db reader").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + defer log.Info().Msg("stopping db reader") + + // We will poll the database just in case the notification mechanism doesn't work. + ticker := time.NewTicker(time.Second * 5) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + case <-outgoingMessagesReady: + } + + func() { + tx, err := bot.dbConn.Begin(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to start transaction") + return + } + defer tx.Rollback(ctx) + + itMessages, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` + SELECT $columns + FROM discord_outgoingmessages + ORDER BY id ASC + `) + if err != nil { + log.Error().Err(err).Msg("failed to fetch outgoing Discord messages") + return + } + + msgs := itMessages.ToSlice() + for _, imsg := range msgs { + msg := imsg.(*models.DiscordOutgoingMessage) + if time.Now().After(msg.ExpiresAt) { + continue + } + messages <- msg + } + + /* + NOTE(ben): Doing this in a transaction means that we will only delete the + messages that we originally fetched. At least, as long as the database's + isolation level is Read Committed, which is the default. + + https://www.postgresql.org/docs/current/transaction-iso.html + */ + _, err = tx.Exec(ctx, `DELETE FROM discord_outgoingmessages`) + if err != nil { + log.Error().Err(err).Msg("failed to delete outgoing messages") + return + } + + err = tx.Commit(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to read and delete outgoing messages") + return + } + + if len(msgs) > 0 { + log.Debug().Int("num messages", len(msgs)).Msg("Sent and deleted outgoing messages") + } + }() + } + }(ctx) + + /* + Whenever we want to send a gateway message, we must receive a value from + this channel first. A goroutine continuously fills the channel at a rate + that respects Discord's gateway rate limit. + + Don't use this for heartbeats; heartbeats should go out immediately. + Don't forget that the server can request a heartbeat at any time. + + See the docs for more details. The capacity of this channel is chosen to + always leave us overhead for heartbeats and other shenanigans. + + https://discord.com/developers/docs/topics/gateway#rate-limiting + */ + rateLimiter := make(chan struct{}, 100) + go func() { + for { + rateLimiter <- struct{}{} + time.Sleep(500 * time.Millisecond) + } + }() + /* + NOTE(ben): This rate limiter is actually not used right now + because we're not actually sending any meaningful gateway + messages. But in the future, if we end up sending presence + updates or other gateway commands, we need to make sure to + put this limiter on all of those outgoing commands. + */ + + for { + select { + case <-ctx.Done(): + return + case <-firstDelay.C: + if ok := sendHeartbeat(); !ok { + return + } + heartbeatTicker = time.NewTicker(dur) + case <-heartbeatTicker.C: + if ok := sendHeartbeat(); !ok { + return + } + case <-bot.forceHeartbeat: + if ok := sendHeartbeat(); !ok { + return + } + heartbeatTicker.Reset(dur) + case msg := <-messages: + _, err := CreateMessage(ctx, msg.ChannelID, msg.PayloadJSON) + if err != nil { + log.Error().Err(err).Msg("failed to send Discord message") + } + } + } +} + +func (bot *discordBotInstance) receiveGatewayMessage(ctx context.Context) (*GatewayMessage, error) { + _, msgBytes, err := bot.conn.ReadMessage() + if err != nil { + return nil, err + } + + var msg GatewayMessage + err = json.Unmarshal(msgBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord gateway message") + } + + logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("received gateway message") + + return &msg, nil +} + +func (bot *discordBotInstance) sendGatewayMessage(ctx context.Context, msg GatewayMessage) error { + logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("sending gateway message") + return bot.conn.WriteMessage(websocket.TextMessage, msg.ToJSON()) +} + +/* +Processes a single event message from Discord. If this returns an error, it means something has +really gone wrong, bad enough that the connection should be shut down. Otherwise it will just log +any errors that occur. +*/ +func (bot *discordBotInstance) processEventMsg(ctx context.Context, msg *GatewayMessage) error { + if msg.Opcode != OpcodeDispatch { + panic(fmt.Sprintf("processEventMsg must only be used on Dispatch messages (opcode %d). Validate this before you call this function.", OpcodeDispatch)) + } + + switch *msg.EventName { + case "RESUMED": + // Nothing to do, but at least we can log something + logging.ExtractLogger(ctx).Info().Msg("Finished resuming gateway session") + case "MESSAGE_CREATE": + newMessage := MessageFromMap(msg.Data) + + err := bot.messageCreateOrUpdate(ctx, &newMessage) + if err != nil { + return oops.New(err, "error on new message") + } + case "MESSAGE_UPDATE": + newMessage := MessageFromMap(msg.Data) + + err := bot.messageCreateOrUpdate(ctx, &newMessage) + if err != nil { + return oops.New(err, "error on updated message") + } + } + + return nil +} + +func (bot *discordBotInstance) messageCreateOrUpdate(ctx context.Context, msg *Message) error { + if msg.Author != nil && msg.Author.ID == config.Config.Discord.BotUserID { + // Don't process your own messages + return nil + } + + if msg.ChannelID == config.Config.Discord.ShowcaseChannelID { + err := bot.processShowcaseMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process showcase message") + } + return nil + } + + if msg.ChannelID == config.Config.Discord.LibraryChannelID { + err := bot.processLibraryMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process library message") + } + return nil + } + + return nil +} + +type MessageToSend struct { + ChannelID string + Req CreateMessageRequest + ExpiresAt time.Time +} + +func SendMessages( + ctx context.Context, + conn *pgxpool.Pool, + msgs ...MessageToSend, +) error { + tx, err := conn.Begin(ctx) + if err != nil { + return oops.New(err, "failed to start transaction") + } + defer tx.Rollback(ctx) + + for _, msg := range msgs { + if msg.ExpiresAt.IsZero() { + msg.ExpiresAt = time.Now().Add(30 * time.Second) + } + + reqBytes, err := json.Marshal(msg.Req) + if err != nil { + return oops.New(err, "failed to marshal Discord message to JSON") + } + + _, err = tx.Exec(ctx, + ` + INSERT INTO discord_outgoingmessages (channel_id, payload_json, expires_at) + VALUES ($1, $2, $3) + `, + msg.ChannelID, + string(reqBytes), + msg.ExpiresAt, + ) + if err != nil { + return oops.New(err, "failed to save outgoing Discord message to the database") + } + } + + err = tx.Commit(ctx) + if err != nil { + return oops.New(err, "failed to commit outgoing Discord messages") + } + + // Notify the sender that messages are ready to go + select { + case outgoingMessagesReady <- struct{}{}: + default: + } + + return nil +} diff --git a/src/discord/payloads.go b/src/discord/payloads.go new file mode 100644 index 0000000..4b5870e --- /dev/null +++ b/src/discord/payloads.go @@ -0,0 +1,292 @@ +package discord + +import ( + "encoding/json" +) + +type Opcode int + +// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes +// NOTE(ben): I'm not using iota because 5 is missing +const ( + OpcodeDispatch Opcode = 0 + OpcodeHeartbeat Opcode = 1 + OpcodeIdentify Opcode = 2 + OpcodePresenceUpdate Opcode = 3 + OpcodeVoiceStateUpdate Opcode = 4 + OpcodeResume Opcode = 6 + OpcodeReconnect Opcode = 7 + OpcodeRequestGuildMembers Opcode = 8 + OpcodeInvalidSession Opcode = 9 + OpcodeHello Opcode = 10 + OpcodeHeartbeatACK Opcode = 11 +) + +type Intent int + +// https://discord.com/developers/docs/topics/gateway#list-of-intents +// NOTE(ben): I'm not using iota because the opcode thing made me paranoid +const ( + IntentGuilds Intent = 1 << 0 + IntentGuildMembers Intent = 1 << 1 + IntentGuildBans Intent = 1 << 2 + IntentGuildEmojisAndStickers Intent = 1 << 3 + IntentGuildIntegrations Intent = 1 << 4 + IntentGuildWebhooks Intent = 1 << 5 + IntentGuildInvites Intent = 1 << 6 + IntentGuildVoiceStates Intent = 1 << 7 + IntentGuildPresences Intent = 1 << 8 + IntentGuildMessages Intent = 1 << 9 + IntentGuildMessageReactions Intent = 1 << 10 + IntentGuildMessageTyping Intent = 1 << 11 + IntentDirectMessages Intent = 1 << 12 + IntentDirectMessageReactions Intent = 1 << 13 + IntentDirectMessageTyping Intent = 1 << 14 +) + +type GatewayMessage struct { + Opcode Opcode `json:"op"` + Data interface{} `json:"d"` + SequenceNumber *int `json:"s,omitempty"` + EventName *string `json:"t,omitempty"` +} + +func (m *GatewayMessage) ToJSON() []byte { + mBytes, err := json.Marshal(m) + if err != nil { + panic(err) + } + + // TODO: check if the payload is too big, either here or where we actually send + // https://discord.com/developers/docs/topics/gateway#sending-payloads + + return mBytes +} + +type Hello struct { + HeartbeatIntervalMs int `json:"heartbeat_interval"` +} + +func HelloFromMap(m interface{}) Hello { + // TODO: This should probably have some error handling, right? + return Hello{ + HeartbeatIntervalMs: int(m.(map[string]interface{})["heartbeat_interval"].(float64)), + } +} + +type Identify struct { + Token string `json:"token"` + Properties IdentifyConnectionProperties `json:"properties"` + Intents Intent `json:"intents"` +} + +type IdentifyConnectionProperties struct { + OS string `json:"$os"` + Browser string `json:"$browser"` + Device string `json:"$device"` +} + +type Ready struct { + GatewayVersion int `json:"v"` + User User `json:"user"` + SessionID string `json:"session_id"` +} + +func ReadyFromMap(m interface{}) Ready { + mmap := m.(map[string]interface{}) + + return Ready{ + GatewayVersion: int(mmap["v"].(float64)), + User: UserFromMap(mmap["user"]), + SessionID: mmap["session_id"].(string), + } +} + +type Resume struct { + Token string `json:"token"` + SessionID string `json:"session_id"` + SequenceNumber int `json:"seq"` +} + +type ChannelType int + +const ( + ChannelTypeGuildext ChannelType = 0 + ChannelTypeDM ChannelType = 1 + ChannelTypeGuildVoice ChannelType = 2 + ChannelTypeGroupDM ChannelType = 3 + ChannelTypeGuildCategory ChannelType = 4 + ChannelTypeGuildNews ChannelType = 5 + ChannelTypeGuildStore ChannelType = 6 + ChannelTypeGuildNewsThread ChannelType = 10 + ChannelTypeGuildPublicThread ChannelType = 11 + ChannelTypeGuildPrivateThread ChannelType = 12 + ChannelTypeGuildStageVoice ChannelType = 13 +) + +type Channel struct { + ID string `json:"id"` + Type ChannelType `json:"type"` + GuildID string `json:"guild_id"` + Name string `json:"name"` + Receipients []User `json:"recipients"` + OwnerID User `json:"owner_id"` + ParentID *string `json:"parent_id"` +} + +type MessageType int + +const ( + MessageTypeDefault MessageType = 0 + + MessageTypeRecipientAdd MessageType = 1 + MessageTypeRecipientRemove MessageType = 2 + MessageTypeCall MessageType = 3 + + MessageTypeChannelNameChange MessageType = 4 + MessageTypeChannelIconChange MessageType = 5 + MessageTypeChannelPinnedMessage MessageType = 6 + + MessageTypeGuildMemberJoin MessageType = 7 + + MessageTypeUserPremiumGuildSubscription MessageType = 8 + MessageTypeUserPremiumGuildSubscriptionTier1 MessageType = 9 + MessageTypeUserPremiumGuildSubscriptionTier2 MessageType = 10 + MessageTypeUserPremiumGuildSubscriptionTier3 MessageType = 11 + + MessageTypeChannelFollowAdd MessageType = 12 + + MessageTypeGuildDiscoveryDisqualified MessageType = 14 + MessageTypeGuildDiscoveryRequalified MessageType = 15 + MessageTypeGuildDiscoveryGracePeriodInitialWarning MessageType = 16 + MessageTypeGuildDiscoveryGracePeriodFinalWarning MessageType = 17 + + MessageTypeThreadCreated MessageType = 18 + MessageTypeReply MessageType = 19 + MessageTypeApplicationCommand MessageType = 20 + MessageTypeThreadStarterMessage MessageType = 21 + MessageTypeGuildInviteReminder MessageType = 22 +) + +// https://discord.com/developers/docs/resources/channel#message-object +type Message struct { + ID string `json:"id"` + ChannelID string `json:"channel_id"` + Content string `json:"content"` + Author *User `json:"author"` // note that this may not be an actual valid user (see the docs) + // TODO: Author info + // TODO: Timestamp parsing, yay + Type MessageType `json:"type"` + + Attachments []Attachment `json:"attachments"` + + originalMap map[string]interface{} +} + +func MessageFromMap(m interface{}) Message { + /* + Some gateway events, like MESSAGE_UPDATE, do not contain the + entire message body. So we need to be defensive on all fields here, + except the most basic identifying information. + */ + + mmap := m.(map[string]interface{}) + msg := Message{ + ID: mmap["id"].(string), + ChannelID: mmap["channel_id"].(string), + Content: maybeString(mmap, "content"), + Type: MessageType(maybeInt(mmap, "type")), + + originalMap: mmap, + } + + if author, ok := mmap["author"]; ok { + u := UserFromMap(author) + msg.Author = &u + } + + if iattachments, ok := mmap["attachments"]; ok { + attachments := iattachments.([]interface{}) + for _, iattachment := range attachments { + msg.Attachments = append(msg.Attachments, AttachmentFromMap(iattachment)) + } + } + + return msg +} + +// https://discord.com/developers/docs/resources/user#user-object +type User struct { + ID string `json:"id"` + Username string `json:"username"` + Discriminator string `json:"discriminator"` + IsBot bool `json:"bot"` +} + +func UserFromMap(m interface{}) User { + mmap := m.(map[string]interface{}) + + u := User{ + ID: mmap["id"].(string), + Username: mmap["username"].(string), + Discriminator: mmap["discriminator"].(string), + } + + if isBot, ok := mmap["bot"]; ok { + u.IsBot = isBot.(bool) + } + + return u +} + +type Attachment struct { + ID string `json:"id"` + Filename string `json:"filename"` + ContentType string `json:"content_type"` + Size int `json:"size"` + Url string `json:"url"` + ProxyUrl string `json:"proxy_url"` + Height *int `json:"height"` + Width *int `json:"width"` +} + +func AttachmentFromMap(m interface{}) Attachment { + mmap := m.(map[string]interface{}) + a := Attachment{ + ID: mmap["id"].(string), + Filename: mmap["filename"].(string), + ContentType: maybeString(mmap, "content_type"), + Size: int(mmap["size"].(float64)), + Url: mmap["url"].(string), + ProxyUrl: mmap["proxy_url"].(string), + Height: maybeIntP(mmap, "height"), + Width: maybeIntP(mmap, "width"), + } + + return a +} + +func maybeString(m map[string]interface{}, k string) string { + val, ok := m[k] + if !ok { + return "" + } + return val.(string) +} + +func maybeInt(m map[string]interface{}, k string) int { + val, ok := m[k] + if !ok { + return 0 + } + return int(val.(float64)) +} + +func maybeIntP(m map[string]interface{}, k string) *int { + val, ok := m[k] + if !ok { + return nil + } + intval := int(val.(float64)) + return &intval +} diff --git a/src/discord/ratelimiting.go b/src/discord/ratelimiting.go new file mode 100644 index 0000000..8d74e1d --- /dev/null +++ b/src/discord/ratelimiting.go @@ -0,0 +1,284 @@ +package discord + +import ( + "context" + "errors" + "math" + "net/http" + "strconv" + "sync" + "time" + + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/utils" +) + +var limiterLog = logging.GlobalLogger().With(). + Str("module", "discord"). + Str("discord actor", "rate limiter"). + Logger() + +var buckets sync.Map // map[route name]bucket name +var rateLimiters sync.Map // map[bucket name]*restRateLimiter +var limiterInitMutex sync.Mutex + +type restRateLimiter struct { + requests chan struct{} + refills chan rateLimiterRefill +} + +type rateLimiterRefill struct { + resetAfter time.Duration + maxRequests int +} + +/* +Whenever we send a request, we must sleep until this time +(if it is in the future, of course). This is a quick and +dirty way to pause all sending in case of a global rate +limit. + +I could put a mutex on this but I don't think it's actually +a problem to have race conditions here. Just set it when +you get throttled. EZ. +*/ +var globalRateLimitTime time.Time + +type rateLimitHeaders struct { + Bucket string + Limit int + Remaining int + ResetAfter time.Duration +} + +func parseRateLimitHeaders(header http.Header) (rateLimitHeaders, bool) { + var err error + + bucket := header.Get("X-RateLimit-Bucket") + var limit int + var remaining int + var resetAfter time.Duration + + limitStr := header.Get("X-RateLimit-Limit") + if limitStr != "" { + limit, err = strconv.Atoi(limitStr) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", limitStr). + Msg("failed to parse X-RateLimit-Limit header") + return rateLimitHeaders{}, false + } + } + + remainingStr := header.Get("X-RateLimit-Remaining") + if remainingStr != "" { + remaining, err = strconv.Atoi(remainingStr) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", remainingStr). + Msg("failed to parse X-RateLimit-Remaining header") + return rateLimitHeaders{}, false + } + } + + resetAfterStr := header.Get("X-RateLimit-Reset-After") + if resetAfterStr != "" { + resetAfterSeconds, err := strconv.ParseFloat(resetAfterStr, 64) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", resetAfterStr). + Msg("failed to parse X-RateLimit-Reset-After header") + return rateLimitHeaders{}, false + } + resetAfter = time.Duration(math.Ceil(resetAfterSeconds)) * time.Second + } + + return rateLimitHeaders{ + Bucket: bucket, + Limit: limit, + Remaining: remaining, + ResetAfter: resetAfter, + }, true +} + +func createLimiter(headers rateLimitHeaders, routeName string) { + limiterInitMutex.Lock() + defer limiterInitMutex.Unlock() + + buckets.Store(routeName, headers.Bucket) + ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{ + requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts + refills: make(chan rateLimiterRefill), + }) + if !loaded { + limiter := ilimiter.(*restRateLimiter) + + log := limiterLog.With().Str("bucket", headers.Bucket).Logger() + + prefillloop: + // Pre-fill the limiter with remaining requests + for i := 0; i < headers.Remaining; i++ { + select { + case limiter.requests <- struct{}{}: + default: + log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + break prefillloop + } + } + + /* + Start the refiller for this bucket. It waits for a request to tell + it when to next reset the rate limit, and how full to fill the bucket. + It then sleeps and refills the bucket, just like it should :) + */ + go func() { + for { + // Wake up on the first request after refilling + refill := <-limiter.refills + + // Sleep for the remainder of the bucket's time + time.Sleep(refill.resetAfter) + + drainloop: + // drain the bucket + for { + select { + case <-limiter.requests: + default: + break drainloop + } + } + + refillloop: + // refill it with the max number of requests + for i := 0; i < refill.maxRequests; i++ { + select { + case limiter.requests <- struct{}{}: + default: + log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + break refillloop + } + } + + // And then we wait again to hear about our next + // bucket's worth of requests. + } + }() + + // Tell the refiller about its first refill + limiter.refills <- rateLimiterRefill{ + resetAfter: headers.ResetAfter, + maxRequests: headers.Limit, + } + } +} + +func (l *restRateLimiter) update(headers rateLimitHeaders) { + refill := rateLimiterRefill{ + resetAfter: headers.ResetAfter, + maxRequests: headers.Limit, + } + + /* + Tell the refiller about this request. If the refiller is already + busy sleeping, this will have no effect, which is what we want. + (It's already sleeping for as long as it needs to.) + */ + select { + case l.refills <- refill: + default: + } +} + +func doWithRateLimiting(ctx context.Context, routeName string, getReq func(ctx context.Context) *http.Request) (*http.Response, error) { + var bucket string + ibucket, ok := buckets.Load(routeName) + if ok { + bucket = ibucket.(string) + } + + for { + var limiter *restRateLimiter + if bucket != "" { + ilimiter, ok := rateLimiters.Load(bucket) + if ok { + limiter = ilimiter.(*restRateLimiter) + } + } + + if globalRateLimitTime.After(time.Now()) { + // oh boy, global rate limit, pause until the coast is clear + err := utils.SleepContext(ctx, globalRateLimitTime.Sub(time.Now())+1*time.Second) + if err != nil { + return nil, err + } + } + + if limiter != nil { + select { + case <-limiter.requests: + case <-ctx.Done(): + return nil, errors.New("request interrupted during rate limiting") + } + } + + res, err := httpClient.Do(getReq(ctx)) + if err != nil { + return nil, err + } + + headers, headersOk := parseRateLimitHeaders(res.Header) + if headersOk { + if limiter == nil || headers.Bucket != bucket { + createLimiter(headers, routeName) + } else { + limiter.update(headers) + } + } + + if res.StatusCode == 429 { + if res.Header.Get("X-RateLimit-Global") != "" { + // globally rate limited + logging.ExtractLogger(ctx).Warn().Msg("got globally rate limited by Discord") + retryAfter, err := strconv.Atoi(res.Header.Get("Retry-After")) + if err == nil { + globalRateLimitTime = time.Now().Add(time.Duration(retryAfter) * time.Second) + } else { + // well this is bad, just sleep for 60 seconds and pray that it's long enough + logging.ExtractLogger(ctx).Warn(). + Err(err). + Msg("got globally rate limited but couldn't determine how long to wait") + globalRateLimitTime = time.Now().Add(60 * time.Second) + } + } else { + // locally rate limited + + /* + Despite our best efforts, we ended up rate limited anyway. + Simply wait the amount of time Discord asks, and then try + again. On the next go-around, hopefully we'll either succeed + or have a rate limiter initialized and ready to go. + */ + logging.ExtractLogger(ctx).Warn().Msg("got rate limited by Discord") + if headersOk { + err := utils.SleepContext(ctx, headers.ResetAfter) + if err != nil { + return nil, err + } + } else { + logging.ExtractLogger(ctx).Warn().Msg("got rate limited, but didn't have the headers??") + err := utils.SleepContext(ctx, 1*time.Second) + if err != nil { + return nil, err + } + } + } + continue + } + + return res, nil + } +} diff --git a/src/discord/rest.go b/src/discord/rest.go new file mode 100644 index 0000000..506c785 --- /dev/null +++ b/src/discord/rest.go @@ -0,0 +1,182 @@ +package discord + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httputil" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/oops" +) + +const ( + BotName = "HandmadeNetwork" + BaseUrl = "https://discord.com/api/v9" + + UserAgentURL = "https://handmade.network/" + UserAgentVersion = "1.0" +) + +var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersion) + +var httpClient = &http.Client{} + +func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request { + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewBuffer(body) + } + + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader) + if err != nil { + panic(err) + } + req.Header.Add("Authorization", fmt.Sprintf("Bot %s", config.Config.Discord.BotToken)) + req.Header.Add("User-Agent", UserAgent) + + return req +} + +type GetGatewayBotResponse struct { + URL string `json:"url"` + // We don't care about shards or session limit stuff; we will never hit those limits +} + +func GetGatewayBot(ctx context.Context) (*GetGatewayBotResponse, error) { + const name = "Get Gateway Bot" + + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodGet, "/gateway/bot", nil) + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != 200 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + body, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var result GetGatewayBotResponse + err = json.Unmarshal(body, &result) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord response") + } + + return &result, nil +} + +type CreateMessageRequest struct { + Content string `json:"content"` +} + +func CreateMessage(ctx context.Context, channelID string, payloadJSON string) (*Message, error) { + const name = "Create Message" + + path := fmt.Sprintf("/channels/%s/messages", channelID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodPost, path, []byte(payloadJSON)) + req.Header.Add("Content-Type", "application/json") + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + // Maybe in the future we could more nicely handle errors like "bad channel", + // but honestly what are the odds that we mess that up... + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var msg Message + err = json.Unmarshal(bodyBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord message") + } + + return &msg, nil +} + +func DeleteMessage(ctx context.Context, channelID string, messageID string) error { + const name = "Delete Message" + + path := fmt.Sprintf("/channels/%s/messages/%s", channelID, messageID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodDelete, path, nil) + }) + if err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + logErrorResponse(ctx, name, res, "") + return oops.New(nil, "got unexpected status code when deleting message") + } + + return nil +} + +func CreateDM(ctx context.Context, recipientID string) (*Channel, error) { + const name = "Create DM" + + path := "/users/@me/channels" + body := []byte(fmt.Sprintf(`{"recipient_id":"%s"}`, recipientID)) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodPost, path, body) + req.Header.Add("Content-Type", "application/json") + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var channel Channel + err = json.Unmarshal(bodyBytes, &channel) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord channel") + } + + return &channel, nil +} + +func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { + dump, err := httputil.DumpResponse(res, true) + if err != nil { + panic(err) + } + + logging.ExtractLogger(ctx).Error().Str("name", name).Msg(msg) + fmt.Println(string(dump)) +} diff --git a/src/discord/showcase.go b/src/discord/showcase.go new file mode 100644 index 0000000..6dacb16 --- /dev/null +++ b/src/discord/showcase.go @@ -0,0 +1,115 @@ +package discord + +import ( + "context" + "net/url" + "regexp" + "strings" + + "git.handmade.network/hmn/hmn/src/oops" +) + +var reDiscordMessageLink = regexp.MustCompile(`https?://.+?(\s|$)`) + +func (bot *discordBotInstance) processShowcaseMsg(ctx context.Context, msg *Message) error { + switch msg.Type { + case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand: + default: + return nil + } + + hasGoodContent := true + if originalMessageHasField(msg, "content") && !messageHasLinks(msg.Content) { + hasGoodContent = false + } + + hasGoodAttachments := true + if originalMessageHasField(msg, "attachments") && len(msg.Attachments) == 0 { + hasGoodAttachments = false + } + + if !hasGoodContent && !hasGoodAttachments { + err := DeleteMessage(ctx, msg.ChannelID, msg.ID) + if err != nil { + return oops.New(err, "failed to delete message") + } + + if msg.Author != nil && !msg.Author.IsBot { + channel, err := CreateDM(ctx, msg.Author.ID) + if err != nil { + return oops.New(err, "failed to create DM channel") + } + + err = SendMessages(ctx, bot.dbConn, MessageToSend{ + ChannelID: channel.ID, + Req: CreateMessageRequest{ + Content: "Posts in #project-showcase are required to have either an image/video or a link. Discuss showcase content in #projects.", + }, + }) + if err != nil { + return oops.New(err, "failed to send showcase warning message") + } + } + } + + return nil +} + +func (bot *discordBotInstance) processLibraryMsg(ctx context.Context, msg *Message) error { + switch msg.Type { + case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand: + default: + return nil + } + + if !originalMessageHasField(msg, "content") { + return nil + } + + if !messageHasLinks(msg.Content) { + err := DeleteMessage(ctx, msg.ChannelID, msg.ID) + if err != nil { + return oops.New(err, "failed to delete message") + } + + if msg.Author != nil && !msg.Author.IsBot { + channel, err := CreateDM(ctx, msg.Author.ID) + if err != nil { + return oops.New(err, "failed to create DM channel") + } + + err = SendMessages(ctx, bot.dbConn, MessageToSend{ + ChannelID: channel.ID, + Req: CreateMessageRequest{ + Content: "Posts in #the-library are required to have a link. Discuss library content in other relevant channels.", + }, + }) + if err != nil { + return oops.New(err, "failed to send showcase warning message") + } + } + } + + return nil +} + +func messageHasLinks(content string) bool { + links := reDiscordMessageLink.FindAllString(content, -1) + for _, link := range links { + _, err := url.Parse(strings.TrimSpace(link)) + if err == nil { + return true + } + } + + return false +} + +func originalMessageHasField(msg *Message, field string) bool { + if msg.originalMap == nil { + return false + } + + _, ok := msg.originalMap[field] + return ok +} diff --git a/src/logging/logging.go b/src/logging/logging.go index 8b1ed9b..243765f 100644 --- a/src/logging/logging.go +++ b/src/logging/logging.go @@ -1,6 +1,7 @@ package logging import ( + "context" "encoding/json" "os" "sort" @@ -16,7 +17,7 @@ import ( func init() { zerolog.ErrorStackMarshaler = oops.ZerologStackMarshaler - log.Logger = log.Output(NewPrettyZerologWriter()) + log.Logger = log.Output(NewPrettyZerologWriter()).With().Stack().Logger() zerolog.SetGlobalLevel(config.Config.LogLevel) } @@ -203,3 +204,17 @@ func LogPanicValue(logger *zerolog.Logger, val interface{}, msg string) { Msg(msg) } } + +const LoggerContextKey = "logger" + +func AttachLoggerToContext(logger *zerolog.Logger, ctx context.Context) context.Context { + return context.WithValue(ctx, LoggerContextKey, logger) +} + +func ExtractLogger(ctx context.Context) *zerolog.Logger { + ilogger := ctx.Value(LoggerContextKey) + if ilogger == nil { + return GlobalLogger() + } + return ilogger.(*zerolog.Logger) +} diff --git a/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go b/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go new file mode 100644 index 0000000..272cbd2 --- /dev/null +++ b/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go @@ -0,0 +1,61 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "git.handmade.network/hmn/hmn/src/oops" + "github.com/jackc/pgx/v4" +) + +func init() { + registerMigration(AddDiscordBotTables{}) +} + +type AddDiscordBotTables struct{} + +func (m AddDiscordBotTables) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2021, 8, 7, 18, 53, 30, 0, time.UTC)) +} + +func (m AddDiscordBotTables) Name() string { + return "AddDiscordBotTables" +} + +func (m AddDiscordBotTables) Description() string { + return "Add tables for Discord bot sessions and messages" +} + +func (m AddDiscordBotTables) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + CREATE TABLE discord_session ( + pk INT NOT NULL DEFAULT 1337 PRIMARY KEY, -- this should always be set to 1337 to ensure that we only have one row :) + session_id VARCHAR(255) NOT NULL, + sequence_number INT NOT NULL, + + CONSTRAINT only_one_session CHECK (pk = 1337) + ); + `) + if err != nil { + return oops.New(err, "failed to create discord session table") + } + + _, err = tx.Exec(ctx, ` + CREATE TABLE discord_outgoingmessages ( + id SERIAL NOT NULL PRIMARY KEY, + channel_id VARCHAR(64) NOT NULL, + payload_json TEXT NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL + ); + `) + if err != nil { + return oops.New(err, "failed to create discord outgoing messages table") + } + + return nil +} + +func (m AddDiscordBotTables) Down(ctx context.Context, tx pgx.Tx) error { + panic("Implement me") +} diff --git a/src/models/discord.go b/src/models/discord.go index 52c0f4f..1843341 100644 --- a/src/models/discord.go +++ b/src/models/discord.go @@ -4,6 +4,18 @@ import ( "time" ) +type DiscordSession struct { + ID string `db:"session_id"` + SequenceNumber int `db:"sequence_number"` +} + +type DiscordOutgoingMessage struct { + ID int `db:"id"` + ChannelID string `db:"channel_id"` + PayloadJSON string `db:"payload_json"` + ExpiresAt time.Time `db:"expires_at"` +} + type DiscordMessage struct { ID string `db:"id"` ChannelID string `db:"channel_id"` diff --git a/src/oops/oops.go b/src/oops/oops.go index 53d8d0b..e92a087 100644 --- a/src/oops/oops.go +++ b/src/oops/oops.go @@ -14,7 +14,11 @@ type Error struct { } func (e *Error) Error() string { - return fmt.Sprintf("%s: %v", e.Message, e.Wrapped) + if e.Wrapped == nil { + return e.Message + } else { + return fmt.Sprintf("%s: %v", e.Message, e.Wrapped) + } } func (e *Error) Unwrap() error { diff --git a/src/utils/utils.go b/src/utils/utils.go index 6fb9926..394fcd1 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -1,5 +1,14 @@ package utils +import ( + "context" + "errors" + "fmt" + "time" + + "git.handmade.network/hmn/hmn/src/oops" +) + func IntMin(a, b int) int { if a < b { return a @@ -17,3 +26,38 @@ func IntMax(a, b int) int { func IntClamp(min, t, max int) int { return IntMax(min, IntMin(t, max)) } + +/* +Recover a panic and convert it to a returned error. Call it like so: + + func MyFunc() (err error) { + defer utils.RecoverPanicAsError(&err) + } + +If an error was already present, the panicked error will take precedence. Unfortunately there's +no good way to include both errors because you can't really have two chains of errors and still +play nice with the standard library's Unwrap behavior. But most of the time this shouldn't be an +issue, since the panic will probably occur before a meaningful error value was set. +*/ +func RecoverPanicAsError(err *error) { + if r := recover(); r != nil { + var recoveredErr error + if rerr, ok := r.(error); ok { + recoveredErr = rerr + } else { + recoveredErr = fmt.Errorf("panic with value: %v", r) + } + *err = oops.New(recoveredErr, "panic recovered as error") + } +} + +var ErrSleepInterrupted = errors.New("sleep interrupted by context cancellation") + +func SleepContext(ctx context.Context, d time.Duration) error { + select { + case <-ctx.Done(): + return ErrSleepInterrupted + case <-time.After(d): + return nil + } +} diff --git a/src/website/website.go b/src/website/website.go index c3c22bf..79f3dc2 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -13,6 +13,7 @@ import ( "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/perf" "git.handmade.network/hmn/hmn/src/templates" @@ -42,6 +43,7 @@ var WebsiteCommand = &cobra.Command{ auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn), perfCollector.Done, + discord.RunDiscordBot(backgroundJobContext, conn), ) signals := make(chan os.Signal, 1) @@ -52,7 +54,9 @@ var WebsiteCommand = &cobra.Command{ go func() { timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + logging.Info().Msg("shutting down web server") server.Shutdown(timeout) + logging.Info().Msg("cancelling background jobs") cancelBackgroundJobs() }()