From 0210a0784be3ff859540d5d9fd3fdf6c0fd9a6f7 Mon Sep 17 00:00:00 2001 From: bvisness Date: Sat, 6 May 2023 19:38:50 +0000 Subject: [PATCH] Add Discord login (#106) This leverages our existing Discord OAuth implementation. Any users with a linked Discord account will be able to log in immediately. When logging in, we request the `email` scope in addition to `identity`, so existing users will be prompted one time to accept the new permissions. On subsequent logins, Discord will skip the prompt. When linking your Discord account to an existing HMN account, we continue to only request the `identity` scope, so we do not receive the user's Discord email. Both login and linking go through the same Discord OAuth callback. All flows through the callback try to achieve the same end goal: a logged-in HMN user with a linked Discord account. Linking works the same as it ever has. Login, however, is different because we do not have a session ID to use as the OAuth state. To account for this, I have added a `pending_login` table that stores a secure unique ID and the eventual destination URL. These pending logins expire after 10 minutes. When we receive the OAuth callback, we look up the pending login by the OAuth `state` and immediately delete it. The destination URL will be used to redirect the user to the right place. If we have a `discord_user` entry for the OAuth'd Discord user, we immediately log the user into the associated HMN account. This is the typical login case. If we do not have a `discord_user`, but there is exactly one HMN user with the same email address as the Discord user, we will link the two accounts and log into the HMN account. (It is possible for multiple HMN accounts to have the same email, because we don't have a uniqueness constraint there. We fail the login in this case rather than link to the wrong account.) Finally, if no associated HMN user exists, a new one will be created. It will use the Discord user's username, email, and avatar. This user will have no password, but they can set or reset a password through the usual flows. Co-authored-by: Ben Visness Reviewed-on: https://git.handmade.network/hmn/hmn/pulls/106 --- public/discord-login.svg | 24 ++ public/style.css | 24 +- src/auth/session.go | 39 ++- src/discord/message_handling.go | 9 +- src/discord/payloads.go | 11 +- src/discord/ratelimiting.go | 10 +- src/discord/rest.go | 10 +- src/hmnurl/urls.go | 7 + .../2023-05-04T024712Z_AddPendingSignups.go | 53 +++ src/migration/seed.go | 2 +- src/models/user.go | 6 + src/rawdata/scss/_header.scss | 26 +- src/templates/src/auth_login.html | 12 + src/templates/src/include/header.html | 22 +- src/templates/src/user_settings.html | 14 +- src/templates/types.go | 15 +- src/twitch/twitch.go | 2 +- src/utils/utils.go | 16 + src/website/auth.go | 48 ++- src/website/base_data.go | 13 +- src/website/discord.go | 316 +++++++++++++++--- src/website/routes.go | 3 +- src/website/user.go | 50 +-- src/website/website.go | 2 +- 24 files changed, 586 insertions(+), 148 deletions(-) create mode 100644 public/discord-login.svg create mode 100644 src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go diff --git a/public/discord-login.svg b/public/discord-login.svg new file mode 100644 index 0000000..0e1c8fb --- /dev/null +++ b/public/discord-login.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/public/style.css b/public/style.css index e2f0530..574382b 100644 --- a/public/style.css +++ b/public/style.css @@ -4602,7 +4602,7 @@ code, .code { .pa2, .tab, figure, header .root-item > a, header .submenu > a { padding: 0.5rem; } -.pa3, header #login-popup { +.pa3 { padding: 1rem; } .pa4 { @@ -7422,7 +7422,7 @@ article code { color: #ccc; color: var(--theme-color-dimmest); } -.b--dimmest, .optionbar, blockquote, .post-content th, .post-content td, header #login-popup { +.b--dimmest, .optionbar, blockquote, .post-content th, .post-content td { border-color: #bbb; border-color: var(--dimmest-color); } @@ -8936,25 +8936,19 @@ header #login-popup { background-color: var(--login-popup-background); color: black; color: var(--fg-font-color); - border-width: 1px; - border-style: dashed; visibility: hidden; position: absolute; z-index: 12; - margin-top: 10px; - right: 0px; - top: 20px; - width: 290px; - max-height: 0px; overflow: hidden; - opacity: 0; - transition: all 0.2s; } + right: 0; + top: 100%; + width: 100%; } header #login-popup.open { - max-height: 170px; - opacity: 1; visibility: visible; } - header #login-popup label { - padding-right: 10px; } + @media screen and (min-width: 35em) { + header #login-popup { + top: 2.2rem; + width: 17rem; } } @font-face { font-family: icons; diff --git a/src/auth/session.go b/src/auth/session.go index 997704d..61c0648 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -24,7 +24,7 @@ const CSRFFieldName = "csrf_token" const sessionDuration = time.Hour * 24 * 14 -func makeSessionId() string { +func MakeSessionId() string { idBytes := make([]byte, 40) _, err := io.ReadFull(rand.Reader, idBytes) if err != nil { @@ -47,7 +47,16 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM session WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, + ` + SELECT $columns + FROM session + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + id, + ) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -61,7 +70,7 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses func CreateSession(ctx context.Context, conn *pgxpool.Pool, username string) (*models.Session, error) { session := models.Session{ - ID: makeSessionId(), + ID: MakeSessionId(), Username: username, ExpiresAt: time.Now().Add(sessionDuration), CSRFToken: makeCSRFToken(), @@ -134,7 +143,16 @@ func DeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) (int64, erro return tag.RowsAffected(), nil } -func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) jobs.Job { +func DeleteExpiredPendingLogins(ctx context.Context, conn *pgxpool.Pool) (int64, error) { + tag, err := conn.Exec(ctx, "DELETE FROM pending_login WHERE expires_at <= CURRENT_TIMESTAMP") + if err != nil { + return 0, oops.New(err, "failed to delete expired pending logins") + } + + return tag.RowsAffected(), nil +} + +func PeriodicallyDeleteExpiredStuff(ctx context.Context, conn *pgxpool.Pool) jobs.Job { job := jobs.New() go func() { defer job.Done() @@ -145,6 +163,7 @@ func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) case <-t.C: err := func() (err error) { defer utils.RecoverPanicAsError(&err) + n, err := DeleteExpiredSessions(ctx, conn) if err == nil { if n > 0 { @@ -153,10 +172,20 @@ func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) } else { logging.Error().Err(err).Msg("Failed to delete expired sessions") } + + n, err = DeleteExpiredPendingLogins(ctx, conn) + if err == nil { + if n > 0 { + logging.Info().Int64("num deleted pending logins", n).Msg("Deleted expired pending logins") + } + } else { + logging.Error().Err(err).Msg("Failed to delete expired pending logins") + } + return nil }() if err != nil { - logging.Error().Err(err).Msg("Panicked in PeriodicallyDeleteExpiredSessions") + logging.Error().Err(err).Msg("Panicked in PeriodicallyDeleteExpiredStuff") } case <-ctx.Done(): return diff --git a/src/discord/message_handling.go b/src/discord/message_handling.go index cfd9df1..6cbc2da 100644 --- a/src/discord/message_handling.go +++ b/src/discord/message_handling.go @@ -431,7 +431,7 @@ var discordDownloadClient = &http.Client{ type DiscordResourceBadStatusCode error -func downloadDiscordResource(ctx context.Context, url string) ([]byte, string, error) { +func DownloadDiscordResource(ctx context.Context, url string) ([]byte, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, "", oops.New(err, "failed to make Discord download request") @@ -491,7 +491,7 @@ func saveAttachment( height = *attachment.Height } - content, _, err := downloadDiscordResource(ctx, attachment.Url) + content, _, err := DownloadDiscordResource(ctx, attachment.Url) if err != nil { return nil, oops.New(err, "failed to download Discord attachment") } @@ -561,7 +561,7 @@ func saveEmbed( } maybeSaveImageish := func(i EmbedImageish, contentTypeCheck func(string) bool) (*uuid.UUID, error) { - content, contentType, err := downloadDiscordResource(ctx, *i.Url) + content, contentType, err := DownloadDiscordResource(ctx, *i.Url) if err != nil { var statusError DiscordResourceBadStatusCode if errors.As(err, &statusError) { @@ -838,7 +838,8 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in } // TODO(asaf): I believe this will also match https://example.com?hello=1&whatever=5 -// Probably need to add word boundaries. +// +// Probably need to add word boundaries. var REDiscordTag = regexp.MustCompile(`&([a-zA-Z0-9]+(-[a-zA-Z0-9]+)*)`) func getDiscordTags(content string) []string { diff --git a/src/discord/payloads.go b/src/discord/payloads.go index d477d22..c103894 100644 --- a/src/discord/payloads.go +++ b/src/discord/payloads.go @@ -346,6 +346,7 @@ type User struct { Avatar *string `json:"avatar"` IsBot bool `json:"bot"` Locale string `json:"locale"` + Email string `json:"email"` } func UserFromMap(m interface{}, k string) *User { @@ -387,8 +388,9 @@ func GuildFromMap(m interface{}, k string) *Guild { // https://discord.com/developers/docs/resources/guild#guild-member-object type GuildMember struct { - User *User `json:"user"` - Nick *string `json:"nick"` + User *User `json:"user"` + Nick *string `json:"nick"` + Avatar *string `json:"avatar"` // more fields not yet handled here } @@ -409,8 +411,9 @@ func GuildMemberFromMap(m interface{}, k string) *GuildMember { } gm := &GuildMember{ - User: UserFromMap(m, "user"), - Nick: maybeStringP(mmap, "nick"), + User: UserFromMap(m, "user"), + Nick: maybeStringP(mmap, "nick"), + Avatar: maybeStringP(mmap, "avatar"), } return gm diff --git a/src/discord/ratelimiting.go b/src/discord/ratelimiting.go index 8d74e1d..1f9e44b 100644 --- a/src/discord/ratelimiting.go +++ b/src/discord/ratelimiting.go @@ -110,7 +110,7 @@ func createLimiter(headers rateLimitHeaders, routeName string) { buckets.Store(routeName, headers.Bucket) ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{ - requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts + requests: make(chan struct{}, 200), // presumably this is big enough to handle bursts refills: make(chan rateLimiterRefill), }) if !loaded { @@ -124,7 +124,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("remaining", headers.Remaining). + Msg("rate limiting channel was too small; you should increase the default capacity") break prefillloop } } @@ -158,7 +160,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("maxRequests", refill.maxRequests). + Msg("rate limiting channel was too small; you should increase the default capacity") break refillloop } } diff --git a/src/discord/rest.go b/src/discord/rest.go index 9124858..417be53 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -652,11 +652,17 @@ func EditOriginalInteractionResponse(ctx context.Context, interactionToken strin return &msg, nil } -func GetAuthorizeUrl(state string) string { +func GetAuthorizeUrl(state string, includeEmail bool) string { + scope := "identify" + if includeEmail { + scope = "identify email" + } + params := make(url.Values) params.Set("response_type", "code") params.Set("client_id", config.Config.Discord.OAuthClientID) - params.Set("scope", "identify") + params.Set("scope", scope) + params.Set("prompt", "none") // immediately redirect back to HMN if already authorized params.Set("state", state) params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback()) return fmt.Sprintf("%s?%s", buildUrl("/oauth2/authorize"), params.Encode()) diff --git a/src/hmnurl/urls.go b/src/hmnurl/urls.go index f175a7d..a1fae05 100644 --- a/src/hmnurl/urls.go +++ b/src/hmnurl/urls.go @@ -121,6 +121,13 @@ func BuildLoginPage(redirectTo string) string { return Url("/login", []Q{{Name: "redirect", Value: redirectTo}}) } +var RegexLoginWithDiscord = regexp.MustCompile("^/login-with-discord$") + +func BuildLoginWithDiscord(redirectTo string) string { + defer CatchPanic() + return Url("/login-with-discord", []Q{{Name: "redirect", Value: redirectTo}}) +} + var RegexLogoutAction = regexp.MustCompile("^/logout$") func BuildLogoutAction(redir string) string { diff --git a/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go new file mode 100644 index 0000000..ed59719 --- /dev/null +++ b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go @@ -0,0 +1,53 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "github.com/jackc/pgx/v5" +) + +func init() { + registerMigration(AddPendingSignups{}) +} + +type AddPendingSignups struct{} + +func (m AddPendingSignups) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2023, 5, 4, 2, 47, 12, 0, time.UTC)) +} + +func (m AddPendingSignups) Name() string { + return "AddPendingSignups" +} + +func (m AddPendingSignups) Description() string { + return "Adds the pending login table" +} + +func (m AddPendingSignups) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, + ` + CREATE TABLE pending_login ( + id VARCHAR(40) NOT NULL PRIMARY KEY, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + destination_url VARCHAR(999) NOT NULL + ) + `, + ) + if err != nil { + return err + } + + return nil +} + +func (m AddPendingSignups) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, `DROP TABLE pending_login`) + if err != nil { + return err + } + + return nil +} diff --git a/src/migration/seed.go b/src/migration/seed.go index 0b6aff7..4d89e6c 100644 --- a/src/migration/seed.go +++ b/src/migration/seed.go @@ -189,7 +189,7 @@ func seedUser(ctx context.Context, conn db.ConnOrTx, input models.User) *models. $6, $7, $8, $9, TRUE, $10, - '2017-01-01T00:00:00Z', '192.168.2.1', null + '2017-01-01T00:00:00Z', '192.168.2.1', NULL ) RETURNING $columns `, diff --git a/src/models/user.go b/src/models/user.go index 5caaf6d..d838324 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -71,3 +71,9 @@ func (u *User) CanSeeUnpublishedEducationContent() bool { func (u *User) CanAuthorEducation() bool { return u.IsStaff || u.EducationRole == EduRoleAuthor } + +type PendingLogin struct { + ID string `db:"id"` + ExpiresAt time.Time `db:"expires_at"` + DestinationUrl string `db:"destination_url"` +} diff --git a/src/rawdata/scss/_header.scss b/src/rawdata/scss/_header.scss index 3a403e0..1460410 100644 --- a/src/rawdata/scss/_header.scss +++ b/src/rawdata/scss/_header.scss @@ -130,33 +130,21 @@ header { @include usevar(background-color, login-popup-background); @include usevar(color, fg-font-color); - @extend .pa3; - - border-width: 1px; - border-style: dashed; - @extend .b--dimmest; - visibility: hidden; position: absolute; z-index: 12; - margin-top: 10px; - right: 0px; - top: 20px; - width: 290px; - max-height: 0px; overflow: hidden; - opacity: 0; - - transition: all 0.2s; + right: 0; + top: 100%; + width: 100%; &.open { - max-height: 170px; - opacity: 1; visibility: visible; } - - label { - padding-right:10px; + + @media #{$breakpoint-not-small} { + top: 2.2rem; + width: 17rem; } } } diff --git a/src/templates/src/auth_login.html b/src/templates/src/auth_login.html index 666611f..4a40738 100644 --- a/src/templates/src/auth_login.html +++ b/src/templates/src/auth_login.html @@ -37,6 +37,18 @@
Need an account? Sign up.
+ +
+
Third-party login
+
+ + Log in with Discord + +
+
diff --git a/src/templates/src/include/header.html b/src/templates/src/include/header.html index 60f4532..751dfaa 100644 --- a/src/templates/src/include/header.html +++ b/src/templates/src/include/header.html @@ -13,14 +13,25 @@ {{ else }} Register -
-
+
+ - Forgot your password? + Forgot your password? -
- +
+ +
+
+
Third-party login
+
+ + Log in with Discord + +
@@ -166,7 +177,6 @@ const loginLink = document.getElementById("login-link"); if (loginPopup !== null) { - loginLink.onclick = (e) => { e.preventDefault(); loginPopup.classList.toggle("open"); diff --git a/src/templates/src/user_settings.html b/src/templates/src/user_settings.html index 68ef0fc..7b82eca 100644 --- a/src/templates/src/user_settings.html +++ b/src/templates/src/user_settings.html @@ -81,18 +81,20 @@
-
-
Old password:
-
- + {{ if .HasPassword }} +
+
Old password:
+
+ +
-
+ {{ end }}
New password:
- Your password must be 8 or more characters, and must differ from your username and current password. + Your password must be 8 or more characters, and must differ from your username{{ if .HasPassword }} and current password{{ end }}. Other than that, please follow best practices.
diff --git a/src/templates/types.go b/src/templates/types.go index acbf187..d72acf0 100644 --- a/src/templates/types.go +++ b/src/templates/types.go @@ -38,13 +38,14 @@ func (bd *BaseData) AddImmediateNotice(class, content string) { } type Header struct { - AdminUrl string - UserProfileUrl string - UserSettingsUrl string - LoginActionUrl string - LogoutActionUrl string - ForgotPasswordUrl string - RegisterUrl string + AdminUrl string + UserProfileUrl string + UserSettingsUrl string + LoginActionUrl string + LogoutActionUrl string + ForgotPasswordUrl string + RegisterUrl string + LoginWithDiscordUrl string HMNHomepageUrl string ProjectIndexUrl string diff --git a/src/twitch/twitch.go b/src/twitch/twitch.go index 90b6adc..938ada1 100644 --- a/src/twitch/twitch.go +++ b/src/twitch/twitch.go @@ -931,7 +931,7 @@ func updateStreamHistory(ctx context.Context, dbConn db.ConnOrTx, status *models ` INSERT INTO twitch_stream_history (stream_id, twitch_id, twitch_login, started_at, stream_ended, ended_at, end_approximated, title, category_id, tags, discord_needs_update) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (stream_id) DO UPDATE SET stream_ended = EXCLUDED.stream_ended, ended_at = EXCLUDED.ended_at, diff --git a/src/utils/utils.go b/src/utils/utils.go index 6df2c0d..ea2fd3d 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -104,3 +104,19 @@ func SleepContext(ctx context.Context, d time.Duration) error { return nil } } + +// Panics if the provided value is falsy (so, zero). This works for booleans +// but also normal values, through the magic of generics. +func Assert[T comparable](value T, msg ...any) { + var zero T + if value == zero { + finalMsg := "" + for i, arg := range msg { + if i > 0 { + finalMsg += " " + } + finalMsg += fmt.Sprintf("%v", arg) + } + panic(finalMsg) + } +} diff --git a/src/website/auth.go b/src/website/auth.go index 6a308ee..c8ab4f3 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -11,6 +11,7 @@ import ( "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/email" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/logging" @@ -23,9 +24,10 @@ var UsernameRegex = regexp.MustCompile(`^[0-9a-zA-Z][\w-]{2,29}$`) type LoginPageData struct { templates.BaseData - RedirectUrl string - RegisterUrl string - ForgotPasswordUrl string + RedirectUrl string + RegisterUrl string + ForgotPasswordUrl string + LoginWithDiscordUrl string } func LoginPage(c *RequestContext) ResponseData { @@ -37,10 +39,11 @@ func LoginPage(c *RequestContext) ResponseData { var res ResponseData res.MustWriteTemplate("auth_login.html", LoginPageData{ - BaseData: getBaseData(c, "Log in", nil), - RedirectUrl: redirect, - RegisterUrl: hmnurl.BuildRegister(redirect), - ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + BaseData: getBaseData(c, "Log in", nil), + RedirectUrl: redirect, + RegisterUrl: hmnurl.BuildRegister(redirect), + ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + LoginWithDiscordUrl: hmnurl.BuildLoginWithDiscord(redirect), }, c.Perf) return res } @@ -120,6 +123,28 @@ func Login(c *RequestContext) ResponseData { return res } +func LoginWithDiscord(c *RequestContext) ResponseData { + destinationUrl := c.URL().Query().Get("redirect") + if c.CurrentUser != nil { + return c.Redirect(destinationUrl, http.StatusSeeOther) + } + + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + INSERT INTO pending_login (id, expires_at, destination_url) + VALUES ($1, $2, $3) + RETURNING $columns + `, + auth.MakeSessionId(), time.Now().Add(time.Minute*10), destinationUrl, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save pending login")) + } + + discordAuthUrl := discord.GetAuthorizeUrl(pendingLogin.ID, true) + return c.Redirect(discordAuthUrl, http.StatusSeeOther) +} + func Logout(c *RequestContext) ResponseData { redirect := c.Req.URL.Query().Get("redirect") @@ -460,7 +485,8 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali } // NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email, -// not to changing your password through the user settings page. +// +// not to changing your password through the user settings page. func RequestPasswordReset(c *RequestContext) ResponseData { if c.CurrentUser != nil { return c.Redirect(hmnurl.BuildHomepage(), http.StatusSeeOther) @@ -717,7 +743,11 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro defer c.Perf.EndBlock() hashed, err := auth.ParsePasswordString(user.Password) if err != nil { - return false, oops.New(err, "failed to parse password string") + if user.Password == "" { + return false, nil + } else { + return false, oops.New(err, "failed to parse password string") + } } passwordsMatch, err := auth.CheckPassword(password, hashed) diff --git a/src/website/base_data.go b/src/website/base_data.go index d6a95c6..8f391c5 100644 --- a/src/website/base_data.go +++ b/src/website/base_data.go @@ -63,12 +63,13 @@ func getBaseData(c *RequestContext, title string, breadcrumbs []templates.Breadc IsProjectPage: !project.IsHMN(), Header: templates.Header{ - AdminUrl: hmnurl.BuildAdminApprovalQueue(), // TODO(asaf): Replace with general-purpose admin page - UserSettingsUrl: hmnurl.BuildUserSettings(""), - LoginActionUrl: hmnurl.BuildLoginAction(c.FullUrl()), - LogoutActionUrl: hmnurl.BuildLogoutAction(c.FullUrl()), - ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), - RegisterUrl: hmnurl.BuildRegister(""), + AdminUrl: hmnurl.BuildAdminApprovalQueue(), // TODO(asaf): Replace with general-purpose admin page + UserSettingsUrl: hmnurl.BuildUserSettings(""), + LoginActionUrl: hmnurl.BuildLoginAction(c.FullUrl()), + LogoutActionUrl: hmnurl.BuildLogoutAction(c.FullUrl()), + ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + RegisterUrl: hmnurl.BuildRegister(""), + LoginWithDiscordUrl: hmnurl.BuildLoginWithDiscord(c.FullUrl()), HMNHomepageUrl: hmnurl.BuildHomepage(), ProjectIndexUrl: hmnurl.BuildProjectIndex(1), diff --git a/src/website/discord.go b/src/website/discord.go index b9db312..52345e5 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -1,39 +1,115 @@ package website import ( + "context" "errors" + "fmt" "net/http" + "strings" "time" + "git.handmade.network/hmn/hmn/src/assets" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/utils" + "github.com/google/uuid" ) +// This callback handles Discord account linking whether the user is signed in +// or not. In all cases, the end state is that the user is signed into a +// Handmade Network account with a linked Discord account. HMN accounts will be +// created as necessary. +// +// If we initiate OAuth while logged in, we will use the current session's CSRF +// token as the OAuth state. Otherwise, we will generate a new entry in the +// pending_login table with an equivalently random token and use that token for +// the state. +// +// Considerations: +// +// | | Already signed in | Not signed in | +// |-----------------------|----------------------|-------------------------------| +// | No matching info | Link Discord account | Create HMN account | +// |-----------------------| to current HMN |-------------------------------| +// | Matching Discord user | account (stealing is | Log into HMN account and link | +// |-----------------------| ok, but make sure | Discord user to it. (Double- | +// | One matching email | any other accounts | check Discord link settings.) | +// |-----------------------| are unlinked) |-------------------------------| +// | More than one | | Fail login | +// | matching email | | | +// |-----------------------|----------------------|-------------------------------| func DiscordOAuthCallback(c *RequestContext) ResponseData { query := c.Req.URL.Query() - // Check the state + var destinationUrl string + + tx, err := c.Conn.Begin(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start transaction for Discord OAuth")) + } + defer tx.Rollback(c) + + // Check the state, figure out where we're going state := query.Get("state") - if state != c.CurrentSession.CSRFToken { - // CSRF'd!!!! + if c.CurrentUser == nil { + // Check the state against all our pending signins - if none is found, + // then CSRF'd!!!! (or the login just expired) + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + SELECT $columns + FROM pending_login + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + state, + ) + if err == db.NotFound { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up pending login")) + } + destinationUrl = pendingLogin.DestinationUrl - c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") - - res := c.Redirect("/", http.StatusSeeOther) - logoutUser(c, &res) - - return res + // Delete the pending login; we're done with it + _, err = tx.Exec(c, `DELETE FROM pending_login WHERE id = $1`, pendingLogin.ID) + if err != nil { + c.Logger.Warn().Str("id", pendingLogin.ID).Err(err).Msg("failed to delete pending login") + } + } else { + // Check the state against the current session - if it does not match, + // then CSRF'd!!!! + if state != c.CurrentSession.CSRFToken { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } + // The only way into OAuth when logged in is when linking your Discord + // account in settings. + destinationUrl = hmnurl.BuildUserSettings("discord") } - // Check for error values and redirect back to user settings + // Check for error values and redirect back to from whence they came if errCode := query.Get("error"); errCode != "" { if errCode == "access_denied" { - // This occurs when the user cancels. Just go back to the profile page. - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + // This occurs when the user cancels. Just go back so they can try again. + var dest string + if c.CurrentUser == nil { + // Send 'em back to the login page for another go, with the + // same destination + dest = hmnurl.BuildLoginPage(destinationUrl) + } else { + dest = hmnurl.BuildUserSettings("discord") + } + return c.Redirect(dest, http.StatusSeeOther) } else { return c.RejectRequest("Failed to authenticate with Discord.") } @@ -41,60 +117,198 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { // Do the actual token exchange code := query.Get("code") - res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) + authRes, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) } - expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) + expiry := time.Now().Add(time.Duration(authRes.ExpiresIn) * time.Second) - user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken) + user, err := discord.GetCurrentUserAsOAuth(c, authRes.AccessToken) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) } - // Add the role on Discord - err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) + hmnMember, err := discord.GetGuildMember(c, config.Config.Discord.GuildID, user.ID) if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) + if err == discord.NotFound { + // nothing, this is fine + } else { + c.Logger.Error().Err(err).Msg("failed to get HMN Discord member for Discord user") + } } - // Add the user to our database - _, err = c.Conn.Exec(c, + // Make the necessary updates in our database (see table above) + + // Determine which HMN user to associate this Discord login with. This + // may not turn anything up, in which case we need to make an account. + var hmnUser *models.User + if c.CurrentUser != nil { + hmnUser = c.CurrentUser + } else { + utils.Assert(user.Email, "didn't get an email from Discord! bad scopes?") + + userFromDiscordID, err := db.QueryOne[models.User](c, tx, + ` + SELECT $columns{hmn_user} + FROM + discord_user + JOIN hmn_user ON discord_user.hmn_user_id = hmn_user.id + WHERE userid = $1 + `, + user.ID, + ) + if err == nil { + hmnUser = userFromDiscordID + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user from Discord OAuth")) + } + + if hmnUser == nil { + usersFromDiscordEmail, err := db.Query[models.User](c, tx, + ` + SELECT $columns + FROM hmn_user + WHERE + LOWER(email) = LOWER($1) + `, + user.Email, + ) + if err == nil { + if len(usersFromDiscordEmail) > 1 { + // oh no why don't we have a unique constraint on emails + return c.RejectRequest("There are multiple Handmade Network accounts with this email address. Please sign into one of them separately.") + } else if len(usersFromDiscordEmail) == 1 { + hmnUser = usersFromDiscordEmail[0] + } + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user by email")) + } + } + } + + // Create a new HMN account if no existing account matches + if hmnUser == nil { + // Check if an HMN account already has this username. We don't link + // in this case because usernames can be changed and we don't want + // account takeovers. + usernameTaken, err := db.QueryOneScalar[bool](c, tx, + `SELECT COUNT(*) > 0 FROM hmn_user WHERE LOWER(username) = LOWER($1)`, + user.Username, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check if username was taken when logging in with Discord")) + } + if usernameTaken { + return c.RejectRequest(fmt.Sprintf("There is already a Handmade Network account with the username \"%s\".", user.Username)) + } + + var displayName string + if hmnMember != nil && hmnMember.Nick != nil { + displayName = *hmnMember.Nick + } + + var avatarHash *string + if hmnMember != nil && hmnMember.Avatar != nil { + avatarHash = hmnMember.Avatar + } else if user.Avatar != nil { + avatarHash = user.Avatar + } + + var avatarAssetID *uuid.UUID + if avatarHash != nil { + // Note! Not using the transaction here. Don't want to fail the login due to avatars. + if avatarAsset, err := saveDiscordAvatar(c, c.Conn, user.ID, *user.Avatar); err == nil { + avatarAssetID = &avatarAsset.ID + } else { + c.Logger.Warn().Err(err).Msg("failed to save Discord avatar") + } + } + + newHMNUser, err := db.QueryOne[models.User](c, tx, + ` + INSERT INTO hmn_user ( + username, name, email, password, avatar_asset_id, date_joined, registration_ip + ) VALUES ( + $1, $2, $3, '', $4, $5, $6 + ) + RETURNING $columns + `, + user.Username, displayName, strings.ToLower(user.Email), avatarAssetID, time.Now(), c.GetIP(), + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new HMN user for Discord login")) + } + hmnUser = newHMNUser + } + + // Add the Discord user data to our database + _, err = tx.Exec(c, ` - INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + INSERT INTO + discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (userid) DO UPDATE SET + username = EXCLUDED.username, + discriminator = EXCLUDED.discriminator, + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + avatar = EXCLUDED.avatar, + locale = EXCLUDED.locale, + expiry = EXCLUDED.expiry, + hmn_user_id = EXCLUDED.hmn_user_id `, user.Username, user.Discriminator, - res.AccessToken, - res.RefreshToken, + authRes.AccessToken, + authRes.RefreshToken, user.Avatar, user.Locale, user.ID, expiry, - c.CurrentUser.ID, + hmnUser.ID, ) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info")) } - if c.CurrentUser.Status == models.UserStatusConfirmed { - _, err = c.Conn.Exec(c, - ` - UPDATE hmn_user - SET status = $1 - WHERE id = $2 - `, - models.UserStatusApproved, - c.CurrentUser.ID, - ) + // Mark the HMN user as confirmed - Discord is good enough auth for us + _, err = tx.Exec(c, + ` + UPDATE hmn_user + SET status = $1 + WHERE id = $2 + `, + models.UserStatusApproved, + hmnUser.ID, + ) + if err != nil { + c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") + // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. + } + + err = tx.Commit(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save updates from Discord OAuth")) + } + + // Add the role on Discord + if hmnMember != nil { + err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) if err != nil { - c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") - // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. + c.Logger.Error().Err(err).Msg("failed to add member role") } } - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + res := c.Redirect(destinationUrl, http.StatusSeeOther) + err = loginUser(c, hmnUser, &res) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, err) + } + return res } func DiscordUnlink(c *RequestContext) ResponseData { @@ -188,3 +402,29 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { return c.Redirect(hmnurl.BuildUserProfile(c.CurrentUser.Username), http.StatusSeeOther) } + +func saveDiscordAvatar(ctx context.Context, conn db.ConnOrTx, userID, avatarHash string) (*models.Asset, error) { + const size = 256 + + filename := fmt.Sprintf("%s.png", avatarHash) + url := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s?size=%d", userID, filename, size) + + content, _, err := discord.DownloadDiscordResource(ctx, url) + if err != nil { + return nil, oops.New(err, "failed to download Discord avatar") + } + + asset, err := assets.Create(ctx, conn, assets.CreateInput{ + Content: content, + Filename: filename, + ContentType: "image/png", + + Width: size, + Height: size, + }) + if err != nil { + return nil, oops.New(err, "failed to save asset for Discord attachment") + } + + return asset, nil +} diff --git a/src/website/routes.go b/src/website/routes.go index 219da3d..ff3d336 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -75,6 +75,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.POST(hmnurl.RegexLoginAction, securityTimerMiddleware(time.Millisecond*100, Login)) hmnOnly.GET(hmnurl.RegexLogoutAction, Logout) hmnOnly.GET(hmnurl.RegexLoginPage, LoginPage) + hmnOnly.GET(hmnurl.RegexLoginWithDiscord, LoginWithDiscord) hmnOnly.GET(hmnurl.RegexRegister, RegisterNewUser) hmnOnly.POST(hmnurl.RegexRegister, securityTimerMiddleware(email.ExpectedEmailSendDuration, RegisterNewUserSubmit)) @@ -104,7 +105,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew)) hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit))) - hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback)) + hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, DiscordOAuthCallback) hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink))) hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog))) diff --git a/src/website/user.go b/src/website/user.go index d364606..d72e5ed 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -213,10 +213,11 @@ func UserSettings(c *RequestContext) ResponseData { AvatarMaxFileSize int DefaultAvatarUrl string - User templates.User - Email string // these fields are handled specially on templates.User - ShowEmail bool - LinksText string + User templates.User + Email string // these fields are handled specially on templates.User + ShowEmail bool + LinksText string + HasPassword bool SubmitUrl string ContactUrl string @@ -292,13 +293,14 @@ func UserSettings(c *RequestContext) ResponseData { Email: c.CurrentUser.Email, ShowEmail: c.CurrentUser.ShowEmail, LinksText: linksText, + HasPassword: c.CurrentUser.Password != "", SubmitUrl: hmnurl.BuildUserSettings(""), ContactUrl: hmnurl.BuildContactPage(), DiscordUser: tduser, DiscordNumUnsavedMessages: numUnsavedMessages, - DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken), + DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken, false), DiscordUnlinkUrl: hmnurl.BuildDiscordUnlink(), DiscordShowcaseBacklogUrl: hmnurl.BuildDiscordShowcaseBacklog(), }, c.Perf) @@ -424,7 +426,13 @@ func UserSettingsSave(c *RequestContext) ResponseData { // Update password oldPassword := form.Get("old_password") newPassword := form.Get("new_password") - if oldPassword != "" && newPassword != "" { + var doChangePassword bool + if c.CurrentUser.Password == "" { + doChangePassword = newPassword != "" + } else { + doChangePassword = oldPassword != "" && newPassword != "" + } + if doChangePassword { errorRes := updatePassword(c, tx, oldPassword, newPassword) if errorRes != nil { return *errorRes @@ -558,25 +566,27 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData { } func updatePassword(c *RequestContext, tx pgx.Tx, old, new string) *ResponseData { - oldHashedPassword, err := auth.ParsePasswordString(c.CurrentUser.Password) - if err != nil { - c.Logger.Warn().Err(err).Msg("failed to parse user's password string") - return nil - } + if c.CurrentUser.Password != "" { + oldHashedPassword, err := auth.ParsePasswordString(c.CurrentUser.Password) + if err != nil { + c.Logger.Warn().Err(err).Msg("failed to parse user's password string") + return nil + } - ok, err := auth.CheckPassword(old, oldHashedPassword) - if err != nil { - res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check user's password")) - return &res - } + ok, err := auth.CheckPassword(old, oldHashedPassword) + if err != nil { + res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check user's password")) + return &res + } - if !ok { - res := c.RejectRequest("The old password you provided was not correct.") - return &res + if !ok { + res := c.RejectRequest("The old password you provided was not correct.") + return &res + } } newHashedPassword := auth.HashPassword(new) - err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword) + err := auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword) if err != nil { res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password")) return &res diff --git a/src/website/website.go b/src/website/website.go index 6b30627..aaca3ef 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -43,7 +43,7 @@ var WebsiteCommand = &cobra.Command{ } backgroundJobsDone := jobs.Zip( - auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), + auth.PeriodicallyDeleteExpiredStuff(backgroundJobContext, conn), auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn), perfCollector.Job, discord.RunDiscordBot(backgroundJobContext, conn),