diff --git a/public/discord-login.svg b/public/discord-login.svg new file mode 100644 index 0000000..0e1c8fb --- /dev/null +++ b/public/discord-login.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/public/style.css b/public/style.css index e2f0530..574382b 100644 --- a/public/style.css +++ b/public/style.css @@ -4602,7 +4602,7 @@ code, .code { .pa2, .tab, figure, header .root-item > a, header .submenu > a { padding: 0.5rem; } -.pa3, header #login-popup { +.pa3 { padding: 1rem; } .pa4 { @@ -7422,7 +7422,7 @@ article code { color: #ccc; color: var(--theme-color-dimmest); } -.b--dimmest, .optionbar, blockquote, .post-content th, .post-content td, header #login-popup { +.b--dimmest, .optionbar, blockquote, .post-content th, .post-content td { border-color: #bbb; border-color: var(--dimmest-color); } @@ -8936,25 +8936,19 @@ header #login-popup { background-color: var(--login-popup-background); color: black; color: var(--fg-font-color); - border-width: 1px; - border-style: dashed; visibility: hidden; position: absolute; z-index: 12; - margin-top: 10px; - right: 0px; - top: 20px; - width: 290px; - max-height: 0px; overflow: hidden; - opacity: 0; - transition: all 0.2s; } + right: 0; + top: 100%; + width: 100%; } header #login-popup.open { - max-height: 170px; - opacity: 1; visibility: visible; } - header #login-popup label { - padding-right: 10px; } + @media screen and (min-width: 35em) { + header #login-popup { + top: 2.2rem; + width: 17rem; } } @font-face { font-family: icons; diff --git a/src/auth/session.go b/src/auth/session.go index 997704d..61c0648 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -24,7 +24,7 @@ const CSRFFieldName = "csrf_token" const sessionDuration = time.Hour * 24 * 14 -func makeSessionId() string { +func MakeSessionId() string { idBytes := make([]byte, 40) _, err := io.ReadFull(rand.Reader, idBytes) if err != nil { @@ -47,7 +47,16 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM session WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, + ` + SELECT $columns + FROM session + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + id, + ) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -61,7 +70,7 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses func CreateSession(ctx context.Context, conn *pgxpool.Pool, username string) (*models.Session, error) { session := models.Session{ - ID: makeSessionId(), + ID: MakeSessionId(), Username: username, ExpiresAt: time.Now().Add(sessionDuration), CSRFToken: makeCSRFToken(), @@ -134,7 +143,16 @@ func DeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) (int64, erro return tag.RowsAffected(), nil } -func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) jobs.Job { +func DeleteExpiredPendingLogins(ctx context.Context, conn *pgxpool.Pool) (int64, error) { + tag, err := conn.Exec(ctx, "DELETE FROM pending_login WHERE expires_at <= CURRENT_TIMESTAMP") + if err != nil { + return 0, oops.New(err, "failed to delete expired pending logins") + } + + return tag.RowsAffected(), nil +} + +func PeriodicallyDeleteExpiredStuff(ctx context.Context, conn *pgxpool.Pool) jobs.Job { job := jobs.New() go func() { defer job.Done() @@ -145,6 +163,7 @@ func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) case <-t.C: err := func() (err error) { defer utils.RecoverPanicAsError(&err) + n, err := DeleteExpiredSessions(ctx, conn) if err == nil { if n > 0 { @@ -153,10 +172,20 @@ func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) } else { logging.Error().Err(err).Msg("Failed to delete expired sessions") } + + n, err = DeleteExpiredPendingLogins(ctx, conn) + if err == nil { + if n > 0 { + logging.Info().Int64("num deleted pending logins", n).Msg("Deleted expired pending logins") + } + } else { + logging.Error().Err(err).Msg("Failed to delete expired pending logins") + } + return nil }() if err != nil { - logging.Error().Err(err).Msg("Panicked in PeriodicallyDeleteExpiredSessions") + logging.Error().Err(err).Msg("Panicked in PeriodicallyDeleteExpiredStuff") } case <-ctx.Done(): return diff --git a/src/discord/message_handling.go b/src/discord/message_handling.go index cfd9df1..6cbc2da 100644 --- a/src/discord/message_handling.go +++ b/src/discord/message_handling.go @@ -431,7 +431,7 @@ var discordDownloadClient = &http.Client{ type DiscordResourceBadStatusCode error -func downloadDiscordResource(ctx context.Context, url string) ([]byte, string, error) { +func DownloadDiscordResource(ctx context.Context, url string) ([]byte, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, "", oops.New(err, "failed to make Discord download request") @@ -491,7 +491,7 @@ func saveAttachment( height = *attachment.Height } - content, _, err := downloadDiscordResource(ctx, attachment.Url) + content, _, err := DownloadDiscordResource(ctx, attachment.Url) if err != nil { return nil, oops.New(err, "failed to download Discord attachment") } @@ -561,7 +561,7 @@ func saveEmbed( } maybeSaveImageish := func(i EmbedImageish, contentTypeCheck func(string) bool) (*uuid.UUID, error) { - content, contentType, err := downloadDiscordResource(ctx, *i.Url) + content, contentType, err := DownloadDiscordResource(ctx, *i.Url) if err != nil { var statusError DiscordResourceBadStatusCode if errors.As(err, &statusError) { @@ -838,7 +838,8 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in } // TODO(asaf): I believe this will also match https://example.com?hello=1&whatever=5 -// Probably need to add word boundaries. +// +// Probably need to add word boundaries. var REDiscordTag = regexp.MustCompile(`&([a-zA-Z0-9]+(-[a-zA-Z0-9]+)*)`) func getDiscordTags(content string) []string { diff --git a/src/discord/payloads.go b/src/discord/payloads.go index d477d22..c103894 100644 --- a/src/discord/payloads.go +++ b/src/discord/payloads.go @@ -346,6 +346,7 @@ type User struct { Avatar *string `json:"avatar"` IsBot bool `json:"bot"` Locale string `json:"locale"` + Email string `json:"email"` } func UserFromMap(m interface{}, k string) *User { @@ -387,8 +388,9 @@ func GuildFromMap(m interface{}, k string) *Guild { // https://discord.com/developers/docs/resources/guild#guild-member-object type GuildMember struct { - User *User `json:"user"` - Nick *string `json:"nick"` + User *User `json:"user"` + Nick *string `json:"nick"` + Avatar *string `json:"avatar"` // more fields not yet handled here } @@ -409,8 +411,9 @@ func GuildMemberFromMap(m interface{}, k string) *GuildMember { } gm := &GuildMember{ - User: UserFromMap(m, "user"), - Nick: maybeStringP(mmap, "nick"), + User: UserFromMap(m, "user"), + Nick: maybeStringP(mmap, "nick"), + Avatar: maybeStringP(mmap, "avatar"), } return gm diff --git a/src/discord/ratelimiting.go b/src/discord/ratelimiting.go index 8d74e1d..1f9e44b 100644 --- a/src/discord/ratelimiting.go +++ b/src/discord/ratelimiting.go @@ -110,7 +110,7 @@ func createLimiter(headers rateLimitHeaders, routeName string) { buckets.Store(routeName, headers.Bucket) ilimiter, loaded := rateLimiters.LoadOrStore(headers.Bucket, &restRateLimiter{ - requests: make(chan struct{}, 100), // presumably this is big enough to handle bursts + requests: make(chan struct{}, 200), // presumably this is big enough to handle bursts refills: make(chan rateLimiterRefill), }) if !loaded { @@ -124,7 +124,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("remaining", headers.Remaining). + Msg("rate limiting channel was too small; you should increase the default capacity") break prefillloop } } @@ -158,7 +160,9 @@ func createLimiter(headers rateLimitHeaders, routeName string) { select { case limiter.requests <- struct{}{}: default: - log.Warn().Msg("rate limiting channel was too small; you should increase the default capacity") + log.Warn(). + Int("maxRequests", refill.maxRequests). + Msg("rate limiting channel was too small; you should increase the default capacity") break refillloop } } diff --git a/src/discord/rest.go b/src/discord/rest.go index 9124858..417be53 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -652,11 +652,17 @@ func EditOriginalInteractionResponse(ctx context.Context, interactionToken strin return &msg, nil } -func GetAuthorizeUrl(state string) string { +func GetAuthorizeUrl(state string, includeEmail bool) string { + scope := "identify" + if includeEmail { + scope = "identify email" + } + params := make(url.Values) params.Set("response_type", "code") params.Set("client_id", config.Config.Discord.OAuthClientID) - params.Set("scope", "identify") + params.Set("scope", scope) + params.Set("prompt", "none") // immediately redirect back to HMN if already authorized params.Set("state", state) params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback()) return fmt.Sprintf("%s?%s", buildUrl("/oauth2/authorize"), params.Encode()) diff --git a/src/hmnurl/urls.go b/src/hmnurl/urls.go index f175a7d..a1fae05 100644 --- a/src/hmnurl/urls.go +++ b/src/hmnurl/urls.go @@ -121,6 +121,13 @@ func BuildLoginPage(redirectTo string) string { return Url("/login", []Q{{Name: "redirect", Value: redirectTo}}) } +var RegexLoginWithDiscord = regexp.MustCompile("^/login-with-discord$") + +func BuildLoginWithDiscord(redirectTo string) string { + defer CatchPanic() + return Url("/login-with-discord", []Q{{Name: "redirect", Value: redirectTo}}) +} + var RegexLogoutAction = regexp.MustCompile("^/logout$") func BuildLogoutAction(redir string) string { diff --git a/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go new file mode 100644 index 0000000..ed59719 --- /dev/null +++ b/src/migration/migrations/2023-05-04T024712Z_AddPendingSignups.go @@ -0,0 +1,53 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "github.com/jackc/pgx/v5" +) + +func init() { + registerMigration(AddPendingSignups{}) +} + +type AddPendingSignups struct{} + +func (m AddPendingSignups) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2023, 5, 4, 2, 47, 12, 0, time.UTC)) +} + +func (m AddPendingSignups) Name() string { + return "AddPendingSignups" +} + +func (m AddPendingSignups) Description() string { + return "Adds the pending login table" +} + +func (m AddPendingSignups) Up(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, + ` + CREATE TABLE pending_login ( + id VARCHAR(40) NOT NULL PRIMARY KEY, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL, + destination_url VARCHAR(999) NOT NULL + ) + `, + ) + if err != nil { + return err + } + + return nil +} + +func (m AddPendingSignups) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, `DROP TABLE pending_login`) + if err != nil { + return err + } + + return nil +} diff --git a/src/migration/seed.go b/src/migration/seed.go index 0b6aff7..4d89e6c 100644 --- a/src/migration/seed.go +++ b/src/migration/seed.go @@ -189,7 +189,7 @@ func seedUser(ctx context.Context, conn db.ConnOrTx, input models.User) *models. $6, $7, $8, $9, TRUE, $10, - '2017-01-01T00:00:00Z', '192.168.2.1', null + '2017-01-01T00:00:00Z', '192.168.2.1', NULL ) RETURNING $columns `, diff --git a/src/models/user.go b/src/models/user.go index 5caaf6d..d838324 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -71,3 +71,9 @@ func (u *User) CanSeeUnpublishedEducationContent() bool { func (u *User) CanAuthorEducation() bool { return u.IsStaff || u.EducationRole == EduRoleAuthor } + +type PendingLogin struct { + ID string `db:"id"` + ExpiresAt time.Time `db:"expires_at"` + DestinationUrl string `db:"destination_url"` +} diff --git a/src/rawdata/scss/_header.scss b/src/rawdata/scss/_header.scss index 3a403e0..1460410 100644 --- a/src/rawdata/scss/_header.scss +++ b/src/rawdata/scss/_header.scss @@ -130,33 +130,21 @@ header { @include usevar(background-color, login-popup-background); @include usevar(color, fg-font-color); - @extend .pa3; - - border-width: 1px; - border-style: dashed; - @extend .b--dimmest; - visibility: hidden; position: absolute; z-index: 12; - margin-top: 10px; - right: 0px; - top: 20px; - width: 290px; - max-height: 0px; overflow: hidden; - opacity: 0; - - transition: all 0.2s; + right: 0; + top: 100%; + width: 100%; &.open { - max-height: 170px; - opacity: 1; visibility: visible; } - - label { - padding-right:10px; + + @media #{$breakpoint-not-small} { + top: 2.2rem; + width: 17rem; } } } diff --git a/src/templates/src/auth_login.html b/src/templates/src/auth_login.html index 666611f..4a40738 100644 --- a/src/templates/src/auth_login.html +++ b/src/templates/src/auth_login.html @@ -37,6 +37,18 @@
Need an account? Sign up.
+ +
+
Third-party login
+
+ + Log in with Discord + +
+
diff --git a/src/templates/src/include/header.html b/src/templates/src/include/header.html index 60f4532..751dfaa 100644 --- a/src/templates/src/include/header.html +++ b/src/templates/src/include/header.html @@ -13,14 +13,25 @@ {{ else }} Register Log in -
-
+
+ - Forgot your password? + Forgot your password? -
- +
+ +
+
+
Third-party login
+
+ + Log in with Discord + +
@@ -166,7 +177,6 @@ const loginLink = document.getElementById("login-link"); if (loginPopup !== null) { - loginLink.onclick = (e) => { e.preventDefault(); loginPopup.classList.toggle("open"); diff --git a/src/templates/src/user_settings.html b/src/templates/src/user_settings.html index 68ef0fc..7b82eca 100644 --- a/src/templates/src/user_settings.html +++ b/src/templates/src/user_settings.html @@ -81,18 +81,20 @@
-
-
Old password:
-
- + {{ if .HasPassword }} +
+
Old password:
+
+ +
-
+ {{ end }}
New password:
- Your password must be 8 or more characters, and must differ from your username and current password. + Your password must be 8 or more characters, and must differ from your username{{ if .HasPassword }} and current password{{ end }}. Other than that, please follow best practices.
diff --git a/src/templates/types.go b/src/templates/types.go index acbf187..d72acf0 100644 --- a/src/templates/types.go +++ b/src/templates/types.go @@ -38,13 +38,14 @@ func (bd *BaseData) AddImmediateNotice(class, content string) { } type Header struct { - AdminUrl string - UserProfileUrl string - UserSettingsUrl string - LoginActionUrl string - LogoutActionUrl string - ForgotPasswordUrl string - RegisterUrl string + AdminUrl string + UserProfileUrl string + UserSettingsUrl string + LoginActionUrl string + LogoutActionUrl string + ForgotPasswordUrl string + RegisterUrl string + LoginWithDiscordUrl string HMNHomepageUrl string ProjectIndexUrl string diff --git a/src/twitch/twitch.go b/src/twitch/twitch.go index 90b6adc..938ada1 100644 --- a/src/twitch/twitch.go +++ b/src/twitch/twitch.go @@ -931,7 +931,7 @@ func updateStreamHistory(ctx context.Context, dbConn db.ConnOrTx, status *models ` INSERT INTO twitch_stream_history (stream_id, twitch_id, twitch_login, started_at, stream_ended, ended_at, end_approximated, title, category_id, tags, discord_needs_update) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (stream_id) DO UPDATE SET stream_ended = EXCLUDED.stream_ended, ended_at = EXCLUDED.ended_at, diff --git a/src/utils/utils.go b/src/utils/utils.go index 6df2c0d..ea2fd3d 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -104,3 +104,19 @@ func SleepContext(ctx context.Context, d time.Duration) error { return nil } } + +// Panics if the provided value is falsy (so, zero). This works for booleans +// but also normal values, through the magic of generics. +func Assert[T comparable](value T, msg ...any) { + var zero T + if value == zero { + finalMsg := "" + for i, arg := range msg { + if i > 0 { + finalMsg += " " + } + finalMsg += fmt.Sprintf("%v", arg) + } + panic(finalMsg) + } +} diff --git a/src/website/auth.go b/src/website/auth.go index 6a308ee..c8ab4f3 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -11,6 +11,7 @@ import ( "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/email" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/logging" @@ -23,9 +24,10 @@ var UsernameRegex = regexp.MustCompile(`^[0-9a-zA-Z][\w-]{2,29}$`) type LoginPageData struct { templates.BaseData - RedirectUrl string - RegisterUrl string - ForgotPasswordUrl string + RedirectUrl string + RegisterUrl string + ForgotPasswordUrl string + LoginWithDiscordUrl string } func LoginPage(c *RequestContext) ResponseData { @@ -37,10 +39,11 @@ func LoginPage(c *RequestContext) ResponseData { var res ResponseData res.MustWriteTemplate("auth_login.html", LoginPageData{ - BaseData: getBaseData(c, "Log in", nil), - RedirectUrl: redirect, - RegisterUrl: hmnurl.BuildRegister(redirect), - ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + BaseData: getBaseData(c, "Log in", nil), + RedirectUrl: redirect, + RegisterUrl: hmnurl.BuildRegister(redirect), + ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + LoginWithDiscordUrl: hmnurl.BuildLoginWithDiscord(redirect), }, c.Perf) return res } @@ -120,6 +123,28 @@ func Login(c *RequestContext) ResponseData { return res } +func LoginWithDiscord(c *RequestContext) ResponseData { + destinationUrl := c.URL().Query().Get("redirect") + if c.CurrentUser != nil { + return c.Redirect(destinationUrl, http.StatusSeeOther) + } + + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + INSERT INTO pending_login (id, expires_at, destination_url) + VALUES ($1, $2, $3) + RETURNING $columns + `, + auth.MakeSessionId(), time.Now().Add(time.Minute*10), destinationUrl, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save pending login")) + } + + discordAuthUrl := discord.GetAuthorizeUrl(pendingLogin.ID, true) + return c.Redirect(discordAuthUrl, http.StatusSeeOther) +} + func Logout(c *RequestContext) ResponseData { redirect := c.Req.URL.Query().Get("redirect") @@ -460,7 +485,8 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali } // NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email, -// not to changing your password through the user settings page. +// +// not to changing your password through the user settings page. func RequestPasswordReset(c *RequestContext) ResponseData { if c.CurrentUser != nil { return c.Redirect(hmnurl.BuildHomepage(), http.StatusSeeOther) @@ -717,7 +743,11 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro defer c.Perf.EndBlock() hashed, err := auth.ParsePasswordString(user.Password) if err != nil { - return false, oops.New(err, "failed to parse password string") + if user.Password == "" { + return false, nil + } else { + return false, oops.New(err, "failed to parse password string") + } } passwordsMatch, err := auth.CheckPassword(password, hashed) diff --git a/src/website/base_data.go b/src/website/base_data.go index d6a95c6..8f391c5 100644 --- a/src/website/base_data.go +++ b/src/website/base_data.go @@ -63,12 +63,13 @@ func getBaseData(c *RequestContext, title string, breadcrumbs []templates.Breadc IsProjectPage: !project.IsHMN(), Header: templates.Header{ - AdminUrl: hmnurl.BuildAdminApprovalQueue(), // TODO(asaf): Replace with general-purpose admin page - UserSettingsUrl: hmnurl.BuildUserSettings(""), - LoginActionUrl: hmnurl.BuildLoginAction(c.FullUrl()), - LogoutActionUrl: hmnurl.BuildLogoutAction(c.FullUrl()), - ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), - RegisterUrl: hmnurl.BuildRegister(""), + AdminUrl: hmnurl.BuildAdminApprovalQueue(), // TODO(asaf): Replace with general-purpose admin page + UserSettingsUrl: hmnurl.BuildUserSettings(""), + LoginActionUrl: hmnurl.BuildLoginAction(c.FullUrl()), + LogoutActionUrl: hmnurl.BuildLogoutAction(c.FullUrl()), + ForgotPasswordUrl: hmnurl.BuildRequestPasswordReset(), + RegisterUrl: hmnurl.BuildRegister(""), + LoginWithDiscordUrl: hmnurl.BuildLoginWithDiscord(c.FullUrl()), HMNHomepageUrl: hmnurl.BuildHomepage(), ProjectIndexUrl: hmnurl.BuildProjectIndex(1), diff --git a/src/website/discord.go b/src/website/discord.go index b9db312..52345e5 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -1,39 +1,115 @@ package website import ( + "context" "errors" + "fmt" "net/http" + "strings" "time" + "git.handmade.network/hmn/hmn/src/assets" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/discord" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/utils" + "github.com/google/uuid" ) +// This callback handles Discord account linking whether the user is signed in +// or not. In all cases, the end state is that the user is signed into a +// Handmade Network account with a linked Discord account. HMN accounts will be +// created as necessary. +// +// If we initiate OAuth while logged in, we will use the current session's CSRF +// token as the OAuth state. Otherwise, we will generate a new entry in the +// pending_login table with an equivalently random token and use that token for +// the state. +// +// Considerations: +// +// | | Already signed in | Not signed in | +// |-----------------------|----------------------|-------------------------------| +// | No matching info | Link Discord account | Create HMN account | +// |-----------------------| to current HMN |-------------------------------| +// | Matching Discord user | account (stealing is | Log into HMN account and link | +// |-----------------------| ok, but make sure | Discord user to it. (Double- | +// | One matching email | any other accounts | check Discord link settings.) | +// |-----------------------| are unlinked) |-------------------------------| +// | More than one | | Fail login | +// | matching email | | | +// |-----------------------|----------------------|-------------------------------| func DiscordOAuthCallback(c *RequestContext) ResponseData { query := c.Req.URL.Query() - // Check the state + var destinationUrl string + + tx, err := c.Conn.Begin(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start transaction for Discord OAuth")) + } + defer tx.Rollback(c) + + // Check the state, figure out where we're going state := query.Get("state") - if state != c.CurrentSession.CSRFToken { - // CSRF'd!!!! + if c.CurrentUser == nil { + // Check the state against all our pending signins - if none is found, + // then CSRF'd!!!! (or the login just expired) + pendingLogin, err := db.QueryOne[models.PendingLogin](c, c.Conn, + ` + SELECT $columns + FROM pending_login + WHERE + id = $1 + AND expires_at > CURRENT_TIMESTAMP + `, + state, + ) + if err == db.NotFound { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up pending login")) + } + destinationUrl = pendingLogin.DestinationUrl - c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") - - res := c.Redirect("/", http.StatusSeeOther) - logoutUser(c, &res) - - return res + // Delete the pending login; we're done with it + _, err = tx.Exec(c, `DELETE FROM pending_login WHERE id = $1`, pendingLogin.ID) + if err != nil { + c.Logger.Warn().Str("id", pendingLogin.ID).Err(err).Msg("failed to delete pending login") + } + } else { + // Check the state against the current session - if it does not match, + // then CSRF'd!!!! + if state != c.CurrentSession.CSRFToken { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?") + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + return res + } + // The only way into OAuth when logged in is when linking your Discord + // account in settings. + destinationUrl = hmnurl.BuildUserSettings("discord") } - // Check for error values and redirect back to user settings + // Check for error values and redirect back to from whence they came if errCode := query.Get("error"); errCode != "" { if errCode == "access_denied" { - // This occurs when the user cancels. Just go back to the profile page. - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + // This occurs when the user cancels. Just go back so they can try again. + var dest string + if c.CurrentUser == nil { + // Send 'em back to the login page for another go, with the + // same destination + dest = hmnurl.BuildLoginPage(destinationUrl) + } else { + dest = hmnurl.BuildUserSettings("discord") + } + return c.Redirect(dest, http.StatusSeeOther) } else { return c.RejectRequest("Failed to authenticate with Discord.") } @@ -41,60 +117,198 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { // Do the actual token exchange code := query.Get("code") - res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) + authRes, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) } - expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) + expiry := time.Now().Add(time.Duration(authRes.ExpiresIn) * time.Second) - user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken) + user, err := discord.GetCurrentUserAsOAuth(c, authRes.AccessToken) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) } - // Add the role on Discord - err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) + hmnMember, err := discord.GetGuildMember(c, config.Config.Discord.GuildID, user.ID) if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) + if err == discord.NotFound { + // nothing, this is fine + } else { + c.Logger.Error().Err(err).Msg("failed to get HMN Discord member for Discord user") + } } - // Add the user to our database - _, err = c.Conn.Exec(c, + // Make the necessary updates in our database (see table above) + + // Determine which HMN user to associate this Discord login with. This + // may not turn anything up, in which case we need to make an account. + var hmnUser *models.User + if c.CurrentUser != nil { + hmnUser = c.CurrentUser + } else { + utils.Assert(user.Email, "didn't get an email from Discord! bad scopes?") + + userFromDiscordID, err := db.QueryOne[models.User](c, tx, + ` + SELECT $columns{hmn_user} + FROM + discord_user + JOIN hmn_user ON discord_user.hmn_user_id = hmn_user.id + WHERE userid = $1 + `, + user.ID, + ) + if err == nil { + hmnUser = userFromDiscordID + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user from Discord OAuth")) + } + + if hmnUser == nil { + usersFromDiscordEmail, err := db.Query[models.User](c, tx, + ` + SELECT $columns + FROM hmn_user + WHERE + LOWER(email) = LOWER($1) + `, + user.Email, + ) + if err == nil { + if len(usersFromDiscordEmail) > 1 { + // oh no why don't we have a unique constraint on emails + return c.RejectRequest("There are multiple Handmade Network accounts with this email address. Please sign into one of them separately.") + } else if len(usersFromDiscordEmail) == 1 { + hmnUser = usersFromDiscordEmail[0] + } + } else if err == db.NotFound { + // no problem + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up existing HMN user by email")) + } + } + } + + // Create a new HMN account if no existing account matches + if hmnUser == nil { + // Check if an HMN account already has this username. We don't link + // in this case because usernames can be changed and we don't want + // account takeovers. + usernameTaken, err := db.QueryOneScalar[bool](c, tx, + `SELECT COUNT(*) > 0 FROM hmn_user WHERE LOWER(username) = LOWER($1)`, + user.Username, + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check if username was taken when logging in with Discord")) + } + if usernameTaken { + return c.RejectRequest(fmt.Sprintf("There is already a Handmade Network account with the username \"%s\".", user.Username)) + } + + var displayName string + if hmnMember != nil && hmnMember.Nick != nil { + displayName = *hmnMember.Nick + } + + var avatarHash *string + if hmnMember != nil && hmnMember.Avatar != nil { + avatarHash = hmnMember.Avatar + } else if user.Avatar != nil { + avatarHash = user.Avatar + } + + var avatarAssetID *uuid.UUID + if avatarHash != nil { + // Note! Not using the transaction here. Don't want to fail the login due to avatars. + if avatarAsset, err := saveDiscordAvatar(c, c.Conn, user.ID, *user.Avatar); err == nil { + avatarAssetID = &avatarAsset.ID + } else { + c.Logger.Warn().Err(err).Msg("failed to save Discord avatar") + } + } + + newHMNUser, err := db.QueryOne[models.User](c, tx, + ` + INSERT INTO hmn_user ( + username, name, email, password, avatar_asset_id, date_joined, registration_ip + ) VALUES ( + $1, $2, $3, '', $4, $5, $6 + ) + RETURNING $columns + `, + user.Username, displayName, strings.ToLower(user.Email), avatarAssetID, time.Now(), c.GetIP(), + ) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new HMN user for Discord login")) + } + hmnUser = newHMNUser + } + + // Add the Discord user data to our database + _, err = tx.Exec(c, ` - INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + INSERT INTO + discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (userid) DO UPDATE SET + username = EXCLUDED.username, + discriminator = EXCLUDED.discriminator, + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + avatar = EXCLUDED.avatar, + locale = EXCLUDED.locale, + expiry = EXCLUDED.expiry, + hmn_user_id = EXCLUDED.hmn_user_id `, user.Username, user.Discriminator, - res.AccessToken, - res.RefreshToken, + authRes.AccessToken, + authRes.RefreshToken, user.Avatar, user.Locale, user.ID, expiry, - c.CurrentUser.ID, + hmnUser.ID, ) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info")) } - if c.CurrentUser.Status == models.UserStatusConfirmed { - _, err = c.Conn.Exec(c, - ` - UPDATE hmn_user - SET status = $1 - WHERE id = $2 - `, - models.UserStatusApproved, - c.CurrentUser.ID, - ) + // Mark the HMN user as confirmed - Discord is good enough auth for us + _, err = tx.Exec(c, + ` + UPDATE hmn_user + SET status = $1 + WHERE id = $2 + `, + models.UserStatusApproved, + hmnUser.ID, + ) + if err != nil { + c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") + // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. + } + + err = tx.Commit(c) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save updates from Discord OAuth")) + } + + // Add the role on Discord + if hmnMember != nil { + err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) if err != nil { - c.Logger.Error().Err(err).Msg("failed to set user status to approved after linking discord account") - // NOTE(asaf): It's not worth failing the request over this, so we're not returning an error to the user. + c.Logger.Error().Err(err).Msg("failed to add member role") } } - return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) + res := c.Redirect(destinationUrl, http.StatusSeeOther) + err = loginUser(c, hmnUser, &res) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, err) + } + return res } func DiscordUnlink(c *RequestContext) ResponseData { @@ -188,3 +402,29 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { return c.Redirect(hmnurl.BuildUserProfile(c.CurrentUser.Username), http.StatusSeeOther) } + +func saveDiscordAvatar(ctx context.Context, conn db.ConnOrTx, userID, avatarHash string) (*models.Asset, error) { + const size = 256 + + filename := fmt.Sprintf("%s.png", avatarHash) + url := fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s?size=%d", userID, filename, size) + + content, _, err := discord.DownloadDiscordResource(ctx, url) + if err != nil { + return nil, oops.New(err, "failed to download Discord avatar") + } + + asset, err := assets.Create(ctx, conn, assets.CreateInput{ + Content: content, + Filename: filename, + ContentType: "image/png", + + Width: size, + Height: size, + }) + if err != nil { + return nil, oops.New(err, "failed to save asset for Discord attachment") + } + + return asset, nil +} diff --git a/src/website/routes.go b/src/website/routes.go index 219da3d..ff3d336 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -75,6 +75,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.POST(hmnurl.RegexLoginAction, securityTimerMiddleware(time.Millisecond*100, Login)) hmnOnly.GET(hmnurl.RegexLogoutAction, Logout) hmnOnly.GET(hmnurl.RegexLoginPage, LoginPage) + hmnOnly.GET(hmnurl.RegexLoginWithDiscord, LoginWithDiscord) hmnOnly.GET(hmnurl.RegexRegister, RegisterNewUser) hmnOnly.POST(hmnurl.RegexRegister, securityTimerMiddleware(email.ExpectedEmailSendDuration, RegisterNewUserSubmit)) @@ -104,7 +105,7 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew)) hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit))) - hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback)) + hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, DiscordOAuthCallback) hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink))) hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog))) diff --git a/src/website/user.go b/src/website/user.go index d364606..d72e5ed 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -213,10 +213,11 @@ func UserSettings(c *RequestContext) ResponseData { AvatarMaxFileSize int DefaultAvatarUrl string - User templates.User - Email string // these fields are handled specially on templates.User - ShowEmail bool - LinksText string + User templates.User + Email string // these fields are handled specially on templates.User + ShowEmail bool + LinksText string + HasPassword bool SubmitUrl string ContactUrl string @@ -292,13 +293,14 @@ func UserSettings(c *RequestContext) ResponseData { Email: c.CurrentUser.Email, ShowEmail: c.CurrentUser.ShowEmail, LinksText: linksText, + HasPassword: c.CurrentUser.Password != "", SubmitUrl: hmnurl.BuildUserSettings(""), ContactUrl: hmnurl.BuildContactPage(), DiscordUser: tduser, DiscordNumUnsavedMessages: numUnsavedMessages, - DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken), + DiscordAuthorizeUrl: discord.GetAuthorizeUrl(c.CurrentSession.CSRFToken, false), DiscordUnlinkUrl: hmnurl.BuildDiscordUnlink(), DiscordShowcaseBacklogUrl: hmnurl.BuildDiscordShowcaseBacklog(), }, c.Perf) @@ -424,7 +426,13 @@ func UserSettingsSave(c *RequestContext) ResponseData { // Update password oldPassword := form.Get("old_password") newPassword := form.Get("new_password") - if oldPassword != "" && newPassword != "" { + var doChangePassword bool + if c.CurrentUser.Password == "" { + doChangePassword = newPassword != "" + } else { + doChangePassword = oldPassword != "" && newPassword != "" + } + if doChangePassword { errorRes := updatePassword(c, tx, oldPassword, newPassword) if errorRes != nil { return *errorRes @@ -558,25 +566,27 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData { } func updatePassword(c *RequestContext, tx pgx.Tx, old, new string) *ResponseData { - oldHashedPassword, err := auth.ParsePasswordString(c.CurrentUser.Password) - if err != nil { - c.Logger.Warn().Err(err).Msg("failed to parse user's password string") - return nil - } + if c.CurrentUser.Password != "" { + oldHashedPassword, err := auth.ParsePasswordString(c.CurrentUser.Password) + if err != nil { + c.Logger.Warn().Err(err).Msg("failed to parse user's password string") + return nil + } - ok, err := auth.CheckPassword(old, oldHashedPassword) - if err != nil { - res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check user's password")) - return &res - } + ok, err := auth.CheckPassword(old, oldHashedPassword) + if err != nil { + res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check user's password")) + return &res + } - if !ok { - res := c.RejectRequest("The old password you provided was not correct.") - return &res + if !ok { + res := c.RejectRequest("The old password you provided was not correct.") + return &res + } } newHashedPassword := auth.HashPassword(new) - err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword) + err := auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword) if err != nil { res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password")) return &res diff --git a/src/website/website.go b/src/website/website.go index 6b30627..aaca3ef 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -43,7 +43,7 @@ var WebsiteCommand = &cobra.Command{ } backgroundJobsDone := jobs.Zip( - auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), + auth.PeriodicallyDeleteExpiredStuff(backgroundJobContext, conn), auth.PeriodicallyDeleteInactiveUsers(backgroundJobContext, conn), perfCollector.Job, discord.RunDiscordBot(backgroundJobContext, conn),