From eb6cd6e89cd0237a87b7e14ce5601c97dbbcebbc Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Thu, 4 May 2023 21:03:40 -0500 Subject: [PATCH] Get login with Discord working --- src/auth/session.go | 15 +- src/discord/payloads.go | 1 + src/discord/ratelimiting.go | 10 +- src/discord/rest.go | 10 +- src/hmnurl/urls.go | 7 + .../2023-05-04T024712Z_AddPendingSignups.go | 53 ++++ ...-05-05T003208Z_NoDefaultDiscordShowcase.go | 51 ++++ src/migration/seed.go | 6 +- src/models/user.go | 6 + src/twitch/twitch.go | 2 +- src/utils/utils.go | 16 + src/website/auth.go | 43 ++- src/website/discord.go | 274 ++++++++++++++---- src/website/routes.go | 3 +- src/website/user.go | 2 +- 15 files changed, 431 insertions(+), 68 deletions(-) create mode 100644 src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go create mode 100644 src/migration/migrations/2023-05-05T003208Z_NoDefaultDiscordShowcase.go diff --git a/src/auth/session.go b/src/auth/session.go index 997704d..eaace17 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -24,7 +24,7 @@ const CSRFFieldName = "csrf_token" const sessionDuration = time.Hour * 24 * 14 -func makeSessionId() string { +func MakeSessionId() string { idBytes := make([]byte, 40) _, err := io.ReadFull(rand.Reader, idBytes) if err != nil { @@ -47,7 +47,16 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM session WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, + ` + SELECT $columns + FROM session + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + id, + ) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -61,7 +70,7 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses func CreateSession(ctx context.Context, conn *pgxpool.Pool, username string) (*models.Session, error) { session := models.Session{ - ID: makeSessionId(), + ID: MakeSessionId(), Username: username, ExpiresAt: time.Now().Add(sessionDuration), CSRFToken: makeCSRFToken(), diff --git a/src/discord/payloads.go b/src/discord/payloads.go index d477d22..698bc9e 100644 --- a/src/discord/payloads.go +++ b/src/discord/payloads.go @@ -346,6 +346,7 @@ type User struct { Avatar *string `json:"avatar"` IsBot bool `json:"bot"` Locale string `json:"locale"` + Email string `json:"email"` } func UserFromMap(m interface{}, k string) *User { diff --git a/src/discord/ratelimiting.go b/src/discord/ratelimiting.go index 8d74e1d..1f9e44b 100644 --- a/src/discord/ratelimiting.go +++ b/src/discord/ratelimiting.go @@ -110,7 +110,7 @@ func createLimiter(headers rateLimitHeaders, routeName string) { buckets.Store(routeName, headers.Bucket) ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{ - requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts + requests: make(chan struct{}, 200), // presumably this is big enough to handle bursts refills: make(chan rateLimiterRefill), }) if !loaded { @@ -124,7 +124,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("remaining", headers.Remaining). + Msg("rate limiting channel was too small; you should increase the default capacity") break prefillloop } } @@ -158,7 +160,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("maxRequests", refill.maxRequests). + Msg("rate limiting channel was too small; you should increase the default capacity") break refillloop } } diff --git a/src/discord/rest.go b/src/discord/rest.go index 9124858..417be53 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -652,11 +652,17 @@ func EditOriginalInteractionResponse(ctx context.Context, interactionToken strin return &msg, nil } -func GetAuthorizeUrl(state string) string { +func GetAuthorizeUrl(state string, includeEmail bool) string { + scope := "identify" + if includeEmail { + scope = "identify email" + } + params := make(url.Values) params.Set("response_type", "code") params.Set("client_id", config.Config.Discord.OAuthClientID) - params.Set("scope", "identify") + params.Set("scope", scope) + params.Set("prompt", "none") // immediately redirect back to HMN if already authorized params.Set("state", state) params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback()) return fmt.Sprintf("%s?%s", buildUrl("/oauth2/authorize"), params.Encode()) diff --git a/src/hmnurl/urls.go b/src/hmnurl/urls.go index f175a7d..a1fae05 100644 --- a/src/hmnurl/urls.go +++ b/src/hmnurl/urls.go @@ -121,6 +121,13 @@ func BuildLoginPage(redirectTo string) string { return Url("/login", []Q{{Name: "redirect", Value: redirectTo}}) } +var RegexLoginWithDiscord = regexp.MustCompile("^/login-with-discord$") + +func BuildLoginWithDiscord(redirectTo string) string { + defer CatchPanic() + return Url("/login-with-discord", []Q{{Name: "redirect", Value: redirectTo}}) +} + var RegexLogoutAction = regexp.MustCompile("^/logout$") func BuildLogoutAction(redir string) string { diff --git a/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go new file mode 100644 index 0000000..ed59719 --- /dev/null +++ b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go @@ -0,0 +1,53 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "github.com/jackc/pgx/v5" +) + +func init() { + registerMigration(AddPendingSignups{}) +} + +type AddPendingSignups struct{} + +func (m AddPendingSignups) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2023, 5, 4, 2, 47, 12, 0, time.UTC)) +} + +func (m AddPendingSignups) Name() string { + return "AddPendingSignups" +} + +func (m AddPendingSignups) Description() string { + return "Adds the pending login table" +} + +func (m AddPendingSignups) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, + ` + CREATE TABLE pending_login ( + id VARCHAR(40) NOT NULL PRIMARY KEY, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + destination_url VARCHAR(999) NOT NULL + ) + `, + ) + if err != nil { + return err + } + + return nil +} + +func (m AddPendingSignups) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, `DROP TABLE pending_login`) + if err != nil { + return err + } + + return nil +} diff --git a/src/migration/migrations/2023-05-05T003208Z_NoDefaultDiscordShowcase.go b/src/migration/migrations/2023-05-05T003208Z_NoDefaultDiscordShowcase.go new file mode 100644 index 0000000..470b95d --- /dev/null +++ b/src/migration/migrations/2023-05-05T003208Z_NoDefaultDiscordShowcase.go @@ -0,0 +1,51 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "github.com/jackc/pgx/v5" +) + +func init() { + registerMigration(NoDefaultDiscordShowcase{}) +} + +type NoDefaultDiscordShowcase struct{} + +func (m NoDefaultDiscordShowcase) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2023, 5, 5, 0, 32, 8, 0, time.UTC)) +} + +func (m NoDefaultDiscordShowcase) Name() string { + return "NoDefaultDiscordShowcase" +} + +func (m NoDefaultDiscordShowcase) Description() string { + return "Make the Discord showcase setting default to false" +} + +func (m NoDefaultDiscordShowcase) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + ALTER TABLE hmn_user + ALTER COLUMN discord_save_showcase SET DEFAULT FALSE + `) + if err != nil { + return err + } + + return nil +} + +func (m NoDefaultDiscordShowcase) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + ALTER TABLE hmn_user + ALTER COLUMN discord_save_showcase SET DEFAULT TRUE + `) + if err != nil { + return err + } + + return nil +} diff --git a/src/migration/seed.go b/src/migration/seed.go index 0b6aff7..536c592 100644 --- a/src/migration/seed.go +++ b/src/migration/seed.go @@ -180,7 +180,8 @@ func seedUser(ctx context.Context, conn db.ConnOrTx, input models.User) *models. name, bio, blurb, signature, darktheme, showemail, - date_joined, registration_ip, avatar_asset_id + date_joined, registration_ip, avatar_asset_id, + discord_save_showcase ) VALUES ( $1, $2, $3, @@ -189,7 +190,8 @@ func seedUser(ctx context.Context, conn db.ConnOrTx, input models.User) *models. $6, $7, $8, $9, TRUE, $10, - '2017-01-01T00:00:00Z', '192.168.2.1', null + '2017-01-01T00:00:00Z', '192.168.2.1', NULL, + TRUE ) RETURNING $columns `, diff --git a/src/models/user.go b/src/models/user.go index 5caaf6d..d838324 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -71,3 +71,9 @@ func (u *User) CanSeeUnpublishedEducationContent() bool { func (u *User) CanAuthorEducation() bool { return u.IsStaff || u.EducationRole == EduRoleAuthor } + +type PendingLogin struct { + ID string `db:"id"` + ExpiresAt time.Time `db:"expires_at"` + DestinationUrl string `db:"destination_url"` +} diff --git a/src/twitch/twitch.go b/src/twitch/twitch.go index 90b6adc..938ada1 100644 --- a/src/twitch/twitch.go +++ b/src/twitch/twitch.go @@ -931,7 +931,7 @@ func updateStreamHistory(ctx context.Context, dbConn db.ConnOrTx, status *models ` INSERT INTO twitch_stream_history (stream_id, twitch_id, twitch_login, started_at, stream_ended, ended_at, end_approximated, title, category_id, tags, discord_needs_update) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (stream_id) DO UPDATE SET stream_ended = EXCLUDED.stream_ended, ended_at = EXCLUDED.ended_at, diff --git a/src/utils/utils.go b/src/utils/utils.go index 6df2c0d..ea2fd3d 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -104,3 +104,19 @@ func SleepContext(ctx context.Context, d time.Duration) error { return nil } } + +// Panics if the provided value is falsy (so, zero). This works for booleans +// but also normal values, through the magic of generics. +func Assert[T comparable](value T, msg ...any) { + var zero T + if value == zero { + finalMsg := "" + for i, arg := range msg { + if i > 0 { + finalMsg += " " + } + finalMsg += fmt.Sprintf("%v", arg) + } + panic(finalMsg) + } +} diff --git a/src/website/auth.go b/src/website/auth.go index 6a308ee..1cbae09 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -11,6 +11,7 @@ import ( "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/email" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/logging" @@ -120,6 +121,30 @@ func Login(c *RequestContext) ResponseData { return res } +func LoginWithDiscord(c *RequestContext) ResponseData { + destinationUrl := c.URL().Query().Get("redirectTo") + if c.CurrentUser != nil { + return c.Redirect(destinationUrl, http.StatusSeeOther) + } + + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + INSERT INTO pending_login (id, expires_at, destination_url) + VALUES ($1, $2, $3) + RETURNING $columns + `, + auth.MakeSessionId(), time.Now().Add(time.Minute*10), destinationUrl, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save pending login")) + } + + // TODO: EXPIRE THESE + + discordAuthUrl := discord.GetAuthorizeUrl(pendingLogin.ID, true) + return c.Redirect(discordAuthUrl, http.StatusSeeOther) +} + func Logout(c *RequestContext) ResponseData { redirect := c.Req.URL.Query().Get("redirect") @@ -244,8 +269,13 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { var newUserId int err = tx.QueryRow(c, ` - INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO hmn_user ( + username, email, password, date_joined, name, registration_ip, + discord_save_showcase + ) VALUES ( + $1, $2, $3, $4, $5, $6, + TRUE + ) RETURNING id `, username, emailAddress, hashed.String(), now, displayName, c.GetIP(), @@ -460,7 +490,8 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali } // NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email, -// not to changing your password through the user settings page. +// +// not to changing your password through the user settings page. func RequestPasswordReset(c *RequestContext) ResponseData { if c.CurrentUser != nil { return c.Redirect(hmnurl.BuildHomepage(), http.StatusSeeOther) @@ -717,7 +748,11 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro defer c.Perf.EndBlock() hashed, err := auth.ParsePasswordString(user.Password) if err != nil { - return false, oops.New(err, "failed to parse password string") + if user.Password == "" { + return false, nil + } else { + return false, oops.New(err, "failed to parse password string") + } } passwordsMatch, err := auth.CheckPassword(password, hashed) diff --git a/src/website/discord.go b/src/website/discord.go index b9db312..b814878 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -2,7 +2,9 @@ package website import ( "errors" + "fmt" "net/http" + "strings" "time" "git.handmade.network/hmn/hmn/src/config" @@ -11,29 +13,94 @@ import ( "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/utils" ) +// This callback handles Discord account linking whether the user is signed in +// or not. In all cases, the end state is that the user is signed into a +// Handmade Network account with a linked Discord account. HMN accounts will be +// created as necessary. +// +// If we initiate OAuth while logged in, we will use the current session's CSRF +// token as the OAuth state. Otherwise, we will generate a new entry in the +// pending_login table with an equivalently random token and use that token for +// the state. +// +// Considerations: +// +// | | Already signed in | Not signed in | +// |-----------------------|----------------------|-------------------------------| +// | No matching info | Link Discord account | Create HMN account | +// |-----------------------| to current HMN |-------------------------------| +// | Matching Discord user | account (stealing is | Log into HMN account and link | +// |-----------------------| ok, but make sure | Discord user to it. (Double- | +// | One matching email | any other accounts | check Discord link settings.) | +// |-----------------------| are unlinked) |-------------------------------| +// | More than one | | Fail login | +// | matching email | | | +// |-----------------------|----------------------|-------------------------------| func DiscordOAuthCallback(c *RequestContext) ResponseData { query := c.Req.URL.Query() - // Check the state + var destinationUrl string + + tx, err := c.Conn.Begin(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start transaction for Discord OAuth")) + } + defer tx.Rollback(c) + + // Check the state, figure out where we're going state := query.Get("state") - if state != c.CurrentSession.CSRFToken { - // CSRF'd!!!! - - c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") - - res := c.Redirect("/", http.StatusSeeOther) - logoutUser(c, &res) - - return res + if c.CurrentUser == nil { + // Check the state against all our pending signins - if none is found, + // then CSRF'd!!!! (or the login just expired) + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + SELECT $columns + FROM pending_login + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + state, + ) + if err == db.NotFound { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up pending login")) + } + destinationUrl = pendingLogin.DestinationUrl + } else { + // Check the state against the current session - if it does not match, + // then CSRF'd!!!! + if state != c.CurrentSession.CSRFToken { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } + // The only way into OAuth when logged in is when linking your Discord + // account in settings. + destinationUrl = hmnurl.BuildUserSettings("discord") } - // Check for error values and redirect back to user settings + // Check for error values and redirect back to from whence they came if errCode := query.Get("error"); errCode != "" { if errCode == "access_denied" { - // This occurs when the user cancels. Just go back to the profile page. - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + // This occurs when the user cancels. Just go back so they can try again. + var dest string + if c.CurrentUser == nil { + // Send 'em back to the login page for another go, with the + // same destination + dest = hmnurl.BuildLoginPage(destinationUrl) + } else { + dest = hmnurl.BuildUserSettings("discord") + } + return c.Redirect(dest, http.StatusSeeOther) } else { return c.RejectRequest("Failed to authenticate with Discord.") } @@ -41,60 +108,165 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { // Do the actual token exchange code := query.Get("code") - res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) + authRes, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) } - expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) + expiry := time.Now().Add(time.Duration(authRes.ExpiresIn) * time.Second) - user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken) + user, err := discord.GetCurrentUserAsOAuth(c, authRes.AccessToken) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) } + // Make the necessary updates in our database (see table above) + + // Determine which HMN user to associate this Discord login with. This + // may not turn anything up, in which case we need to make an account. + var hmnUser *models.User + if c.CurrentUser != nil { + hmnUser = c.CurrentUser + } else { + utils.Assert(user.Email, "didn't get an email from Discord! bad scopes?") + + userFromDiscordID, err := db.QueryOne[models.User](c, tx, + ` + SELECT $columns{hmn_user} + FROM + discord_user + JOIN hmn_user ON discord_user.hmn_user_id = hmn_user.id + WHERE userid = $1 + `, + user.ID, + ) + if err == nil { + hmnUser = userFromDiscordID + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user from Discord OAuth")) + } + + if hmnUser == nil { + usersFromDiscordEmail, err := db.Query[models.User](c, tx, + ` + SELECT $columns + FROM hmn_user + WHERE + LOWER(email) = LOWER($1) + `, + user.Email, + ) + if err == nil { + if len(usersFromDiscordEmail) > 1 { + // oh no why don't we have a unique constraint on emails + return c.RejectRequest("There are multiple Handmade Network accounts with this email address. Please sign into one of them separately.") + } else if len(usersFromDiscordEmail) == 1 { + hmnUser = usersFromDiscordEmail[0] + } + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user by email")) + } + } + } + + // Create a new HMN account if no existing account matches + if hmnUser == nil { + // Check if an HMN account already has this username. We don't link + // in this case because usernames can be changed and we don't want + // account takeovers. + usernameTaken, err := db.QueryOneScalar[bool](c, tx, + `SELECT COUNT(*) > 0 FROM hmn_user WHERE LOWER(username) = LOWER($1)`, + user.Username, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check if username was taken when logging in with Discord")) + } + if usernameTaken { + return c.RejectRequest(fmt.Sprintf("There is already a Handmade Network account with the username \"%s\".", user.Username)) + } + + newHMNUser, err := db.QueryOne[models.User](c, tx, + ` + INSERT INTO hmn_user ( + username, email, password, date_joined, registration_ip + ) VALUES ( + $1, $2, '', $3, $4 + ) + RETURNING $columns + `, + user.Username, strings.ToLower(user.Email), time.Now(), c.GetIP(), + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new HMN user for Discord login")) + } + hmnUser = newHMNUser + } + + // Add the Discord user data to our database + _, err = tx.Exec(c, + ` + INSERT INTO + discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (userid) DO UPDATE SET + username = EXCLUDED.username, + discriminator = EXCLUDED.discriminator, + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + avatar = EXCLUDED.avatar, + locale = EXCLUDED.locale, + expiry = EXCLUDED.expiry, + hmn_user_id = EXCLUDED.hmn_user_id + `, + user.Username, + user.Discriminator, + authRes.AccessToken, + authRes.RefreshToken, + user.Avatar, + user.Locale, + user.ID, + expiry, + hmnUser.ID, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info")) + } + + // Mark the HMN user as confirmed - Discord is good enough auth for us + _, err = tx.Exec(c, + ` + UPDATE hmn_user + SET status = $1 + WHERE id = $2 + `, + models.UserStatusApproved, + hmnUser.ID, + ) + if err != nil { + c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") + // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. + } + + err = tx.Commit(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save updates from Discord OAuth")) + } + // Add the role on Discord err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) } - // Add the user to our database - _, err = c.Conn.Exec(c, - ` - INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - `, - user.Username, - user.Discriminator, - res.AccessToken, - res.RefreshToken, - user.Avatar, - user.Locale, - user.ID, - expiry, - c.CurrentUser.ID, - ) + res := c.Redirect(destinationUrl, http.StatusSeeOther) + err = loginUser(c, hmnUser, &res) if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info")) + return c.ErrorResponse(http.StatusInternalServerError, err) } - - if c.CurrentUser.Status == models.UserStatusConfirmed { - _, err = c.Conn.Exec(c, - ` - UPDATE hmn_user - SET status = $1 - WHERE id = $2 - `, - models.UserStatusApproved, - c.CurrentUser.ID, - ) - if err != nil { - c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") - // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. - } - } - - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + return res } func DiscordUnlink(c *RequestContext) ResponseData { diff --git a/src/website/routes.go b/src/website/routes.go index 219da3d..ff3d336 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -75,6 +75,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.POST(hmnurl.RegexLoginAction, securityTimerMiddleware(time.Millisecond*100, Login)) hmnOnly.GET(hmnurl.RegexLogoutAction, Logout) hmnOnly.GET(hmnurl.RegexLoginPage, LoginPage) + hmnOnly.GET(hmnurl.RegexLoginWithDiscord, LoginWithDiscord) hmnOnly.GET(hmnurl.RegexRegister, RegisterNewUser) hmnOnly.POST(hmnurl.RegexRegister, securityTimerMiddleware(email.ExpectedEmailSendDuration, RegisterNewUserSubmit)) @@ -104,7 +105,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew)) hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit))) - hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback)) + hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, DiscordOAuthCallback) hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink))) hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog))) diff --git a/src/website/user.go b/src/website/user.go index d364606..f2c76f3 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -298,7 +298,7 @@ func UserSettings(c *RequestContext) ResponseData { DiscordUser: tduser, DiscordNumUnsavedMessages: numUnsavedMessages, - DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken), + DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken, false), DiscordUnlinkUrl: hmnurl.BuildDiscordUnlink(), DiscordShowcaseBacklogUrl: hmnurl.BuildDiscordShowcaseBacklog(), }, c.Perf)