From d92bf9a9b82b5f8acd746c19a5be197d21f664f2 Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Sun, 15 Aug 2021 23:40:56 -0500 Subject: [PATCH] Add Discord account linking --- src/config/types.go | 5 + src/discord/payloads.go | 10 +- src/discord/rest.go | 123 ++++++++++++++- src/hmnurl/urls.go | 16 ++ .../2021-08-16T023440Z_DiscordData.go | 76 ++++++++++ src/models/discord.go | 67 +++++++-- src/templates/mapping.go | 14 ++ src/templates/src/discordtest.html | 14 ++ src/templates/types.go | 6 + src/website/discord.go | 140 ++++++++++++++++++ src/website/routes.go | 3 + 11 files changed, 459 insertions(+), 15 deletions(-) create mode 100644 src/migration/migrations/2021-08-16T023440Z_DiscordData.go create mode 100644 src/templates/src/discordtest.html create mode 100644 src/website/discord.go diff --git a/src/config/types.go b/src/config/types.go index 19485b98..19c01950 100644 --- a/src/config/types.go +++ b/src/config/types.go @@ -67,6 +67,11 @@ type DiscordConfig struct { BotToken string BotUserID string + OAuthClientID string + OAuthClientSecret string + + GuildID string + MemberRoleID string ShowcaseChannelID string LibraryChannelID string } diff --git a/src/discord/payloads.go b/src/discord/payloads.go index 4b5870e1..edb061ae 100644 --- a/src/discord/payloads.go +++ b/src/discord/payloads.go @@ -217,10 +217,12 @@ func MessageFromMap(m interface{}) Message { // https://discord.com/developers/docs/resources/user#user-object type User struct { - ID string `json:"id"` - Username string `json:"username"` - Discriminator string `json:"discriminator"` - IsBot bool `json:"bot"` + ID string `json:"id"` + Username string `json:"username"` + Discriminator string `json:"discriminator"` + Avatar *string `json:"avatar"` + IsBot bool `json:"bot"` + Locale string `json:"locale"` } func UserFromMap(m interface{}) User { diff --git a/src/discord/rest.go b/src/discord/rest.go index 506c7854..6b7e55b7 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "net/http/httputil" + "net/url" + "strings" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/logging" @@ -26,13 +28,17 @@ var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersi var httpClient = &http.Client{} +func buildUrl(path string) string { + return fmt.Sprintf("%s%s", BaseUrl, path) +} + func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request { var bodyReader io.Reader if body != nil { bodyReader = bytes.NewBuffer(body) } - req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader) + req, err := http.NewRequestWithContext(ctx, method, buildUrl(path), bodyReader) if err != nil { panic(err) } @@ -171,6 +177,121 @@ func CreateDM(ctx context.Context, recipientID string) (*Channel, error) { return &channel, nil } +type OAuthCodeExchangeResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` +} + +func ExchangeOAuthCode(ctx context.Context, code, redirectURI string) (*OAuthCodeExchangeResponse, error) { + const name = "OAuth Code Exchange" + + body := make(url.Values) + body.Set("client_id", config.Config.Discord.OAuthClientID) + body.Set("client_secret", config.Config.Discord.OAuthClientSecret) + body.Set("grant_type", "authorization_code") + body.Set("code", code) + body.Set("redirect_uri", redirectURI) + bodyStr := body.Encode() + + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + "https://discord.com/api/oauth2/token", + strings.NewReader(bodyStr), + ) + if err != nil { + panic(err) + } + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var tokenResponse OAuthCodeExchangeResponse + err = json.Unmarshal(bodyBytes, &tokenResponse) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord OAuth token") + } + + return &tokenResponse, nil +} + +func GetCurrentUserAsOAuth(ctx context.Context, accessToken string) (*User, error) { + const name = "Get Current User" + + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildUrl("/users/@me"), nil) + if err != nil { + panic(err) + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + req.Header.Add("User-Agent", UserAgent) + + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var user User + err = json.Unmarshal(bodyBytes, &user) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord user") + } + + return &user, nil +} + +func AddGuildMemberRole(ctx context.Context, userID, roleID string) error { + const name = "Delete Message" + + path := fmt.Sprintf("/guilds/%s/members/%s/roles/%s", config.Config.Discord.GuildID, userID, roleID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodPut, path, nil) + }) + if err != nil { + return err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + logErrorResponse(ctx, name, res, "") + return oops.New(nil, "got unexpected status code when adding role") + } + + return nil +} + func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { dump, err := httputil.DumpResponse(res, true) if err != nil { diff --git a/src/hmnurl/urls.go b/src/hmnurl/urls.go index 8d777f2a..a07d6f37 100644 --- a/src/hmnurl/urls.go +++ b/src/hmnurl/urls.go @@ -545,6 +545,22 @@ func BuildLibraryResource(projectSlug string, resourceId int) string { return ProjectUrl(builder.String(), nil, projectSlug) } +/* +* Discord OAuth + */ + +var RegexDiscordTest = regexp.MustCompile("^/discord$") + +func BuildDiscordTest() string { + return Url("/discord", nil) +} + +var RegexDiscordOAuthCallback = regexp.MustCompile("^/_discord_callback$") + +func BuildDiscordOAuthCallback() string { + return Url("/_discord_callback", nil) +} + /* * Assets */ diff --git a/src/migration/migrations/2021-08-16T023440Z_DiscordData.go b/src/migration/migrations/2021-08-16T023440Z_DiscordData.go new file mode 100644 index 00000000..a1114800 --- /dev/null +++ b/src/migration/migrations/2021-08-16T023440Z_DiscordData.go @@ -0,0 +1,76 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "git.handmade.network/hmn/hmn/src/oops" + "github.com/jackc/pgx/v4" +) + +func init() { + registerMigration(DiscordData{}) +} + +type DiscordData struct{} + +func (m DiscordData) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2021, 8, 16, 2, 34, 40, 0, time.UTC)) +} + +func (m DiscordData) Name() string { + return "DiscordData" +} + +func (m DiscordData) Description() string { + return "Clean up Discord data models" +} + +func (m DiscordData) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + ALTER TABLE handmade_discord RENAME TO handmade_discorduser; + ALTER TABLE handmade_discorduser + ALTER username SET NOT NULL, + ALTER discriminator SET NOT NULL, + ALTER access_token SET NOT NULL, + ALTER refresh_token SET NOT NULL, + ALTER locale SET NOT NULL, + ALTER userid SET NOT NULL, + ALTER expiry SET NOT NULL; + `) + if err != nil { + return oops.New(err, "failed to fix up discord table") + } + + _, err = tx.Exec(ctx, ` + ALTER TABLE handmade_discordmessagecontent + DROP CONSTRAINT handmade_discordmess_discord_id_1acc147f_fk_handmade_, + DROP CONSTRAINT handmade_discordmess_message_id_4dfde67d_fk_handmade_, + ADD FOREIGN KEY (discord_id) REFERENCES handmade_discorduser (id) ON DELETE CASCADE, + ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE; + + ALTER TABLE handmade_discordmessageattachment + DROP CONSTRAINT handmade_discordmess_asset_id_c64a3c31_fk_handmade_, + DROP CONSTRAINT handmade_discordmess_message_id_d39da9b3_fk_handmade_, + ADD FOREIGN KEY (asset_id) REFERENCES handmade_asset (id) ON DELETE CASCADE, + ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE; + + ALTER TABLE handmade_discordmessageembed + DROP CONSTRAINT handmade_discordmess_image_id_9b04bb5f_fk_handmade_, + DROP CONSTRAINT handmade_discordmess_message_id_04f15ce6_fk_handmade_, + DROP CONSTRAINT handmade_discordmess_video_id_1c41289f_fk_handmade_, + ADD FOREIGN KEY (image_id) REFERENCES handmade_asset (id) ON DELETE SET NULL, + ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE, + ADD FOREIGN KEY (video_id) REFERENCES handmade_asset (id) ON DELETE SET NULL; + `) + if err != nil { + return oops.New(err, "failed to fix constraints") + } + + return nil +} + +func (m DiscordData) Down(ctx context.Context, tx pgx.Tx) error { + panic("Implement me") +} diff --git a/src/models/discord.go b/src/models/discord.go index 18433417..4b2eeb74 100644 --- a/src/models/discord.go +++ b/src/models/discord.go @@ -2,8 +2,65 @@ package models import ( "time" + + "github.com/google/uuid" ) +type DiscordUser struct { + ID int `db:"id"` + Username string `db:"username"` + Discriminator string `db:"discriminator"` + AccessToken string `db:"access_token"` + RefreshToken string `db:"refresh_token"` + Avatar *string `db:"avatar"` + Locale string `db:"locale"` + UserID string `db:"userid"` + Expiry time.Time `db:"expiry"` + HMNUserId int `db:"hmn_user_id"` +} + +/* +Logs the existence of a Discord message and what we've done with it. +Created unconditionally for all users, regardless of link status. +Therefore, it must not contain any actual content. +*/ +type DiscordMessage struct { + ID string `db:"id"` + ChannelID string `db:"channel_id"` + GuildID *string `db:"guild_id"` + Url string `db:"url"` + UserID string `db:"user_id"` + SentAt time.Time `db:"sent_at"` + SnippetCreated bool `db:"snippet_created"` +} + +/* +Stores the content of a Discord message for users with a linked +Discord account. Always created for users with a linked Discord +account, regardless of whether we create snippets or not. +*/ +type DiscordMessageContent struct { + MessageID string `db:"message_id"` + LastContent string `db:"last_content"` + DiscordID int `db:"discord_id"` +} + +type DiscordMessageAttachment struct { + ID string `db:"id"` + AssetID uuid.UUID `db:"asset_id"` + MessageID string `db:"message_id"` +} + +type DiscordMessageEmbed struct { + ID int `db:"id"` + Title *string `db:"title"` + Description *string `db:"description"` + URL *string `db:"url"` + ImageID *uuid.UUID `db:"image_id"` + MessageID string `db:"message_id"` + VideoID *uuid.UUID `db:"video_id"` +} + type DiscordSession struct { ID string `db:"session_id"` SequenceNumber int `db:"sequence_number"` @@ -15,13 +72,3 @@ type DiscordOutgoingMessage struct { PayloadJSON string `db:"payload_json"` ExpiresAt time.Time `db:"expires_at"` } - -type DiscordMessage struct { - ID string `db:"id"` - ChannelID string `db:"channel_id"` - GuildID *string `db:"guild_id"` - Url string `db:"url"` - UserID string `db:"user_id"` - SentAt time.Time `db:"sent_at"` - SnippetCreated bool `db:"snippet_created"` -} diff --git a/src/templates/mapping.go b/src/templates/mapping.go index 9c14599f..33981615 100644 --- a/src/templates/mapping.go +++ b/src/templates/mapping.go @@ -1,6 +1,7 @@ package templates import ( + "fmt" "html/template" "net" "regexp" @@ -331,6 +332,19 @@ func PodcastEpisodeToTemplate(projectSlug string, episode *models.PodcastEpisode } } +func DiscordUserToTemplate(d *models.DiscordUser) DiscordUser { + var avatarUrl string // TODO: Default avatar image + if d.Avatar != nil { + avatarUrl = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", d.UserID, *d.Avatar) + } + + return DiscordUser{ + Username: d.Username, + Discriminator: d.Discriminator, + Avatar: avatarUrl, + } +} + func maybeString(s *string) string { if s == nil { return "" diff --git a/src/templates/src/discordtest.html b/src/templates/src/discordtest.html new file mode 100644 index 00000000..c899ebff --- /dev/null +++ b/src/templates/src/discordtest.html @@ -0,0 +1,14 @@ +{{ template "base.html" . }} + +{{ define "content" }} +Wow a discord + +
+ {{ with .DiscordUser }} + + {{ .Username }}#{{ .Discriminator }} + {{ else }} + Link your account + {{ end }} +
+{{ end }} \ No newline at end of file diff --git a/src/templates/types.go b/src/templates/types.go index d4edfae7..bcc90a96 100644 --- a/src/templates/types.go +++ b/src/templates/types.go @@ -303,3 +303,9 @@ type EmailBaseData struct { Subject template.HTML Separator template.HTML } + +type DiscordUser struct { + Username string + Discriminator string + Avatar string +} diff --git a/src/website/discord.go b/src/website/discord.go new file mode 100644 index 00000000..1f005d38 --- /dev/null +++ b/src/website/discord.go @@ -0,0 +1,140 @@ +package website + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "git.handmade.network/hmn/hmn/src/auth" + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" + "git.handmade.network/hmn/hmn/src/hmnurl" + "git.handmade.network/hmn/hmn/src/models" + "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/templates" +) + +func DiscordTest(c *RequestContext) ResponseData { + var userDiscord *models.DiscordUser + iUserDiscord, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{}, + ` + SELECT $columns + FROM handmade_discorduser + WHERE hmn_user_id = $1 + `, + c.CurrentUser.ID, + ) + if err != nil { + if errors.Is(err, db.ErrNoMatchingRows) { + // we're ok, just no user + } else { + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current user's Discord account")) + } + } else { + userDiscord = iUserDiscord.(*models.DiscordUser) + } + + type templateData struct { + templates.BaseData + DiscordUser *templates.DiscordUser + AuthorizeURL string + } + + baseData := getBaseData(c) + baseData.Title = "Discord Test" + + params := make(url.Values) + params.Set("response_type", "code") + params.Set("client_id", config.Config.Discord.OAuthClientID) + params.Set("scope", "identify") + params.Set("state", c.CurrentSession.CSRFToken) + params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback()) + + td := templateData{ + BaseData: baseData, + AuthorizeURL: fmt.Sprintf("https://discord.com/api/oauth2/authorize?%s", params.Encode()), + } + + if userDiscord != nil { + u := templates.DiscordUserToTemplate(userDiscord) + td.DiscordUser = &u + } + + var res ResponseData + res.MustWriteTemplate("discordtest.html", td, c.Perf) + return res +} + +func DiscordOAuthCallback(c *RequestContext) ResponseData { + query := c.Req.URL.Query() + + // Check the state + state := query.Get("state") + if state != c.CurrentSession.CSRFToken { + // CSRF'd!!!! + + // TODO(compression): Should this and the CSRF middleware be pulled out to + // a separate function? + + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + + err := auth.DeleteSession(c.Context(), c.Conn, c.CurrentSession.ID) + if err != nil { + c.Logger.Error().Err(err).Msg("failed to delete session on Discord OAuth state failure") + } + + res := c.Redirect("/", http.StatusSeeOther) + res.SetCookie(auth.DeleteSessionCookie) + + return res + } + + // Check for error values and redirect back to ???? + if query.Get("error") != "" { + // TODO: actually handle these errors + return ErrorResponse(http.StatusBadRequest, errors.New(query.Get("error"))) + } + + // Do the actual token exchange and redirect back to ???? + code := query.Get("code") + res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) // TODO: Redirect to the right place + if err != nil { + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) + } + expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) + + user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken) + if err != nil { + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) + } + + // TODO: Add the role on Discord + err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID) + if err != nil { + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) + } + + _, err = c.Conn.Exec(c.Context(), + ` + INSERT INTO handmade_discorduser (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + `, + user.Username, + user.Discriminator, + res.AccessToken, + res.RefreshToken, + user.Avatar, + user.Locale, + user.ID, + expiry, + c.CurrentUser.ID, + ) + if err != nil { + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info")) + } + + return c.Redirect(hmnurl.BuildDiscordTest(), http.StatusSeeOther) +} diff --git a/src/website/routes.go b/src/website/routes.go index 5397fa12..fe2f963c 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -197,6 +197,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode) mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS) + mainRoutes.GET(hmnurl.RegexDiscordTest, authMiddleware(DiscordTest)) // TODO: Delete this route + mainRoutes.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback)) + mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS) mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData { var res ResponseData