diff --git a/go.mod b/go.mod index f9b87fe9..16fef9d0 100644 --- a/go.mod +++ b/go.mod @@ -10,11 +10,13 @@ require ( github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 github.com/go-stack/stack v1.8.0 github.com/google/uuid v1.2.0 + github.com/gorilla/websocket v1.4.2 github.com/huandu/xstrings v1.3.2 // indirect github.com/imdario/mergo v0.3.12 // indirect github.com/jackc/pgconn v1.8.0 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.10.1 + github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.1.1 // indirect github.com/rs/zerolog v1.21.0 github.com/spf13/cobra v1.1.3 diff --git a/go.sum b/go.sum index 241c838c..474fba92 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,7 @@ github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= @@ -185,6 +186,8 @@ github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv github.com/jackc/puddle v1.1.3 h1:JnPg/5Q9xVJGfjsO5CPUOjnJps1JaRUm8I9FXVCFK94= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= diff --git a/src/config/types.go b/src/config/types.go index f76a88b8..19485b98 100644 --- a/src/config/types.go +++ b/src/config/types.go @@ -25,6 +25,7 @@ type HMNConfig struct { Auth AuthConfig Email EmailConfig DigitalOcean DigitalOceanConfig + Discord DiscordConfig } type PostgresConfig struct { @@ -62,6 +63,14 @@ type EmailConfig struct { OverrideRecipientEmail string } +type DiscordConfig struct { + BotToken string + BotUserID string + + ShowcaseChannelID string + LibraryChannelID string +} + func (info PostgresConfig) DSN() string { return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName) } diff --git a/src/discord/gateway.go b/src/discord/gateway.go new file mode 100644 index 00000000..7cdfdac4 --- /dev/null +++ b/src/discord/gateway.go @@ -0,0 +1,643 @@ +package discord + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net" + "runtime" + "sync" + "time" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/models" + "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/utils" + "github.com/gorilla/websocket" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/jpillora/backoff" +) + +func RunDiscordBot(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{} { + log := logging.ExtractLogger(ctx).With().Str("module", "discord").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + if config.Config.Discord.BotToken == "" { + log.Warn().Msg("No Discord bot token was provided, so the Discord bot cannot run.") + done := make(chan struct{}, 1) + done <- struct{}{} + return done + } + + done := make(chan struct{}) + go func() { + defer func() { + log.Debug().Msg("shut down Discord bot") + done <- struct{}{} + }() + + boff := backoff.Backoff{ + Min: 1 * time.Second, + Max: 5 * time.Minute, + } + + for { + select { + case <-ctx.Done(): + return + default: + } + + func() { + log.Info().Msg("Connecting to the Discord gateway") + bot := newBotInstance(dbConn) + err := bot.Run(ctx) + if err != nil { + dur := boff.Duration() + log.Error(). + Err(err). + Dur("retrying after", dur). + Msg("failed to run Discord bot") + + timer := time.NewTimer(dur) + select { + case <-ctx.Done(): + case <-timer.C: + } + + return + } + + select { + case <-ctx.Done(): + return + default: + } + + // This delay satisfies the 1 to 5 second delay Discord + // wants on reconnects, and seems fine to do every time. + delay := time.Duration(int64(time.Second) + rand.Int63n(int64(time.Second*4))) + log.Info().Dur("delay", delay).Msg("Reconnecting to Discord") + time.Sleep(delay) + + boff.Reset() + }() + } + }() + return done +} + +var outgoingMessagesReady = make(chan struct{}, 1) + +type discordBotInstance struct { + conn *websocket.Conn + dbConn *pgxpool.Pool + + heartbeatIntervalMs int + forceHeartbeat chan struct{} + + /* + Every time we send a heartbeat, we set this variable to false. + Whenever we ack a heartbeat, we set this variable to true. + If we try to send a heartbeat but the previous one was not + acked, then we close the connection and try to reconnect. + */ + didAckHeartbeat bool + + /* + All goroutines should call this when they exit, to ensure that + the other goroutines shut down as well. + */ + cancel context.CancelFunc + wg sync.WaitGroup +} + +func newBotInstance(dbConn *pgxpool.Pool) *discordBotInstance { + return &discordBotInstance{ + dbConn: dbConn, + forceHeartbeat: make(chan struct{}), + didAckHeartbeat: true, + } +} + +/* +Runs a bot instance to completion. It will start up a gateway connection and return when the +connection is closed. It only returns an error when something unexpected occurs; if so, you should +do exponential backoff before reconnecting. Otherwise you can reconnect right away. +*/ +func (bot *discordBotInstance) Run(ctx context.Context) (err error) { + defer utils.RecoverPanicAsError(&err) + + ctx, bot.cancel = context.WithCancel(ctx) + defer bot.cancel() + + err = bot.connect(ctx) + if err != nil { + return oops.New(err, "failed to connect to Discord gateway") + } + defer bot.conn.Close() + + bot.wg.Add(1) + go bot.doSender(ctx) + + // Wait for child goroutines to exit (they will do so when context is canceled). This ensures + // that nothing is in the middle of sending. Then close the connection, so that this goroutine + // can finish as well. + go func() { + bot.wg.Wait() + bot.conn.Close() + }() + + for { + msg, err := bot.receiveGatewayMessage(ctx) + if err != nil { + // TODO: Are there other kinds of connection close events that we need to handle? Probably? + if errors.Is(err, net.ErrClosed) { + // If the connection is closed, that's our cue to shut down the bot. Any errors + // related to the closure will have been logged elsewhere anyway. + return nil + } else { + return oops.New(err, "failed to receive message from the gateway") + } + } + + // Update the sequence number in the db + if msg.SequenceNumber != nil { + _, err = bot.dbConn.Exec(ctx, `UPDATE discord_session SET sequence_number = $1`, *msg.SequenceNumber) + if err != nil { + return oops.New(err, "failed to save latest sequence number") + } + } + + switch msg.Opcode { + case OpcodeDispatch: + // Just a normal event + err := bot.processEventMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process gateway event") + } + case OpcodeHeartbeat: + bot.forceHeartbeat <- struct{}{} + case OpcodeHeartbeatACK: + bot.didAckHeartbeat = true + case OpcodeReconnect: + logging.ExtractLogger(ctx).Info().Msg("Discord asked us to reconnect to the gateway") + return nil + case OpcodeInvalidSession: + // We tried to resume but the session was invalid. + // Delete the session and reconnect from scratch again. + _, err := bot.dbConn.Exec(ctx, `DELETE FROM discord_session`) + if err != nil { + return oops.New(err, "failed to delete invalid session") + } + return nil + } + } +} + +/* +The connection process in short: +- Gateway sends Hello, asking the client to heartbeat on some interval +- Client sends Identify and starts heartbeat process +- Gateway sends Ready, client is now connected to gateway + +Or, if we have an existing session: +- Gateway sends Hello, asking the client to heartbeat on some interval +- Client sends Resume and starts heartbeat process +- Gateway sends all missed events followed by a RESUMED event, or an Invalid Session if the + session is ded + +Note that some events probably won't be received until the Guild Create message is received. + +It's a little annoying to handle resumes since we want to handle the missed messages as if we were +receiving them in real time. But we're kind of in a different state from from when we're normally +receiving messages, because we are expecting a RESUMED event at the end, and the first message we +receive might be an Invalid Session. So, unfortunately, we just have to handle the Invalid Session +and RESUMED messages in our main message receiving loop instead of here. + +(Discord could have prevented this if they send a "Resume ACK" message before replaying events. +That way, we could receive exactly one message after sending Resume, either a Resume ACK or an +Invalid Session, and from there it would be crystal clear what to do. Alas!) +*/ +func (bot *discordBotInstance) connect(ctx context.Context) error { + res, err := GetGatewayBot(ctx) + if err != nil { + return oops.New(err, "failed to get gateway URL") + } + + conn, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s/?v=9&encoding=json", res.URL), nil) + if err != nil { + return oops.New(err, "failed to connect to the Discord gateway") + } + bot.conn = conn + + helloMessage, err := bot.receiveGatewayMessage(ctx) + if err != nil { + return oops.New(err, "failed to read Hello message") + } + if helloMessage.Opcode != OpcodeHello { + return oops.New(nil, "expected a Hello (opcode %d), but got opcode %d", OpcodeHello, helloMessage.Opcode) + } + helloData := HelloFromMap(helloMessage.Data) + bot.heartbeatIntervalMs = helloData.HeartbeatIntervalMs + + // Now that the gateway has said hello, we need to establish a new session, either resuming + // an old one or starting a new one. + + shouldResume := true + isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) + if err != nil { + if errors.Is(err, db.ErrNoMatchingRows) { + // No session yet! Just identify and get on with it + shouldResume = false + } else { + return oops.New(err, "failed to get current session from database") + } + } + + if shouldResume { + // Reconnect to the previous session + session := isession.(*models.DiscordSession) + + err := bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeResume, + Data: Resume{ + Token: config.Config.Discord.BotToken, + SessionID: session.ID, + SequenceNumber: session.SequenceNumber, + }, + }) + if err != nil { + return oops.New(err, "failed to send Resume message") + } + + return nil + } else { + // Start a new session + err := bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeIdentify, + Data: Identify{ + Token: config.Config.Discord.BotToken, + Properties: IdentifyConnectionProperties{ + OS: runtime.GOOS, + Browser: BotName, + Device: BotName, + }, + Intents: IntentGuilds | IntentGuildMessages, + }, + }) + if err != nil { + return oops.New(err, "failed to send Identify message") + } + + readyMessage, err := bot.receiveGatewayMessage(ctx) + if err != nil { + return oops.New(err, "failed to read Ready message") + } + if readyMessage.Opcode != OpcodeDispatch { + return oops.New(err, "expected a READY event, but got a message with opcode %d", readyMessage.Opcode) + } + if *readyMessage.EventName != "READY" { + return oops.New(err, "expected a READY event, but got a %s event", *readyMessage.EventName) + } + readyData := ReadyFromMap(readyMessage.Data) + + _, err = bot.dbConn.Exec(ctx, + ` + INSERT INTO discord_session (session_id, sequence_number) + VALUES ($1, $2) + ON CONFLICT (pk) DO UPDATE + SET session_id = $1, sequence_number = $2 + `, + readyData.SessionID, + *readyMessage.SequenceNumber, + ) + if err != nil { + return oops.New(err, "failed to save new bot session in the database") + } + } + + return nil +} + +/* +Sends outgoing gateway messages and channel messages. Handles heartbeats. This function should be +run as its own goroutine. +*/ +func (bot *discordBotInstance) doSender(ctx context.Context) { + defer bot.wg.Done() + defer bot.cancel() + + log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + defer log.Info().Msg("shutting down Discord sender") + + /* + The first heartbeat is supposed to occur at a random time within + the first heartbeat interval. + + https://discord.com/developers/docs/topics/gateway#heartbeating + */ + dur := time.Duration(bot.heartbeatIntervalMs) * time.Millisecond + firstDelay := time.NewTimer(time.Duration(rand.Int63n(int64(dur)))) + heartbeatTicker := &time.Ticker{} // this will start never ticking, and get initialized after the first heartbeat + + // Returns false if the heartbeat failed + sendHeartbeat := func() bool { + if !bot.didAckHeartbeat { + log.Error().Msg("did not receive a heartbeat ACK in between heartbeats") + return false + } + bot.didAckHeartbeat = false + + latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) + if err != nil { + log.Error().Err(err).Msg("failed to fetch latest sequence number from the db") + return false + } + + err = bot.sendGatewayMessage(ctx, GatewayMessage{ + Opcode: OpcodeHeartbeat, + Data: latestSequenceNumber, + }) + if err != nil { + log.Error().Err(err).Msg("failed to send heartbeat") + return false + } + + return true + } + + /* + Start a goroutine to fetch outgoing messages from the db. We do this in a separate goroutine + to ensure that issues talking to the database don't prevent us from sending heartbeats. + */ + messages := make(chan *models.DiscordOutgoingMessage) + bot.wg.Add(1) + go func(ctx context.Context) { + defer bot.wg.Done() + defer bot.cancel() + + log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "sender db reader").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + defer log.Info().Msg("stopping db reader") + + // We will poll the database just in case the notification mechanism doesn't work. + ticker := time.NewTicker(time.Second * 5) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + case <-outgoingMessagesReady: + } + + func() { + tx, err := bot.dbConn.Begin(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to start transaction") + return + } + defer tx.Rollback(ctx) + + itMessages, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` + SELECT $columns + FROM discord_outgoingmessages + ORDER BY id ASC + `) + if err != nil { + log.Error().Err(err).Msg("failed to fetch outgoing Discord messages") + return + } + + msgs := itMessages.ToSlice() + for _, imsg := range msgs { + msg := imsg.(*models.DiscordOutgoingMessage) + if time.Now().After(msg.ExpiresAt) { + continue + } + messages <- msg + } + + /* + NOTE(ben): Doing this in a transaction means that we will only delete the + messages that we originally fetched. At least, as long as the database's + isolation level is Read Committed, which is the default. + + https://www.postgresql.org/docs/current/transaction-iso.html + */ + _, err = tx.Exec(ctx, `DELETE FROM discord_outgoingmessages`) + if err != nil { + log.Error().Err(err).Msg("failed to delete outgoing messages") + return + } + + err = tx.Commit(ctx) + if err != nil { + log.Error().Err(err).Msg("failed to read and delete outgoing messages") + return + } + + if len(msgs) > 0 { + log.Debug().Int("num messages", len(msgs)).Msg("Sent and deleted outgoing messages") + } + }() + } + }(ctx) + + /* + Whenever we want to send a gateway message, we must receive a value from + this channel first. A goroutine continuously fills the channel at a rate + that respects Discord's gateway rate limit. + + Don't use this for heartbeats; heartbeats should go out immediately. + Don't forget that the server can request a heartbeat at any time. + + See the docs for more details. The capacity of this channel is chosen to + always leave us overhead for heartbeats and other shenanigans. + + https://discord.com/developers/docs/topics/gateway#rate-limiting + */ + rateLimiter := make(chan struct{}, 100) + go func() { + for { + rateLimiter <- struct{}{} + time.Sleep(500 * time.Millisecond) + } + }() + /* + NOTE(ben): This rate limiter is actually not used right now + because we're not actually sending any meaningful gateway + messages. But in the future, if we end up sending presence + updates or other gateway commands, we need to make sure to + put this limiter on all of those outgoing commands. + */ + + for { + select { + case <-ctx.Done(): + return + case <-firstDelay.C: + if ok := sendHeartbeat(); !ok { + return + } + heartbeatTicker = time.NewTicker(dur) + case <-heartbeatTicker.C: + if ok := sendHeartbeat(); !ok { + return + } + case <-bot.forceHeartbeat: + if ok := sendHeartbeat(); !ok { + return + } + heartbeatTicker.Reset(dur) + case msg := <-messages: + _, err := CreateMessage(ctx, msg.ChannelID, msg.PayloadJSON) + if err != nil { + log.Error().Err(err).Msg("failed to send Discord message") + } + } + } +} + +func (bot *discordBotInstance) receiveGatewayMessage(ctx context.Context) (*GatewayMessage, error) { + _, msgBytes, err := bot.conn.ReadMessage() + if err != nil { + return nil, err + } + + var msg GatewayMessage + err = json.Unmarshal(msgBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord gateway message") + } + + logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("received gateway message") + + return &msg, nil +} + +func (bot *discordBotInstance) sendGatewayMessage(ctx context.Context, msg GatewayMessage) error { + logging.ExtractLogger(ctx).Debug().Interface("msg", msg).Msg("sending gateway message") + return bot.conn.WriteMessage(websocket.TextMessage, msg.ToJSON()) +} + +/* +Processes a single event message from Discord. If this returns an error, it means something has +really gone wrong, bad enough that the connection should be shut down. Otherwise it will just log +any errors that occur. +*/ +func (bot *discordBotInstance) processEventMsg(ctx context.Context, msg *GatewayMessage) error { + if msg.Opcode != OpcodeDispatch { + panic(fmt.Sprintf("processEventMsg must only be used on Dispatch messages (opcode %d). Validate this before you call this function.", OpcodeDispatch)) + } + + switch *msg.EventName { + case "RESUMED": + // Nothing to do, but at least we can log something + logging.ExtractLogger(ctx).Info().Msg("Finished resuming gateway session") + case "MESSAGE_CREATE": + newMessage := MessageFromMap(msg.Data) + + err := bot.messageCreateOrUpdate(ctx, &newMessage) + if err != nil { + return oops.New(err, "error on new message") + } + case "MESSAGE_UPDATE": + newMessage := MessageFromMap(msg.Data) + + err := bot.messageCreateOrUpdate(ctx, &newMessage) + if err != nil { + return oops.New(err, "error on updated message") + } + } + + return nil +} + +func (bot *discordBotInstance) messageCreateOrUpdate(ctx context.Context, msg *Message) error { + if msg.Author != nil && msg.Author.ID == config.Config.Discord.BotUserID { + // Don't process your own messages + return nil + } + + if msg.ChannelID == config.Config.Discord.ShowcaseChannelID { + err := bot.processShowcaseMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process showcase message") + } + return nil + } + + if msg.ChannelID == config.Config.Discord.LibraryChannelID { + err := bot.processLibraryMsg(ctx, msg) + if err != nil { + return oops.New(err, "failed to process library message") + } + return nil + } + + return nil +} + +type MessageToSend struct { + ChannelID string + Req CreateMessageRequest + ExpiresAt time.Time +} + +func SendMessages( + ctx context.Context, + conn *pgxpool.Pool, + msgs ...MessageToSend, +) error { + tx, err := conn.Begin(ctx) + if err != nil { + return oops.New(err, "failed to start transaction") + } + defer tx.Rollback(ctx) + + for _, msg := range msgs { + if msg.ExpiresAt.IsZero() { + msg.ExpiresAt = time.Now().Add(30 * time.Second) + } + + reqBytes, err := json.Marshal(msg.Req) + if err != nil { + return oops.New(err, "failed to marshal Discord message to JSON") + } + + _, err = tx.Exec(ctx, + ` + INSERT INTO discord_outgoingmessages (channel_id, payload_json, expires_at) + VALUES ($1, $2, $3) + `, + msg.ChannelID, + string(reqBytes), + msg.ExpiresAt, + ) + if err != nil { + return oops.New(err, "failed to save outgoing Discord message to the database") + } + } + + err = tx.Commit(ctx) + if err != nil { + return oops.New(err, "failed to commit outgoing Discord messages") + } + + // Notify the sender that messages are ready to go + select { + case outgoingMessagesReady <- struct{}{}: + default: + } + + return nil +} diff --git a/src/discord/payloads.go b/src/discord/payloads.go new file mode 100644 index 00000000..4b5870e1 --- /dev/null +++ b/src/discord/payloads.go @@ -0,0 +1,292 @@ +package discord + +import ( + "encoding/json" +) + +type Opcode int + +// https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-opcodes +// NOTE(ben): I'm not using iota because 5 is missing +const ( + OpcodeDispatch Opcode = 0 + OpcodeHeartbeat Opcode = 1 + OpcodeIdentify Opcode = 2 + OpcodePresenceUpdate Opcode = 3 + OpcodeVoiceStateUpdate Opcode = 4 + OpcodeResume Opcode = 6 + OpcodeReconnect Opcode = 7 + OpcodeRequestGuildMembers Opcode = 8 + OpcodeInvalidSession Opcode = 9 + OpcodeHello Opcode = 10 + OpcodeHeartbeatACK Opcode = 11 +) + +type Intent int + +// https://discord.com/developers/docs/topics/gateway#list-of-intents +// NOTE(ben): I'm not using iota because the opcode thing made me paranoid +const ( + IntentGuilds Intent = 1 << 0 + IntentGuildMembers Intent = 1 << 1 + IntentGuildBans Intent = 1 << 2 + IntentGuildEmojisAndStickers Intent = 1 << 3 + IntentGuildIntegrations Intent = 1 << 4 + IntentGuildWebhooks Intent = 1 << 5 + IntentGuildInvites Intent = 1 << 6 + IntentGuildVoiceStates Intent = 1 << 7 + IntentGuildPresences Intent = 1 << 8 + IntentGuildMessages Intent = 1 << 9 + IntentGuildMessageReactions Intent = 1 << 10 + IntentGuildMessageTyping Intent = 1 << 11 + IntentDirectMessages Intent = 1 << 12 + IntentDirectMessageReactions Intent = 1 << 13 + IntentDirectMessageTyping Intent = 1 << 14 +) + +type GatewayMessage struct { + Opcode Opcode `json:"op"` + Data interface{} `json:"d"` + SequenceNumber *int `json:"s,omitempty"` + EventName *string `json:"t,omitempty"` +} + +func (m *GatewayMessage) ToJSON() []byte { + mBytes, err := json.Marshal(m) + if err != nil { + panic(err) + } + + // TODO: check if the payload is too big, either here or where we actually send + // https://discord.com/developers/docs/topics/gateway#sending-payloads + + return mBytes +} + +type Hello struct { + HeartbeatIntervalMs int `json:"heartbeat_interval"` +} + +func HelloFromMap(m interface{}) Hello { + // TODO: This should probably have some error handling, right? + return Hello{ + HeartbeatIntervalMs: int(m.(map[string]interface{})["heartbeat_interval"].(float64)), + } +} + +type Identify struct { + Token string `json:"token"` + Properties IdentifyConnectionProperties `json:"properties"` + Intents Intent `json:"intents"` +} + +type IdentifyConnectionProperties struct { + OS string `json:"$os"` + Browser string `json:"$browser"` + Device string `json:"$device"` +} + +type Ready struct { + GatewayVersion int `json:"v"` + User User `json:"user"` + SessionID string `json:"session_id"` +} + +func ReadyFromMap(m interface{}) Ready { + mmap := m.(map[string]interface{}) + + return Ready{ + GatewayVersion: int(mmap["v"].(float64)), + User: UserFromMap(mmap["user"]), + SessionID: mmap["session_id"].(string), + } +} + +type Resume struct { + Token string `json:"token"` + SessionID string `json:"session_id"` + SequenceNumber int `json:"seq"` +} + +type ChannelType int + +const ( + ChannelTypeGuildext ChannelType = 0 + ChannelTypeDM ChannelType = 1 + ChannelTypeGuildVoice ChannelType = 2 + ChannelTypeGroupDM ChannelType = 3 + ChannelTypeGuildCategory ChannelType = 4 + ChannelTypeGuildNews ChannelType = 5 + ChannelTypeGuildStore ChannelType = 6 + ChannelTypeGuildNewsThread ChannelType = 10 + ChannelTypeGuildPublicThread ChannelType = 11 + ChannelTypeGuildPrivateThread ChannelType = 12 + ChannelTypeGuildStageVoice ChannelType = 13 +) + +type Channel struct { + ID string `json:"id"` + Type ChannelType `json:"type"` + GuildID string `json:"guild_id"` + Name string `json:"name"` + Receipients []User `json:"recipients"` + OwnerID User `json:"owner_id"` + ParentID *string `json:"parent_id"` +} + +type MessageType int + +const ( + MessageTypeDefault MessageType = 0 + + MessageTypeRecipientAdd MessageType = 1 + MessageTypeRecipientRemove MessageType = 2 + MessageTypeCall MessageType = 3 + + MessageTypeChannelNameChange MessageType = 4 + MessageTypeChannelIconChange MessageType = 5 + MessageTypeChannelPinnedMessage MessageType = 6 + + MessageTypeGuildMemberJoin MessageType = 7 + + MessageTypeUserPremiumGuildSubscription MessageType = 8 + MessageTypeUserPremiumGuildSubscriptionTier1 MessageType = 9 + MessageTypeUserPremiumGuildSubscriptionTier2 MessageType = 10 + MessageTypeUserPremiumGuildSubscriptionTier3 MessageType = 11 + + MessageTypeChannelFollowAdd MessageType = 12 + + MessageTypeGuildDiscoveryDisqualified MessageType = 14 + MessageTypeGuildDiscoveryRequalified MessageType = 15 + MessageTypeGuildDiscoveryGracePeriodInitialWarning MessageType = 16 + MessageTypeGuildDiscoveryGracePeriodFinalWarning MessageType = 17 + + MessageTypeThreadCreated MessageType = 18 + MessageTypeReply MessageType = 19 + MessageTypeApplicationCommand MessageType = 20 + MessageTypeThreadStarterMessage MessageType = 21 + MessageTypeGuildInviteReminder MessageType = 22 +) + +// https://discord.com/developers/docs/resources/channel#message-object +type Message struct { + ID string `json:"id"` + ChannelID string `json:"channel_id"` + Content string `json:"content"` + Author *User `json:"author"` // note that this may not be an actual valid user (see the docs) + // TODO: Author info + // TODO: Timestamp parsing, yay + Type MessageType `json:"type"` + + Attachments []Attachment `json:"attachments"` + + originalMap map[string]interface{} +} + +func MessageFromMap(m interface{}) Message { + /* + Some gateway events, like MESSAGE_UPDATE, do not contain the + entire message body. So we need to be defensive on all fields here, + except the most basic identifying information. + */ + + mmap := m.(map[string]interface{}) + msg := Message{ + ID: mmap["id"].(string), + ChannelID: mmap["channel_id"].(string), + Content: maybeString(mmap, "content"), + Type: MessageType(maybeInt(mmap, "type")), + + originalMap: mmap, + } + + if author, ok := mmap["author"]; ok { + u := UserFromMap(author) + msg.Author = &u + } + + if iattachments, ok := mmap["attachments"]; ok { + attachments := iattachments.([]interface{}) + for _, iattachment := range attachments { + msg.Attachments = append(msg.Attachments, AttachmentFromMap(iattachment)) + } + } + + return msg +} + +// https://discord.com/developers/docs/resources/user#user-object +type User struct { + ID string `json:"id"` + Username string `json:"username"` + Discriminator string `json:"discriminator"` + IsBot bool `json:"bot"` +} + +func UserFromMap(m interface{}) User { + mmap := m.(map[string]interface{}) + + u := User{ + ID: mmap["id"].(string), + Username: mmap["username"].(string), + Discriminator: mmap["discriminator"].(string), + } + + if isBot, ok := mmap["bot"]; ok { + u.IsBot = isBot.(bool) + } + + return u +} + +type Attachment struct { + ID string `json:"id"` + Filename string `json:"filename"` + ContentType string `json:"content_type"` + Size int `json:"size"` + Url string `json:"url"` + ProxyUrl string `json:"proxy_url"` + Height *int `json:"height"` + Width *int `json:"width"` +} + +func AttachmentFromMap(m interface{}) Attachment { + mmap := m.(map[string]interface{}) + a := Attachment{ + ID: mmap["id"].(string), + Filename: mmap["filename"].(string), + ContentType: maybeString(mmap, "content_type"), + Size: int(mmap["size"].(float64)), + Url: mmap["url"].(string), + ProxyUrl: mmap["proxy_url"].(string), + Height: maybeIntP(mmap, "height"), + Width: maybeIntP(mmap, "width"), + } + + return a +} + +func maybeString(m map[string]interface{}, k string) string { + val, ok := m[k] + if !ok { + return "" + } + return val.(string) +} + +func maybeInt(m map[string]interface{}, k string) int { + val, ok := m[k] + if !ok { + return 0 + } + return int(val.(float64)) +} + +func maybeIntP(m map[string]interface{}, k string) *int { + val, ok := m[k] + if !ok { + return nil + } + intval := int(val.(float64)) + return &intval +} diff --git a/src/discord/ratelimiting.go b/src/discord/ratelimiting.go new file mode 100644 index 00000000..8d74e1d9 --- /dev/null +++ b/src/discord/ratelimiting.go @@ -0,0 +1,284 @@ +package discord + +import ( + "context" + "errors" + "math" + "net/http" + "strconv" + "sync" + "time" + + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/utils" +) + +var limiterLog = logging.GlobalLogger().With(). + Str("module", "discord"). + Str("discord actor", "rate limiter"). + Logger() + +var buckets sync.Map // map[route name]bucket name +var rateLimiters sync.Map // map[bucket name]*restRateLimiter +var limiterInitMutex sync.Mutex + +type restRateLimiter struct { + requests chan struct{} + refills chan rateLimiterRefill +} + +type rateLimiterRefill struct { + resetAfter time.Duration + maxRequests int +} + +/* +Whenever we send a request, we must sleep until this time +(if it is in the future, of course). This is a quick and +dirty way to pause all sending in case of a global rate +limit. + +I could put a mutex on this but I don't think it's actually +a problem to have race conditions here. Just set it when +you get throttled. EZ. +*/ +var globalRateLimitTime time.Time + +type rateLimitHeaders struct { + Bucket string + Limit int + Remaining int + ResetAfter time.Duration +} + +func parseRateLimitHeaders(header http.Header) (rateLimitHeaders, bool) { + var err error + + bucket := header.Get("X-RateLimit-Bucket") + var limit int + var remaining int + var resetAfter time.Duration + + limitStr := header.Get("X-RateLimit-Limit") + if limitStr != "" { + limit, err = strconv.Atoi(limitStr) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", limitStr). + Msg("failed to parse X-RateLimit-Limit header") + return rateLimitHeaders{}, false + } + } + + remainingStr := header.Get("X-RateLimit-Remaining") + if remainingStr != "" { + remaining, err = strconv.Atoi(remainingStr) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", remainingStr). + Msg("failed to parse X-RateLimit-Remaining header") + return rateLimitHeaders{}, false + } + } + + resetAfterStr := header.Get("X-RateLimit-Reset-After") + if resetAfterStr != "" { + resetAfterSeconds, err := strconv.ParseFloat(resetAfterStr, 64) + if err != nil { + limiterLog.Error(). + Err(err). + Str("value", resetAfterStr). + Msg("failed to parse X-RateLimit-Reset-After header") + return rateLimitHeaders{}, false + } + resetAfter = time.Duration(math.Ceil(resetAfterSeconds)) * time.Second + } + + return rateLimitHeaders{ + Bucket: bucket, + Limit: limit, + Remaining: remaining, + ResetAfter: resetAfter, + }, true +} + +func createLimiter(headers rateLimitHeaders, routeName string) { + limiterInitMutex.Lock() + defer limiterInitMutex.Unlock() + + buckets.Store(routeName, headers.Bucket) + ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{ + requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts + refills: make(chan rateLimiterRefill), + }) + if !loaded { + limiter := ilimiter.(*restRateLimiter) + + log := limiterLog.With().Str("bucket", headers.Bucket).Logger() + + prefillloop: + // Pre-fill the limiter with remaining requests + for i := 0; i < headers.Remaining; i++ { + select { + case limiter.requests <- struct{}{}: + default: + log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + break prefillloop + } + } + + /* + Start the refiller for this bucket. It waits for a request to tell + it when to next reset the rate limit, and how full to fill the bucket. + It then sleeps and refills the bucket, just like it should :) + */ + go func() { + for { + // Wake up on the first request after refilling + refill := <-limiter.refills + + // Sleep for the remainder of the bucket's time + time.Sleep(refill.resetAfter) + + drainloop: + // drain the bucket + for { + select { + case <-limiter.requests: + default: + break drainloop + } + } + + refillloop: + // refill it with the max number of requests + for i := 0; i < refill.maxRequests; i++ { + select { + case limiter.requests <- struct{}{}: + default: + log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + break refillloop + } + } + + // And then we wait again to hear about our next + // bucket's worth of requests. + } + }() + + // Tell the refiller about its first refill + limiter.refills <- rateLimiterRefill{ + resetAfter: headers.ResetAfter, + maxRequests: headers.Limit, + } + } +} + +func (l *restRateLimiter) update(headers rateLimitHeaders) { + refill := rateLimiterRefill{ + resetAfter: headers.ResetAfter, + maxRequests: headers.Limit, + } + + /* + Tell the refiller about this request. If the refiller is already + busy sleeping, this will have no effect, which is what we want. + (It's already sleeping for as long as it needs to.) + */ + select { + case l.refills <- refill: + default: + } +} + +func doWithRateLimiting(ctx context.Context, routeName string, getReq func(ctx context.Context) *http.Request) (*http.Response, error) { + var bucket string + ibucket, ok := buckets.Load(routeName) + if ok { + bucket = ibucket.(string) + } + + for { + var limiter *restRateLimiter + if bucket != "" { + ilimiter, ok := rateLimiters.Load(bucket) + if ok { + limiter = ilimiter.(*restRateLimiter) + } + } + + if globalRateLimitTime.After(time.Now()) { + // oh boy, global rate limit, pause until the coast is clear + err := utils.SleepContext(ctx, globalRateLimitTime.Sub(time.Now())+1*time.Second) + if err != nil { + return nil, err + } + } + + if limiter != nil { + select { + case <-limiter.requests: + case <-ctx.Done(): + return nil, errors.New("request interrupted during rate limiting") + } + } + + res, err := httpClient.Do(getReq(ctx)) + if err != nil { + return nil, err + } + + headers, headersOk := parseRateLimitHeaders(res.Header) + if headersOk { + if limiter == nil || headers.Bucket != bucket { + createLimiter(headers, routeName) + } else { + limiter.update(headers) + } + } + + if res.StatusCode == 429 { + if res.Header.Get("X-RateLimit-Global") != "" { + // globally rate limited + logging.ExtractLogger(ctx).Warn().Msg("got globally rate limited by Discord") + retryAfter, err := strconv.Atoi(res.Header.Get("Retry-After")) + if err == nil { + globalRateLimitTime = time.Now().Add(time.Duration(retryAfter) * time.Second) + } else { + // well this is bad, just sleep for 60 seconds and pray that it's long enough + logging.ExtractLogger(ctx).Warn(). + Err(err). + Msg("got globally rate limited but couldn't determine how long to wait") + globalRateLimitTime = time.Now().Add(60 * time.Second) + } + } else { + // locally rate limited + + /* + Despite our best efforts, we ended up rate limited anyway. + Simply wait the amount of time Discord asks, and then try + again. On the next go-around, hopefully we'll either succeed + or have a rate limiter initialized and ready to go. + */ + logging.ExtractLogger(ctx).Warn().Msg("got rate limited by Discord") + if headersOk { + err := utils.SleepContext(ctx, headers.ResetAfter) + if err != nil { + return nil, err + } + } else { + logging.ExtractLogger(ctx).Warn().Msg("got rate limited, but didn't have the headers??") + err := utils.SleepContext(ctx, 1*time.Second) + if err != nil { + return nil, err + } + } + } + continue + } + + return res, nil + } +} diff --git a/src/discord/rest.go b/src/discord/rest.go new file mode 100644 index 00000000..506c7854 --- /dev/null +++ b/src/discord/rest.go @@ -0,0 +1,182 @@ +package discord + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httputil" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/oops" +) + +const ( + BotName = "HandmadeNetwork" + BaseUrl = "https://discord.com/api/v9" + + UserAgentURL = "https://handmade.network/" + UserAgentVersion = "1.0" +) + +var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersion) + +var httpClient = &http.Client{} + +func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request { + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewBuffer(body) + } + + req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader) + if err != nil { + panic(err) + } + req.Header.Add("Authorization", fmt.Sprintf("Bot %s", config.Config.Discord.BotToken)) + req.Header.Add("User-Agent", UserAgent) + + return req +} + +type GetGatewayBotResponse struct { + URL string `json:"url"` + // We don't care about shards or session limit stuff; we will never hit those limits +} + +func GetGatewayBot(ctx context.Context) (*GetGatewayBotResponse, error) { + const name = "Get Gateway Bot" + + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodGet, "/gateway/bot", nil) + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != 200 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + body, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var result GetGatewayBotResponse + err = json.Unmarshal(body, &result) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord response") + } + + return &result, nil +} + +type CreateMessageRequest struct { + Content string `json:"content"` +} + +func CreateMessage(ctx context.Context, channelID string, payloadJSON string) (*Message, error) { + const name = "Create Message" + + path := fmt.Sprintf("/channels/%s/messages", channelID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodPost, path, []byte(payloadJSON)) + req.Header.Add("Content-Type", "application/json") + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + // Maybe in the future we could more nicely handle errors like "bad channel", + // but honestly what are the odds that we mess that up... + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var msg Message + err = json.Unmarshal(bodyBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord message") + } + + return &msg, nil +} + +func DeleteMessage(ctx context.Context, channelID string, messageID string) error { + const name = "Delete Message" + + path := fmt.Sprintf("/channels/%s/messages/%s", channelID, messageID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodDelete, path, nil) + }) + if err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + logErrorResponse(ctx, name, res, "") + return oops.New(nil, "got unexpected status code when deleting message") + } + + return nil +} + +func CreateDM(ctx context.Context, recipientID string) (*Channel, error) { + const name = "Create DM" + + path := "/users/@me/channels" + body := []byte(fmt.Sprintf(`{"recipient_id":"%s"}`, recipientID)) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodPost, path, body) + req.Header.Add("Content-Type", "application/json") + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var channel Channel + err = json.Unmarshal(bodyBytes, &channel) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord channel") + } + + return &channel, nil +} + +func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { + dump, err := httputil.DumpResponse(res, true) + if err != nil { + panic(err) + } + + logging.ExtractLogger(ctx).Error().Str("name", name).Msg(msg) + fmt.Println(string(dump)) +} diff --git a/src/discord/showcase.go b/src/discord/showcase.go new file mode 100644 index 00000000..6dacb161 --- /dev/null +++ b/src/discord/showcase.go @@ -0,0 +1,115 @@ +package discord + +import ( + "context" + "net/url" + "regexp" + "strings" + + "git.handmade.network/hmn/hmn/src/oops" +) + +var reDiscordMessageLink = regexp.MustCompile(`https?://.+?(\s|$)`) + +func (bot *discordBotInstance) processShowcaseMsg(ctx context.Context, msg *Message) error { + switch msg.Type { + case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand: + default: + return nil + } + + hasGoodContent := true + if originalMessageHasField(msg, "content") && !messageHasLinks(msg.Content) { + hasGoodContent = false + } + + hasGoodAttachments := true + if originalMessageHasField(msg, "attachments") && len(msg.Attachments) == 0 { + hasGoodAttachments = false + } + + if !hasGoodContent && !hasGoodAttachments { + err := DeleteMessage(ctx, msg.ChannelID, msg.ID) + if err != nil { + return oops.New(err, "failed to delete message") + } + + if msg.Author != nil && !msg.Author.IsBot { + channel, err := CreateDM(ctx, msg.Author.ID) + if err != nil { + return oops.New(err, "failed to create DM channel") + } + + err = SendMessages(ctx, bot.dbConn, MessageToSend{ + ChannelID: channel.ID, + Req: CreateMessageRequest{ + Content: "Posts in #project-showcase are required to have either an image/video or a link. Discuss showcase content in #projects.", + }, + }) + if err != nil { + return oops.New(err, "failed to send showcase warning message") + } + } + } + + return nil +} + +func (bot *discordBotInstance) processLibraryMsg(ctx context.Context, msg *Message) error { + switch msg.Type { + case MessageTypeDefault, MessageTypeReply, MessageTypeApplicationCommand: + default: + return nil + } + + if !originalMessageHasField(msg, "content") { + return nil + } + + if !messageHasLinks(msg.Content) { + err := DeleteMessage(ctx, msg.ChannelID, msg.ID) + if err != nil { + return oops.New(err, "failed to delete message") + } + + if msg.Author != nil && !msg.Author.IsBot { + channel, err := CreateDM(ctx, msg.Author.ID) + if err != nil { + return oops.New(err, "failed to create DM channel") + } + + err = SendMessages(ctx, bot.dbConn, MessageToSend{ + ChannelID: channel.ID, + Req: CreateMessageRequest{ + Content: "Posts in #the-library are required to have a link. Discuss library content in other relevant channels.", + }, + }) + if err != nil { + return oops.New(err, "failed to send showcase warning message") + } + } + } + + return nil +} + +func messageHasLinks(content string) bool { + links := reDiscordMessageLink.FindAllString(content, -1) + for _, link := range links { + _, err := url.Parse(strings.TrimSpace(link)) + if err == nil { + return true + } + } + + return false +} + +func originalMessageHasField(msg *Message, field string) bool { + if msg.originalMap == nil { + return false + } + + _, ok := msg.originalMap[field] + return ok +} diff --git a/src/logging/logging.go b/src/logging/logging.go index 8b1ed9b2..243765fc 100644 --- a/src/logging/logging.go +++ b/src/logging/logging.go @@ -1,6 +1,7 @@ package logging import ( + "context" "encoding/json" "os" "sort" @@ -16,7 +17,7 @@ import ( func init() { zerolog.ErrorStackMarshaler = oops.ZerologStackMarshaler - log.Logger = log.Output(NewPrettyZerologWriter()) + log.Logger = log.Output(NewPrettyZerologWriter()).With().Stack().Logger() zerolog.SetGlobalLevel(config.Config.LogLevel) } @@ -203,3 +204,17 @@ func LogPanicValue(logger *zerolog.Logger, val interface{}, msg string) { Msg(msg) } } + +const LoggerContextKey = "logger" + +func AttachLoggerToContext(logger *zerolog.Logger, ctx context.Context) context.Context { + return context.WithValue(ctx, LoggerContextKey, logger) +} + +func ExtractLogger(ctx context.Context) *zerolog.Logger { + ilogger := ctx.Value(LoggerContextKey) + if ilogger == nil { + return GlobalLogger() + } + return ilogger.(*zerolog.Logger) +} diff --git a/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go b/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go new file mode 100644 index 00000000..272cbd24 --- /dev/null +++ b/src/migration/migrations/2021-08-07T185330Z_AddDiscordBotTables.go @@ -0,0 +1,61 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "git.handmade.network/hmn/hmn/src/oops" + "github.com/jackc/pgx/v4" +) + +func init() { + registerMigration(AddDiscordBotTables{}) +} + +type AddDiscordBotTables struct{} + +func (m AddDiscordBotTables) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2021, 8, 7, 18, 53, 30, 0, time.UTC)) +} + +func (m AddDiscordBotTables) Name() string { + return "AddDiscordBotTables" +} + +func (m AddDiscordBotTables) Description() string { + return "Add tables for Discord bot sessions and messages" +} + +func (m AddDiscordBotTables) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + CREATE TABLE discord_session ( + pk INT NOT NULL DEFAULT 1337 PRIMARY KEY, -- this should always be set to 1337 to ensure that we only have one row :) + session_id VARCHAR(255) NOT NULL, + sequence_number INT NOT NULL, + + CONSTRAINT only_one_session CHECK (pk = 1337) + ); + `) + if err != nil { + return oops.New(err, "failed to create discord session table") + } + + _, err = tx.Exec(ctx, ` + CREATE TABLE discord_outgoingmessages ( + id SERIAL NOT NULL PRIMARY KEY, + channel_id VARCHAR(64) NOT NULL, + payload_json TEXT NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL + ); + `) + if err != nil { + return oops.New(err, "failed to create discord outgoing messages table") + } + + return nil +} + +func (m AddDiscordBotTables) Down(ctx context.Context, tx pgx.Tx) error { + panic("Implement me") +} diff --git a/src/models/discord.go b/src/models/discord.go index 52c0f4f6..18433417 100644 --- a/src/models/discord.go +++ b/src/models/discord.go @@ -4,6 +4,18 @@ import ( "time" ) +type DiscordSession struct { + ID string `db:"session_id"` + SequenceNumber int `db:"sequence_number"` +} + +type DiscordOutgoingMessage struct { + ID int `db:"id"` + ChannelID string `db:"channel_id"` + PayloadJSON string `db:"payload_json"` + ExpiresAt time.Time `db:"expires_at"` +} + type DiscordMessage struct { ID string `db:"id"` ChannelID string `db:"channel_id"` diff --git a/src/oops/oops.go b/src/oops/oops.go index 53d8d0bf..e92a0876 100644 --- a/src/oops/oops.go +++ b/src/oops/oops.go @@ -14,7 +14,11 @@ type Error struct { } func (e *Error) Error() string { - return fmt.Sprintf("%s: %v", e.Message, e.Wrapped) + if e.Wrapped == nil { + return e.Message + } else { + return fmt.Sprintf("%s: %v", e.Message, e.Wrapped) + } } func (e *Error) Unwrap() error { diff --git a/src/utils/utils.go b/src/utils/utils.go index 6fb9926a..394fcd16 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -1,5 +1,14 @@ package utils +import ( + "context" + "errors" + "fmt" + "time" + + "git.handmade.network/hmn/hmn/src/oops" +) + func IntMin(a, b int) int { if a < b { return a @@ -17,3 +26,38 @@ func IntMax(a, b int) int { func IntClamp(min, t, max int) int { return IntMax(min, IntMin(t, max)) } + +/* +Recover a panic and convert it to a returned error. Call it like so: + + func MyFunc() (err error) { + defer utils.RecoverPanicAsError(&err) + } + +If an error was already present, the panicked error will take precedence. Unfortunately there's +no good way to include both errors because you can't really have two chains of errors and still +play nice with the standard library's Unwrap behavior. But most of the time this shouldn't be an +issue, since the panic will probably occur before a meaningful error value was set. +*/ +func RecoverPanicAsError(err *error) { + if r := recover(); r != nil { + var recoveredErr error + if rerr, ok := r.(error); ok { + recoveredErr = rerr + } else { + recoveredErr = fmt.Errorf("panic with value: %v", r) + } + *err = oops.New(recoveredErr, "panic recovered as error") + } +} + +var ErrSleepInterrupted = errors.New("sleep interrupted by context cancellation") + +func SleepContext(ctx context.Context, d time.Duration) error { + select { + case <-ctx.Done(): + return ErrSleepInterrupted + case <-time.After(d): + return nil + } +} diff --git a/src/website/website.go b/src/website/website.go index c3c22bfe..79f3dc24 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -13,6 +13,7 @@ import ( "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/perf" "git.handmade.network/hmn/hmn/src/templates" @@ -42,6 +43,7 @@ var WebsiteCommand = &cobra.Command{ auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn), perfCollector.Done, + discord.RunDiscordBot(backgroundJobContext, conn), ) signals := make(chan os.Signal, 1) @@ -52,7 +54,9 @@ var WebsiteCommand = &cobra.Command{ go func() { timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + logging.Info().Msg("shutting down web server") server.Shutdown(timeout) + logging.Info().Msg("cancelling background jobs") cancelBackgroundJobs() }()