diff --git a/src/hmnurl/hmnurl.go b/src/hmnurl/hmnurl.go index c5103b7..a502944 100644 --- a/src/hmnurl/hmnurl.go +++ b/src/hmnurl/hmnurl.go @@ -84,6 +84,8 @@ func UrlWithFragment(path string, query []Q, fragment string) string { return HMNProjectContext.UrlWithFragment(path, query, fragment) } +// Takes a project URL and rewrites it using the current URL context. This can be used +// to convert a personal project URL to official and vice versa. func (c *UrlContext) RewriteProjectUrl(u *url.URL) string { // we need to strip anything matching the personal project regex to get the base path match := RegexPersonalProject.FindString(u.Path) diff --git a/src/website/admin.go b/src/website/admin.go index 43cfca4..b315fbb 100644 --- a/src/website/admin.go +++ b/src/website/admin.go @@ -69,7 +69,7 @@ func AdminAtomFeed(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() @@ -134,7 +134,7 @@ type unapprovedUserData struct { func AdminApprovalQueue(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() @@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { userIds = append(userIds, u.User.ID) } - userLinks, err := db.Query[models.Link](c.Context(), c.Conn, + userLinks, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM @@ -253,13 +253,13 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { userIdStr := c.Req.Form.Get("user_id") userId, err := strconv.Atoi(userIdStr) if err != nil { - return RejectRequest(c, "User id can't be parsed") + return c.RejectRequest("User id can't be parsed") } - user, err := hmndata.FetchUser(c.Context(), c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{}) + user, err := hmndata.FetchUser(c, c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{}) if err != nil { if errors.Is(err, db.NotFound) { - return RejectRequest(c, "User not found") + return c.RejectRequest("User not found") } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) } @@ -267,7 +267,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { whatHappened := "" if action == ApprovalQueueActionApprove { - _, err := c.Conn.Exec(c.Context(), + _, err := c.Conn.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -281,7 +281,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { } whatHappened = fmt.Sprintf("%s approved successfully", user.Username) } else if action == ApprovalQueueActionSpammer { - _, err := c.Conn.Exec(c.Context(), + _, err := c.Conn.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -293,15 +293,15 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to set user to banned")) } - err = auth.DeleteSessionForUser(c.Context(), c.Conn, user.Username) + err = auth.DeleteSessionForUser(c, c.Conn, user.Username) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user")) } - err = deleteAllPostsForUser(c.Context(), c.Conn, user.ID) + err = deleteAllPostsForUser(c, c.Conn, user.ID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's posts")) } - err = deleteAllProjectsForUser(c.Context(), c.Conn, user.ID) + err = deleteAllProjectsForUser(c, c.Conn, user.ID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's projects")) } @@ -324,7 +324,7 @@ type UnapprovedPost struct { } func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { - posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn, + posts, err := db.Query[UnapprovedPost](c, c.Conn, ` SELECT $columns FROM @@ -355,7 +355,7 @@ type UnapprovedProject struct { } func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { - ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn, + ownerIDs, err := db.QueryScalar[int](c, c.Conn, ` SELECT id FROM @@ -369,7 +369,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { return nil, oops.New(err, "failed to fetch unapproved users") } - projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: ownerIDs, IncludeHidden: true, }) @@ -382,7 +382,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { projectIDs = append(projectIDs, p.Project.ID) } - projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, + projectLinks, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM diff --git a/src/website/api.go b/src/website/api.go index 4a8710e..1bbc29e 100644 --- a/src/website/api.go +++ b/src/website/api.go @@ -19,7 +19,7 @@ func APICheckUsername(c *RequestContext) ResponseData { requestedUsername := usernameArgs[0] found = true c.Perf.StartBlock("SQL", "Fetch user") - user, err := db.QueryOne[models.User](c.Context(), c.Conn, + user, err := db.QueryOne[models.User](c, c.Conn, ` SELECT $columns FROM @@ -45,7 +45,7 @@ func APICheckUsername(c *RequestContext) ResponseData { var res ResponseData res.Header().Set("Content-Type", "application/json") - AddCORSHeaders(c, &res) + addCORSHeaders(c, &res) if found { res.Write([]byte(fmt.Sprintf(`{ "found": true, "canonical": "%s" }`, canonicalUsername))) } else { diff --git a/src/website/assets.go b/src/website/assets.go index 541cd60..7bedc23 100644 --- a/src/website/assets.go +++ b/src/website/assets.go @@ -85,7 +85,7 @@ func AssetUpload(c *RequestContext) ResponseData { } } - asset, err := assets.Create(c.Context(), c.Conn, assets.CreateInput{ + asset, err := assets.Create(c, c.Conn, assets.CreateInput{ Content: data, Filename: originalFilename, ContentType: mimeType, diff --git a/src/website/auth.go b/src/website/auth.go index 6ca8e09..2412d51 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -28,7 +28,7 @@ type LoginPageData struct { func LoginPage(c *RequestContext) ResponseData { if c.CurrentUser != nil { - return RejectRequest(c, "You are already logged in.") + return c.RejectRequest("You are already logged in.") } var res ResponseData @@ -75,7 +75,7 @@ func Login(c *RequestContext) ResponseData { return res } - user, err := db.QueryOne[models.User](c.Context(), c.Conn, + user, err := db.QueryOne[models.User](c, c.Conn, ` SELECT $columns FROM hmn_user @@ -102,7 +102,7 @@ func Login(c *RequestContext) ResponseData { } if user.Status == models.UserStatusInactive { - return RejectRequest(c, "You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.") + return c.RejectRequest("You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.") } res := c.Redirect(redirect, http.StatusSeeOther) @@ -136,7 +136,7 @@ func RegisterNewUser(c *RequestContext) ResponseData { func RegisterNewUserSubmit(c *RequestContext) ResponseData { if c.CurrentUser != nil { - return RejectRequest(c, "Can't register new user. You are already logged in") + return c.RejectRequest("Can't register new user. You are already logged in") } c.Req.ParseForm() @@ -146,16 +146,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { password := c.Req.Form.Get("password") password2 := c.Req.Form.Get("password2") if !UsernameRegex.Match([]byte(username)) { - return RejectRequest(c, "Invalid username") + return c.RejectRequest("Invalid username") } if !email.IsEmail(emailAddress) { - return RejectRequest(c, "Invalid email address") + return c.RejectRequest("Invalid email address") } if len(password) < 8 { - return RejectRequest(c, "Password too short") + return c.RejectRequest("Password too short") } if password != password2 { - return RejectRequest(c, "Password confirmation doesn't match password") + return c.RejectRequest("Password confirmation doesn't match password") } c.Perf.StartBlock("SQL", "Check blacklist") @@ -169,7 +169,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Check for existing usernames and emails") userAlreadyExists := true - _, err := db.QueryOneScalar[int](c.Context(), c.Conn, + _, err := db.QueryOneScalar[int](c, c.Conn, ` SELECT id FROM hmn_user @@ -186,11 +186,11 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { } if userAlreadyExists { - return RejectRequest(c, fmt.Sprintf("Username (%s) already exists.", username)) + return c.RejectRequest(fmt.Sprintf("Username (%s) already exists.", username)) } emailAlreadyExists := true - _, err = db.QueryOneScalar[int](c.Context(), c.Conn, + _, err = db.QueryOneScalar[int](c, c.Conn, ` SELECT id FROM hmn_user @@ -215,16 +215,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { hashed := auth.HashPassword(password) c.Perf.StartBlock("SQL", "Create user and one time token") - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) now := time.Now() var newUserId int - err = tx.QueryRow(c.Context(), + err = tx.QueryRow(c, ` INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip) VALUES ($1, $2, $3, $4, $5, $6) @@ -237,7 +237,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { } ott := models.GenerateToken() - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id) VALUES($1, $2, $3, $4, $5) @@ -263,7 +263,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Commit user") - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit user to the db")) } @@ -302,7 +302,7 @@ func EmailConfirmation(c *RequestContext) ResponseData { username, hasUsername := c.PathParams["username"] if !hasUsername { - return RejectRequest(c, "Bad validation url") + return c.RejectRequest("Bad validation url") } token := "" @@ -319,7 +319,7 @@ func EmailConfirmation(c *RequestContext) ResponseData { } if !hasToken { - return RejectRequest(c, "Bad validation url") + return c.RejectRequest("Bad validation url") } validationResult := validateUsernameAndToken(c, username, token, models.TokenTypeRegistration) @@ -366,13 +366,13 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Updating user status and deleting token") - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -385,7 +385,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status")) } - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM one_time_token WHERE id = $1 `, @@ -395,7 +395,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete one time token")) } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit transaction")) } @@ -413,7 +413,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { // NOTE(asaf): Only call this when validationResult.Match is false. func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, validationResult validateUserAndTokenResult) ResponseData { if validationResult.User == nil { - return RejectRequest(c, "You haven't validated your email in time and your user was deleted. You may try registering again with the same username.") + return c.RejectRequest("You haven't validated your email in time and your user was deleted. You may try registering again with the same username.") } if validationResult.OneTimeToken == nil { @@ -422,7 +422,7 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali return c.Redirect(hmnurl.BuildLoginPage(""), http.StatusSeeOther) } - return RejectRequest(c, "Bad token. If you are having problems registering or logging in, please contact the staff.") + return c.RejectRequest("Bad token. If you are having problems registering or logging in, please contact the staff.") } // NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email, @@ -446,14 +446,14 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { emailAddress := strings.TrimSpace(c.Req.Form.Get("email")) if username == "" && emailAddress == "" { - return RejectRequest(c, "You must provide a username and an email address.") + return c.RejectRequest("You must provide a username and an email address.") } c.Perf.StartBlock("SQL", "Fetching user") type userQuery struct { User models.User `db:"hmn_user"` } - user, err := db.QueryOne[models.User](c.Context(), c.Conn, + user, err := db.QueryOne[models.User](c, c.Conn, ` SELECT $columns FROM hmn_user @@ -473,7 +473,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if user != nil { c.Perf.StartBlock("SQL", "Fetching existing token") - resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, + resetToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn, ` SELECT $columns FROM one_time_token @@ -495,7 +495,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if resetToken != nil { if resetToken.Expires.Before(now.Add(time.Minute * 30)) { // NOTE(asaf): Expired or about to expire c.Perf.StartBlock("SQL", "Deleting expired token") - _, err = c.Conn.Exec(c.Context(), + _, err = c.Conn.Exec(c, ` DELETE FROM one_time_token WHERE id = $1 @@ -512,7 +512,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if resetToken == nil { c.Perf.StartBlock("SQL", "Creating new token") - newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, + newToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn, ` INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id) VALUES ($1, $2, $3, $4, $5) @@ -567,12 +567,12 @@ func DoPasswordReset(c *RequestContext) ResponseData { token, hasToken := c.PathParams["token"] if !hasToken || !hasUsername { - return RejectRequest(c, "Bad validation url.") + return c.RejectRequest("Bad validation url.") } validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset) if !validationResult.Match { - return RejectRequest(c, "Bad validation url.") + return c.RejectRequest("Bad validation url.") } var res ResponseData @@ -601,30 +601,30 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset) if !validationResult.Match { - return RejectRequest(c, "Bad validation url.") + return c.RejectRequest("Bad validation url.") } if c.CurrentUser != nil && c.CurrentUser.ID != validationResult.User.ID { - return RejectRequest(c, fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username)) + return c.RejectRequest(fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username)) } if len(password) < 8 { - return RejectRequest(c, "Password too short") + return c.RejectRequest("Password too short") } if password != password2 { - return RejectRequest(c, "Password confirmation doesn't match password") + return c.RejectRequest("Password confirmation doesn't match password") } hashed := auth.HashPassword(password) c.Perf.StartBlock("SQL", "Update user's password and delete reset token") - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - tag, err := tx.Exec(c.Context(), + tag, err := tx.Exec(c, ` UPDATE hmn_user SET password = $1 @@ -638,7 +638,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { } if validationResult.User.Status == models.UserStatusInactive { - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -652,7 +652,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { } } - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM one_time_token WHERE id = $1 @@ -663,7 +663,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete onetimetoken")) } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit password reset to the db")) } @@ -698,7 +698,7 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro // re-hash and save the user's password if necessary if hashed.IsOutdated() { newHashed := auth.HashPassword(password) - err := auth.UpdatePassword(c.Context(), c.Conn, user.Username, newHashed) + err := auth.UpdatePassword(c, c.Conn, user.Username, newHashed) if err != nil { c.Logger.Error().Err(err).Msg("failed to update user's password") } @@ -711,15 +711,15 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) error { c.Perf.StartBlock("SQL", "Setting last login and creating session") defer c.Perf.EndBlock() - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return oops.New(err, "failed to start db transaction") } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) now := time.Now() - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE hmn_user SET last_login = $1 @@ -732,12 +732,12 @@ func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) return oops.New(err, "failed to update last_login for user") } - session, err := auth.CreateSession(c.Context(), c.Conn, user.Username) + session, err := auth.CreateSession(c, c.Conn, user.Username) if err != nil { return oops.New(err, "failed to create session") } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return oops.New(err, "failed to commit transaction") } @@ -749,7 +749,7 @@ func logoutUser(c *RequestContext, res *ResponseData) { sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) if err == nil { // clear the session from the db immediately, no expiration - err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value) + err := auth.DeleteSession(c, c.Conn, sessionCookie.Value) if err != nil { logging.Error().Err(err).Msg("failed to delete session on logout") } @@ -772,7 +772,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, User models.User `db:"hmn_user"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"` } - data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, + data, err := db.QueryOne[userAndTokenQuery](c, c.Conn, ` SELECT $columns FROM hmn_user diff --git a/src/website/blogs.go b/src/website/blogs.go index 7278734..9b18a39 100644 --- a/src/website/blogs.go +++ b/src/website/blogs.go @@ -37,7 +37,7 @@ func BlogIndex(c *RequestContext) ResponseData { const postsPerPage = 20 - numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -51,7 +51,7 @@ func BlogIndex(c *RequestContext) ResponseData { return c.Redirect(c.UrlContext.BuildBlog(page), http.StatusSeeOther) } - threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, Limit: postsPerPage, @@ -78,7 +78,7 @@ func BlogIndex(c *RequestContext) ResponseData { canCreate := false if c.CurrentProject.HasBlog() && c.CurrentUser != nil { isProjectOwner := false - owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID) + owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project owners")) } @@ -128,7 +128,7 @@ func BlogThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread, posts, err := hmndata.FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{ + thread, posts, err := hmndata.FetchThreadPosts(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -155,7 +155,7 @@ func BlogThread(c *RequestContext) ResponseData { // Update thread last read info if c.CurrentUser != nil { c.Perf.StartBlock("SQL", "Update TLRI") - _, err := c.Conn.Exec(c.Context(), + _, err := c.Conn.Exec(c, ` INSERT INTO thread_last_read_info (thread_id, user_id, lastread) VALUES ($1, $2, $3) @@ -196,7 +196,7 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ + thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -227,11 +227,11 @@ func BlogNewThread(c *RequestContext) ResponseData { } func BlogNewThreadSubmit(c *RequestContext) ResponseData { - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) err = c.Req.ParseForm() if err != nil { @@ -240,15 +240,15 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData { title := c.Req.Form.Get("title") unparsed := c.Req.Form.Get("body") if title == "" { - return RejectRequest(c, "You must provide a title for your post.") + return c.RejectRequest("You must provide a title for your post.") } if unparsed == "" { - return RejectRequest(c, "You must provide a body for your post.") + return c.RejectRequest("You must provide a body for your post.") } // Create thread var threadId int - err = tx.QueryRow(c.Context(), + err = tx.QueryRow(c, ` INSERT INTO thread (title, type, project_id, first_id, last_id) VALUES ($1, $2, $3, $4, $5) @@ -265,9 +265,9 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData { } // Create everything else - hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) + hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new blog post")) } @@ -282,11 +282,11 @@ func BlogPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -326,17 +326,17 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -351,16 +351,16 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData { unparsed := c.Req.Form.Get("body") editReason := c.Req.Form.Get("editreason") if title != "" && post.Thread.FirstID != post.Post.ID { - return RejectRequest(c, "You can only edit the title by editing the first post.") + return c.RejectRequest("You can only edit the title by editing the first post.") } if unparsed == "" { - return RejectRequest(c, "You must provide a post body.") + return c.RejectRequest("You must provide a post body.") } - hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) + hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) if title != "" { - _, err := tx.Exec(c.Context(), + _, err := tx.Exec(c, ` UPDATE thread SET title = $1 WHERE id = $2 `, @@ -372,7 +372,7 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData { } } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post")) } @@ -387,7 +387,7 @@ func BlogPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -421,11 +421,11 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) err = c.Req.ParseForm() if err != nil { @@ -433,12 +433,12 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData { } unparsed := c.Req.Form.Get("body") if unparsed == "" { - return RejectRequest(c, "Your reply cannot be empty.") + return c.RejectRequest("Your reply cannot be empty.") } - newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host) + newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to blog post")) } @@ -453,11 +453,11 @@ func BlogPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -503,19 +503,19 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID) + threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post")) } @@ -523,7 +523,7 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData { if threadDeleted { return c.Redirect(c.UrlContext.BuildHomepage(), http.StatusSeeOther) } else { - thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ + thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, }) @@ -560,7 +560,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.ThreadID = threadId c.Perf.StartBlock("SQL", "Verify that the thread exists") - threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, + threadExists, err := db.QueryOneScalar[bool](c, c.Conn, ` SELECT COUNT(*) > 0 FROM thread @@ -588,7 +588,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.PostID = postId c.Perf.StartBlock("SQL", "Verify that the post exists") - postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, + postExists, err := db.QueryOneScalar[bool](c, c.Conn, ` SELECT COUNT(*) > 0 FROM post diff --git a/src/website/common.go b/src/website/common.go new file mode 100644 index 0000000..b37914f --- /dev/null +++ b/src/website/common.go @@ -0,0 +1,138 @@ +package website + +import ( + "errors" + "net/http" + "net/url" + "strings" + + "git.handmade.network/hmn/hmn/src/auth" + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/hmndata" + "git.handmade.network/hmn/hmn/src/hmnurl" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/models" + "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/templates" +) + +func loadCommonData(h Handler) Handler { + return func(c *RequestContext) ResponseData { + c.Perf.StartBlock("MIDDLEWARE", "Load common website data") + { + // get user + { + sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) + if err == nil { + user, session, err := getCurrentUserAndSession(c, sessionCookie.Value) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user")) + } + + c.CurrentUser = user + c.CurrentSession = session + } + // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. + } + + // get current official project (HMN or otherwise, by subdomain) + { + hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost()) + slug := strings.TrimRight(hostPrefix, ".") + var owners []*models.User + + if len(slug) > 0 { + dbProject, err := hmndata.FetchProjectBySlug(c, c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{ + Lifecycles: models.AllProjectLifecycles, + IncludeHidden: true, + }) + if err == nil { + c.CurrentProject = &dbProject.Project + c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme) + owners = dbProject.Owners + } else { + if errors.Is(err, db.NotFound) { + // do nothing, this is fine + } else { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) + } + } + } + + if c.CurrentProject == nil { + dbProject, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{ + Lifecycles: models.AllProjectLifecycles, + IncludeHidden: true, + }) + if err != nil { + panic(oops.New(err, "failed to fetch HMN project")) + } + c.CurrentProject = &dbProject.Project + c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme) + } + + if c.CurrentProject == nil { + panic("failed to load project data") + } + + c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners) + + c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject) + } + + c.Theme = "light" + if c.CurrentUser != nil && c.CurrentUser.DarkTheme { + c.Theme = "dark" + } + } + c.Perf.EndBlock() + + return h(c) + } +} + +// Given a session id, fetches user data from the database. Will return nil if +// the user cannot be found, and will only return an error if it's serious. +func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) { + session, err := auth.GetSession(c, c.Conn, sessionId) + if err != nil { + if errors.Is(err, auth.ErrNoSession) { + return nil, nil, nil + } else { + return nil, nil, oops.New(err, "failed to get current session") + } + } + + user, err := hmndata.FetchUserByUsername(c, c.Conn, nil, session.Username, hmndata.UsersQuery{ + AnyStatus: true, + }) + if err != nil { + if errors.Is(err, db.NotFound) { + logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found") + return nil, nil, nil // user was deleted or something + } else { + return nil, nil, oops.New(err, "failed to get user for session") + } + } + + return user, session, nil +} + +func addCORSHeaders(c *RequestContext, res *ResponseData) { + parsed, err := url.Parse(config.Config.BaseUrl) + if err != nil { + c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers") + return + } + origin := "" + origins, found := c.Req.Header["Origin"] + if found { + origin = origins[0] + } + if strings.HasSuffix(origin, parsed.Host) { + res.Header().Add("Access-Control-Allow-Origin", origin) + res.Header().Add("Access-Control-Allow-Credentials", "true") + res.Header().Add("Vary", "Origin") + } +} diff --git a/src/website/discord.go b/src/website/discord.go index 5fdc29d..b9db312 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -35,31 +35,31 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { // This occurs when the user cancels. Just go back to the profile page. return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) } else { - return RejectRequest(c, "Failed to authenticate with Discord.") + return c.RejectRequest("Failed to authenticate with Discord.") } } // Do the actual token exchange code := query.Get("code") - res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) + res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) } expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) - user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken) + user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) } // Add the role on Discord - err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID) + err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) } // Add the user to our database - _, err = c.Conn.Exec(c.Context(), + _, err = c.Conn.Exec(c, ` INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) @@ -79,7 +79,7 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { } if c.CurrentUser.Status == models.UserStatusConfirmed { - _, err = c.Conn.Exec(c.Context(), + _, err = c.Conn.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -98,13 +98,13 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData { } func DiscordUnlink(c *RequestContext) ResponseData { - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx, + discordUser, err := db.QueryOne[models.DiscordUser](c, tx, ` SELECT $columns FROM discord_user @@ -120,7 +120,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } } - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM discord_user WHERE id = $1 @@ -131,12 +131,12 @@ func DiscordUnlink(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete Discord user")) } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit Discord user delete")) } - err = discord.RemoveGuildMemberRole(c.Context(), discordUser.UserID, config.Config.Discord.MemberRoleID) + err = discord.RemoveGuildMemberRole(c, discordUser.UserID, config.Config.Discord.MemberRoleID) if err != nil { c.Logger.Warn().Err(err).Msg("failed to remove member role on unlink") } @@ -145,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } func DiscordShowcaseBacklog(c *RequestContext) ResponseData { - duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, + duser, err := db.QueryOne[models.DiscordUser](c, c.Conn, `SELECT $columns FROM discord_user WHERE hmn_user_id = $1`, c.CurrentUser.ID, ) @@ -157,7 +157,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user")) } - msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn, + msgIDs, err := db.QueryScalar[string](c, c.Conn, ` SELECT msg.id FROM @@ -174,12 +174,12 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { } for _, msgID := range msgIDs { - interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID) + interned, err := discord.FetchInternedMessage(c, c.Conn, msgID) if err != nil && !errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, err) } else if err == nil { // NOTE(asaf): Creating snippet even if the checkbox is off because the user asked us to. - err = discord.HandleSnippetForInternedMessage(c.Context(), c.Conn, interned, true) + err = discord.HandleSnippetForInternedMessage(c, c.Conn, interned, true) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, err) } diff --git a/src/website/errors.go b/src/website/errors.go index 71eeae8..43c0cba 100644 --- a/src/website/errors.go +++ b/src/website/errors.go @@ -1,6 +1,31 @@ package website -import "fmt" +import ( + "fmt" + "net/http" + "strings" + + "git.handmade.network/hmn/hmn/src/templates" +) + +func FourOhFour(c *RequestContext) ResponseData { + var res ResponseData + res.StatusCode = http.StatusNotFound + + if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") { + templateData := struct { + templates.BaseData + Wanted string + }{ + BaseData: getBaseData(c, "Page not found", nil), + Wanted: c.FullUrl(), + } + res.MustWriteTemplate("404.html", templateData, c.Perf) + } else { + res.Write([]byte("Not Found")) + } + return res +} // A SafeError can be used to wrap another error and explicitly provide // an error message that is safe to show to a user. This allows the original diff --git a/src/website/feed.go b/src/website/feed.go index 99a6bdd..f93c72d 100644 --- a/src/website/feed.go +++ b/src/website/feed.go @@ -33,7 +33,7 @@ var feedThreadTypes = []models.ThreadType{ } func Feed(c *RequestContext) ResponseData { - numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ThreadTypes: feedThreadTypes, }) if err != nil { @@ -156,7 +156,7 @@ func AtomFeed(c *RequestContext) ResponseData { if hasAll { itemsPerFeed = 100000 } - projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, nil, hmndata.ProjectsQuery{ + projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, nil, hmndata.ProjectsQuery{ Limit: itemsPerFeed, Types: hmndata.OfficialProjects, OrderBy: "date_approved DESC", @@ -188,7 +188,7 @@ func AtomFeed(c *RequestContext) ResponseData { feedData.AtomFeedUrl = hmnurl.BuildAtomFeedForShowcase() feedData.FeedUrl = hmnurl.BuildShowcase() - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{ Limit: itemsPerFeed, }) if err != nil { @@ -215,7 +215,7 @@ func AtomFeed(c *RequestContext) ResponseData { } func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostListItem, error) { - postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ThreadTypes: feedThreadTypes, Limit: limit, Offset: offset, @@ -225,7 +225,7 @@ func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostLi return nil, err } c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() diff --git a/src/website/fishbowl.go b/src/website/fishbowl.go index cf222b7..db44954 100644 --- a/src/website/fishbowl.go +++ b/src/website/fishbowl.go @@ -206,7 +206,7 @@ func Fishbowl(c *RequestContext) ResponseData { func FishbowlFiles(c *RequestContext) ResponseData { var res ResponseData fishbowlHTTPFS.ServeHTTP(&res, c.Req) - AddCORSHeaders(c, &res) + addCORSHeaders(c, &res) return res } @@ -224,7 +224,7 @@ func linkifyDiscordContent(c *RequestContext, dbConn db.ConnOrTx, content string discordUserIds = append(discordUserIds, id) } - hmnUsers, err := hmndata.FetchUsers(c.Context(), dbConn, c.CurrentUser, hmndata.UsersQuery{ + hmnUsers, err := hmndata.FetchUsers(c, dbConn, c.CurrentUser, hmndata.UsersQuery{ DiscordUserIDs: discordUserIds, }) if err != nil { diff --git a/src/website/forums.go b/src/website/forums.go index 49a403e..18e06b0 100644 --- a/src/website/forums.go +++ b/src/website/forums.go @@ -91,7 +91,7 @@ func Forum(c *RequestContext) ResponseData { currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) - numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, SubforumIDs: []int{cd.SubforumID}, @@ -107,7 +107,7 @@ func Forum(c *RequestContext) ResponseData { } howManyThreadsToSkip := (page - 1) * threadsPerPage - mainThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + mainThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, SubforumIDs: []int{cd.SubforumID}, @@ -141,7 +141,7 @@ func Forum(c *RequestContext) ResponseData { subforumNodes := cd.SubforumTree[cd.SubforumID].Children for _, sfNode := range subforumNodes { - numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, SubforumIDs: []int{sfNode.ID}, @@ -150,7 +150,7 @@ func Forum(c *RequestContext) ResponseData { panic(oops.New(err, "failed to get count of threads")) } - subforumThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + subforumThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, SubforumIDs: []int{sfNode.ID}, @@ -203,7 +203,7 @@ func Forum(c *RequestContext) ResponseData { func ForumMarkRead(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() @@ -212,16 +212,16 @@ func ForumMarkRead(c *RequestContext) ResponseData { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) sfIds := []int{sfId} if sfId == 0 { // Mark literally everything as read - _, err := tx.Exec(c.Context(), + _, err := tx.Exec(c, ` UPDATE hmn_user SET marked_all_read_at = NOW() @@ -234,7 +234,7 @@ func ForumMarkRead(c *RequestContext) ResponseData { } // Delete thread unread info - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM thread_last_read_info WHERE user_id = $1; @@ -246,7 +246,7 @@ func ForumMarkRead(c *RequestContext) ResponseData { } // Delete subforum unread info - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM subforum_last_read_info WHERE user_id = $1; @@ -258,7 +258,7 @@ func ForumMarkRead(c *RequestContext) ResponseData { } } else { c.Perf.StartBlock("SQL", "Update SLRIs") - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` INSERT INTO subforum_last_read_info (subforum_id, user_id, lastread) SELECT id, $2, $3 @@ -277,7 +277,7 @@ func ForumMarkRead(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Delete TLRIs") - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` DELETE FROM thread_last_read_info WHERE @@ -298,7 +298,7 @@ func ForumMarkRead(c *RequestContext) ResponseData { } } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit SLRI/TLRI updates")) } @@ -332,7 +332,7 @@ func ForumThread(c *RequestContext) ResponseData { return FourOhFour(c) } - threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadIDs: []int{cd.ThreadID}, }) @@ -351,7 +351,7 @@ func ForumThread(c *RequestContext) ResponseData { return c.Redirect(correctThreadUrl, http.StatusSeeOther) } - numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadIDs: []int{cd.ThreadID}, @@ -374,7 +374,7 @@ func ForumThread(c *RequestContext) ResponseData { PreviousUrl: c.UrlContext.BuildForumThread(currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)), } - postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadIDs: []int{thread.ID}, Limit: threadPostsPerPage, @@ -396,7 +396,7 @@ func ForumThread(c *RequestContext) ResponseData { post.ReplyPost = &reply } - addAuthorCountsToPost(c.Context(), c.Conn, &post) + addAuthorCountsToPost(c, c.Conn, &post) posts = append(posts, post) } @@ -404,7 +404,7 @@ func ForumThread(c *RequestContext) ResponseData { // Update thread last read info if c.CurrentUser != nil { c.Perf.StartBlock("SQL", "Update TLRI") - _, err = c.Conn.Exec(c.Context(), + _, err = c.Conn.Exec(c, ` INSERT INTO thread_last_read_info (thread_id, user_id, lastread) VALUES ($1, $2, $3) @@ -445,7 +445,7 @@ func ForumPostRedirect(c *RequestContext) ResponseData { return FourOhFour(c) } - posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadIDs: []int{cd.ThreadID}, @@ -495,11 +495,11 @@ func ForumNewThread(c *RequestContext) ResponseData { } func ForumNewThreadSubmit(c *RequestContext) ResponseData { - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) cd, ok := getCommonForumData(c) if !ok { @@ -517,15 +517,15 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData { sticky = true } if title == "" { - return RejectRequest(c, "You must provide a title for your post.") + return c.RejectRequest("You must provide a title for your post.") } if unparsed == "" { - return RejectRequest(c, "You must provide a body for your post.") + return c.RejectRequest("You must provide a body for your post.") } // Create thread var threadId int - err = tx.QueryRow(c.Context(), + err = tx.QueryRow(c, ` INSERT INTO thread (title, sticky, type, project_id, subforum_id, first_id, last_id) VALUES ($1, $2, $3, $4, $5, $6, $7) @@ -544,9 +544,9 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData { } // Create everything else - hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) + hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new forum thread")) } @@ -561,7 +561,7 @@ func ForumPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, }) @@ -600,11 +600,11 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) err = c.Req.ParseForm() if err != nil { @@ -612,10 +612,10 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData { } unparsed := c.Req.Form.Get("body") if unparsed == "" { - return RejectRequest(c, "Your reply cannot be empty.") + return c.RejectRequest("Your reply cannot be empty.") } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, }) @@ -629,9 +629,9 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData { replyPostId = &post.Post.ID } - newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host) + newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to forum post")) } @@ -646,11 +646,11 @@ func ForumPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, }) @@ -688,17 +688,17 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, }) @@ -713,16 +713,16 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData { unparsed := c.Req.Form.Get("body") editReason := c.Req.Form.Get("editreason") if title != "" && post.Thread.FirstID != post.Post.ID { - return RejectRequest(c, "You can only edit the title by editing the first post.") + return c.RejectRequest("You can only edit the title by editing the first post.") } if unparsed == "" { - return RejectRequest(c, "You must provide a body for your post.") + return c.RejectRequest("You must provide a body for your post.") } - hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) + hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) if title != "" { - _, err := tx.Exec(c.Context(), + _, err := tx.Exec(c, ` UPDATE thread SET title = $1 WHERE id = $2 `, @@ -734,7 +734,7 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData { } } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit forum post")) } @@ -749,11 +749,11 @@ func ForumPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ + post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, }) @@ -798,19 +798,19 @@ func ForumPostDeleteSubmit(c *RequestContext) ResponseData { return FourOhFour(c) } - if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { + if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) { return FourOhFour(c) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID) + threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID) - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post")) } @@ -831,7 +831,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData { panic(err) } - thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{ + thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, // This is the rare query where we want all thread types! }) @@ -842,7 +842,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() @@ -874,7 +874,7 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { defer c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() diff --git a/src/website/imagefile_helper.go b/src/website/imagefile_helper.go index dd433b8..38c7e65 100644 --- a/src/website/imagefile_helper.go +++ b/src/website/imagefile_helper.go @@ -89,7 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string, img.Seek(0, io.SeekStart) io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs sha1sum := hasher.Sum(nil) - imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn, + imageFile, err := db.QueryOne[models.ImageFile](c, dbConn, ` INSERT INTO image_file (file, size, sha1sum, protected, width, height) VALUES ($1, $2, $3, $4, $5, $6) diff --git a/src/website/jam.go b/src/website/jam.go index 8464641..963767a 100644 --- a/src/website/jam.go +++ b/src/website/jam.go @@ -51,7 +51,7 @@ func JamIndex2021(c *RequestContext) ResponseData { } tagId := -1 - jamTag, err := hmndata.FetchTag(c.Context(), c.Conn, hmndata.TagQuery{ + jamTag, err := hmndata.FetchTag(c, c.Conn, hmndata.TagQuery{ Text: []string{"wheeljam"}, }) if err == nil { @@ -60,7 +60,7 @@ func JamIndex2021(c *RequestContext) ResponseData { c.Logger.Warn().Err(err).Msg("failed to fetch jam tag; will fetch all snippets as a result") } - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{ Tags: []int{tagId}, }) if err != nil { diff --git a/src/website/landing.go b/src/website/landing.go index 2bef43b..0d994c0 100644 --- a/src/website/landing.go +++ b/src/website/landing.go @@ -34,13 +34,13 @@ type LandingTemplateData struct { func Index(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() var timelineItems []templates.TimelineItem - numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ThreadTypes: feedThreadTypes, }) if err != nil { @@ -65,7 +65,7 @@ func Index(c *RequestContext) ResponseData { } // This is essentially an alternate for feed page 1. - posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ThreadTypes: feedThreadTypes, Limit: feedPostsPerPage, SortDescending: true, @@ -84,7 +84,7 @@ func Index(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Get news") - newsThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ + newsThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ ProjectIDs: []int{models.HMNProjectID}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, Limit: 1, @@ -106,7 +106,7 @@ func Index(c *RequestContext) ResponseData { } c.Perf.EndBlock() - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{ Limit: 40, }) if err != nil { diff --git a/src/website/middlewares.go b/src/website/middlewares.go new file mode 100644 index 0000000..12939af --- /dev/null +++ b/src/website/middlewares.go @@ -0,0 +1,123 @@ +package website + +import ( + "fmt" + "math/rand" + "net/http" + "time" + + "git.handmade.network/hmn/hmn/src/auth" + "git.handmade.network/hmn/hmn/src/hmnurl" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/oops" + "git.handmade.network/hmn/hmn/src/perf" + "git.handmade.network/hmn/hmn/src/utils" +) + +func panicCatcherMiddleware(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + defer func() { + if recovered := recover(); recovered != nil { + maybeError, ok := recovered.(*error) + var err error + if ok { + err = *maybeError + } else { + err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered)) + } + res = c.ErrorResponse(http.StatusInternalServerError, err) + } + }() + + return h(c) + } +} + +func trackRequestPerf(h Handler) Handler { + return func(c *RequestContext) ResponseData { + c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path) + defer func() { + c.Perf.EndRequest() + log := logging.Info() + blockStack := make([]time.Time, 0) + for i, block := range c.Perf.Blocks { + for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) { + blockStack = blockStack[:len(blockStack)-1] + } + log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs())) + blockStack = append(blockStack, block.End) + } + log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000)) + // perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this + }() + + return h(c) + } +} + +func needsAuth(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + if c.CurrentUser == nil { + return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther) + } + + return h(c) + } +} + +func adminsOnly(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + if c.CurrentUser == nil || !c.CurrentUser.IsStaff { + return FourOhFour(c) + } + + return h(c) + } +} + +func csrfMiddleware(h Handler) Handler { + // CSRF mitigation actions per the OWASP cheat sheet: + // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html + return func(c *RequestContext) ResponseData { + c.Req.ParseMultipartForm(100 * 1024 * 1024) + csrfToken := c.Req.Form.Get(auth.CSRFFieldName) + if csrfToken != c.CurrentSession.CSRFToken { + c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?") + + res := c.Redirect("/", http.StatusSeeOther) + logoutUser(c, &res) + + return res + } + + return h(c) + } +} + +func securityTimerMiddleware(duration time.Duration, h Handler) Handler { + // NOTE(asaf): Will make sure that the request takes at least `duration` to finish. Adds a 10% random duration. + return func(c *RequestContext) ResponseData { + additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10))) + timer := time.NewTimer(duration + additionalDuration) + res := h(c) + select { + case <-c.Done(): + case <-timer.C: + } + return res + } +} + +func logContextErrors(c *RequestContext, errs ...error) { + for _, err := range errs { + c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request") + } +} + +func logContextErrorsMiddleware(h Handler) Handler { + return func(c *RequestContext) ResponseData { + res := h(c) + logContextErrors(c, res.Errors...) + return res + } +} diff --git a/src/website/notices.go b/src/website/notices.go new file mode 100644 index 0000000..fc394e5 --- /dev/null +++ b/src/website/notices.go @@ -0,0 +1,100 @@ +package website + +import ( + "errors" + "html/template" + "net/http" + "strings" + "time" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/templates" +) + +const NoticesCookieName = "hmn_notices" + +func getNoticesFromCookie(c *RequestContext) []templates.Notice { + cookie, err := c.Req.Cookie(NoticesCookieName) + if err != nil { + if !errors.Is(err, http.ErrNoCookie) { + c.Logger.Warn().Err(err).Msg("failed to get notices cookie") + } + return nil + } + return deserializeNoticesFromCookie(cookie.Value) +} + +func storeNoticesInCookie(c *RequestContext, res *ResponseData) { + serialized := serializeNoticesForCookie(c, res.FutureNotices) + if serialized != "" { + noticesCookie := http.Cookie{ + Name: NoticesCookieName, + Value: serialized, + Path: "/", + Domain: config.Config.Auth.CookieDomain, + Expires: time.Now().Add(time.Minute * 5), + Secure: config.Config.Auth.CookieSecure, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + res.SetCookie(¬icesCookie) + } else if !(res.StatusCode >= 300 && res.StatusCode < 400) { + // NOTE(asaf): Don't clear on redirect + noticesCookie := http.Cookie{ + Name: NoticesCookieName, + Path: "/", + Domain: config.Config.Auth.CookieDomain, + MaxAge: -1, + } + res.SetCookie(¬icesCookie) + } +} + +func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string { + var builder strings.Builder + maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices. + size := 0 + for i, notice := range notices { + sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1 + if i != 0 { + sizeIncrease += 1 + } + if size+sizeIncrease > maxSize { + c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie") + break + } + + if i != 0 { + builder.WriteString("\t") + } + builder.WriteString(notice.Class) + builder.WriteString("|") + builder.WriteString(string(notice.Content)) + + size += sizeIncrease + } + return builder.String() +} + +func deserializeNoticesFromCookie(cookieVal string) []templates.Notice { + var result []templates.Notice + notices := strings.Split(cookieVal, "\t") + for _, notice := range notices { + parts := strings.SplitN(notice, "|", 2) + if len(parts) == 2 { + result = append(result, templates.Notice{ + Class: parts[0], + Content: template.HTML(parts[1]), + }) + } + } + return result +} + +func storeNoticesInCookieMiddleware(h Handler) Handler { + return func(c *RequestContext) ResponseData { + res := h(c) + storeNoticesInCookie(c, &res) + return res + } +} diff --git a/src/website/podcast.go b/src/website/podcast.go index a521d4e..04ec6c2 100644 --- a/src/website/podcast.go +++ b/src/website/podcast.go @@ -126,29 +126,29 @@ func PodcastEditSubmit(c *RequestContext) ResponseData { title := c.Req.Form.Get("title") if len(strings.TrimSpace(title)) == 0 { - return RejectRequest(c, "Podcast title is empty") + return c.RejectRequest("Podcast title is empty") } description := c.Req.Form.Get("description") if len(strings.TrimSpace(description)) == 0 { - return RejectRequest(c, "Podcast description is empty") + return c.RejectRequest("Podcast description is empty") } c.Perf.StartBlock("SQL", "Updating podcast") - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) imageSaveResult := SaveImageFile(c, tx, "podcast_image", maxFileSize, fmt.Sprintf("podcast/%s/logo%d", c.CurrentProject.Slug, time.Now().UTC().Unix())) if imageSaveResult.ValidationError != "" { - return RejectRequest(c, imageSaveResult.ValidationError) + return c.RejectRequest(imageSaveResult.ValidationError) } else if imageSaveResult.FatalError != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(imageSaveResult.FatalError, "Failed to save podcast image")) } if imageSaveResult.ImageFile != nil { - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE podcast SET @@ -166,7 +166,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to update podcast")) } } else { - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE podcast SET @@ -179,7 +179,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData { podcastResult.Podcast.ID, ) } - err = tx.Commit(c.Context()) + err = tx.Commit(c) c.Perf.EndBlock() if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to commit db transaction")) @@ -357,16 +357,16 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData { c.Req.ParseForm() title := c.Req.Form.Get("title") if len(strings.TrimSpace(title)) == 0 { - return RejectRequest(c, "Episode title is empty") + return c.RejectRequest("Episode title is empty") } description := c.Req.Form.Get("description") if len(strings.TrimSpace(description)) == 0 { - return RejectRequest(c, "Episode description is empty") + return c.RejectRequest("Episode description is empty") } episodeNumberStr := c.Req.Form.Get("episode_number") episodeNumber, err := strconv.Atoi(episodeNumberStr) if err != nil { - return RejectRequest(c, "Episode number can't be parsed") + return c.RejectRequest("Episode number can't be parsed") } episodeFile := c.Req.Form.Get("episode_file") found = false @@ -378,7 +378,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData { } if !found { - return RejectRequest(c, "Requested episode file not found") + return c.RejectRequest("Requested episode file not found") } c.Perf.StartBlock("MP3", "Parsing mp3 file for duration") @@ -417,7 +417,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData { if isEdit { guidStr = podcastResult.Episodes[0].GUID.String() c.Perf.StartBlock("SQL", "Updating podcast episode") - _, err := c.Conn.Exec(c.Context(), + _, err := c.Conn.Exec(c, ` UPDATE podcast_episode SET @@ -446,7 +446,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData { guid := uuid.New() guidStr = guid.String() c.Perf.StartBlock("SQL", "Creating new podcast episode") - _, err := c.Conn.Exec(c.Context(), + _, err := c.Conn.Exec(c, ` INSERT INTO podcast_episode (guid, title, description, description_rendered, audio_filename, duration, pub_date, episode_number, podcast_id) @@ -532,7 +532,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG Podcast models.Podcast `db:"podcast"` ImageFilename string `db:"imagefile.file"` } - podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn, + podcastQueryResult, err := db.QueryOne[podcastQuery](c, c.Conn, ` SELECT $columns FROM @@ -558,7 +558,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG if fetchEpisodes { if episodeGUID == "" { c.Perf.StartBlock("SQL", "Fetch podcast episodes") - episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn, + episodes, err := db.Query[models.PodcastEpisode](c, c.Conn, ` SELECT $columns FROM podcast_episode AS episode @@ -578,7 +578,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG return result, err } c.Perf.StartBlock("SQL", "Fetch podcast episode") - episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn, + episode, err := db.QueryOne[models.PodcastEpisode](c, c.Conn, ` SELECT $columns FROM podcast_episode AS episode diff --git a/src/website/projects.go b/src/website/projects.go index 878d92a..ed3dfa8 100644 --- a/src/website/projects.go +++ b/src/website/projects.go @@ -26,11 +26,52 @@ import ( "git.handmade.network/hmn/hmn/src/utils" "github.com/google/uuid" "github.com/jackc/pgx/v4" + "github.com/teacat/noire" ) const maxPersonalProjects = 5 const maxProjectOwners = 5 +func ProjectCSS(c *RequestContext) ResponseData { + color := c.URL().Query().Get("color") + if color == "" { + return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) + } + + baseData := getBaseData(c, "", nil) + + bgColor := noire.NewHex(color) + h, s, l := bgColor.HSL() + if baseData.Theme == "dark" { + l = 15 + } else { + l = 95 + } + if s > 20 { + s = 20 + } + bgColor = noire.NewHSL(h, s, l) + + templateData := struct { + templates.BaseData + Color string + PostBgColor string + }{ + BaseData: baseData, + Color: color, + PostBgColor: bgColor.HTML(), + } + + var res ResponseData + res.Header().Add("Content-Type", "text/css") + err := res.WriteTemplate("project.css", templateData, c.Perf) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS")) + } + + return res +} + type ProjectTemplateData struct { templates.BaseData @@ -48,7 +89,7 @@ func ProjectIndex(c *RequestContext) ResponseData { const maxCarouselProjects = 10 const maxPersonalProjects = 10 - officialProjects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + officialProjects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ Types: hmndata.OfficialProjects, }) if err != nil { @@ -123,7 +164,7 @@ func ProjectIndex(c *RequestContext) ResponseData { // Fetch and highlight a random selection of personal projects var personalProjects []templates.Project { - projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ Types: hmndata.PersonalProjects, }) if err != nil { @@ -181,13 +222,13 @@ func ProjectHomepage(c *RequestContext) ResponseData { // There are no further permission checks to do, because permissions are // checked whatever way we fetch the project. - owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID) + owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, err) } c.Perf.StartBlock("SQL", "Fetching screenshots") - screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn, + screenshotFilenames, err := db.QueryScalar[string](c, c.Conn, ` SELECT screenshot.file FROM @@ -204,7 +245,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetching project links") - projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, + projectLinks, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM @@ -221,12 +262,12 @@ func ProjectHomepage(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetching project timeline") - posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ ProjectIDs: []int{c.CurrentProject.ID}, Limit: maxRecentActivity, SortDescending: true, @@ -241,7 +282,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { Value: c.CurrentProject.Blurb, }) - p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{ + p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{ Lifecycles: models.AllProjectLifecycles, IncludeHidden: true, }) @@ -317,7 +358,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { tagId = *c.CurrentProject.TagID } - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{ Tags: []int{tagId}, }) if err != nil { @@ -364,7 +405,7 @@ type ProjectEditData struct { } func ProjectNew(c *RequestContext) ResponseData { - numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: []int{c.CurrentUser.ID}, Types: hmndata.PersonalProjects, }) @@ -372,7 +413,7 @@ func ProjectNew(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects")) } if numProjects >= maxPersonalProjects { - return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) + return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) } var project templates.ProjectSettings @@ -397,16 +438,16 @@ func ProjectNewSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, formResult.Error) } if len(formResult.RejectionReason) != 0 { - return RejectRequest(c, formResult.RejectionReason) + return c.RejectRequest(formResult.RejectionReason) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) - numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: []int{c.CurrentUser.ID}, Types: hmndata.PersonalProjects, }) @@ -414,11 +455,11 @@ func ProjectNewSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects")) } if numProjects >= maxPersonalProjects { - return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) + return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) } var projectId int - err = tx.QueryRow(c.Context(), + err = tx.QueryRow(c, ` INSERT INTO project (name, blurb, description, descparsed, lifecycle, date_created, all_last_updated) @@ -439,12 +480,12 @@ func ProjectNewSubmit(c *RequestContext) ResponseData { formResult.Payload.ProjectID = projectId - err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload) + err = updateProject(c, tx, c.CurrentUser, &formResult.Payload) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, err) } - tx.Commit(c.Context()) + tx.Commit(c) urlContext := &hmnurl.UrlContext{ PersonalProject: true, @@ -461,7 +502,7 @@ func ProjectEdit(c *RequestContext) ResponseData { } p, err := hmndata.FetchProject( - c.Context(), c.Conn, + c, c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{ Lifecycles: models.AllProjectLifecycles, @@ -473,7 +514,7 @@ func ProjectEdit(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetching project links") - projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, + projectLinks, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM @@ -524,23 +565,23 @@ func ProjectEditSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, formResult.Error) } if len(formResult.RejectionReason) != 0 { - return RejectRequest(c, formResult.RejectionReason) + return c.RejectRequest(formResult.RejectionReason) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) formResult.Payload.ProjectID = c.CurrentProject.ID - err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload) + err = updateProject(c, tx, c.CurrentUser, &formResult.Payload) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, err) } - tx.Commit(c.Context()) + tx.Commit(c) urlContext := &hmnurl.UrlContext{ PersonalProject: formResult.Payload.Personal, diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index b653606..70f2ed2 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -13,6 +13,7 @@ import ( "path" "regexp" "strings" + "time" "git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/logging" @@ -43,15 +44,22 @@ func (r *Route) String() string { } type RouteBuilder struct { - Router *Router - Prefixes []*regexp.Regexp - Middleware Middleware + Router *Router + Prefixes []*regexp.Regexp + Middlewares []Middleware } type Handler func(c *RequestContext) ResponseData - type Middleware func(h Handler) Handler +func applyMiddlewares(h Handler, ms []Middleware) Handler { + result := h + for i := len(ms) - 1; i >= 0; i-- { + result = ms[i](result) + } + return result +} + func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler) { // Ensure that this regex matches the start of the string regexStr := regex.String() @@ -59,7 +67,7 @@ func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler panic("All routing regexes must begin with '^'") } - h = rb.Middleware(h) + h = applyMiddlewares(h, rb.Middlewares) for _, method := range methods { rb.Router.Routes = append(rb.Router.Routes, Route{ Method: method, @@ -81,10 +89,19 @@ func (rb *RouteBuilder) POST(regex *regexp.Regexp, h Handler) { rb.Handle([]string{http.MethodPost}, regex, h) } -func (rb *RouteBuilder) Group(regex *regexp.Regexp, addRoutes func(rb *RouteBuilder)) { +func (rb *RouteBuilder) WithMiddleware(ms ...Middleware) RouteBuilder { + newRb := *rb + newRb.Middlewares = append(rb.Middlewares, ms...) + + return newRb +} + +func (rb *RouteBuilder) Group(regex *regexp.Regexp, ms ...Middleware) RouteBuilder { newRb := *rb newRb.Prefixes = append(newRb.Prefixes, regex) - addRoutes(&newRb) + newRb.Middlewares = append(rb.Middlewares, ms...) + + return newRb } func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -138,6 +155,8 @@ nextroute: Req: req, Res: rw, PathParams: params, + + ctx: req.Context(), } c.PathParams = params @@ -174,13 +193,33 @@ type RequestContext struct { ctx context.Context } -func (c *RequestContext) Context() context.Context { - if c.ctx == nil { - c.ctx = c.Req.Context() - } - return c.ctx +// Our RequestContext is a context.Context + +var _ context.Context = &RequestContext{} + +func (c *RequestContext) Deadline() (time.Time, bool) { + return c.ctx.Deadline() } +func (c *RequestContext) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *RequestContext) Err() error { + return c.ctx.Err() +} + +func (c *RequestContext) Value(key any) any { + switch key { + case perf.PerfContextKey: + return c.Perf + default: + return c.ctx.Value(key) + } +} + +// Plus it does many other things specific to us + func (c *RequestContext) URL() *url.URL { return c.Req.URL } @@ -325,7 +364,7 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData { func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData { defer func() { if r := recover(); r != nil { - LogContextErrors(c, errs...) + logContextErrors(c, errs...) panic(r) } }() @@ -338,6 +377,23 @@ func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData { return res } +func (c *RequestContext) RejectRequest(reason string) ResponseData { + type RejectData struct { + templates.BaseData + RejectReason string + } + + var res ResponseData + err := res.WriteTemplate("reject.html", RejectData{ + BaseData: getBaseData(c, "Rejected", nil), + RejectReason: reason, + }, c.Perf) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template")) + } + return res +} + type ResponseData struct { StatusCode int Body *bytes.Buffer diff --git a/src/website/routes.go b/src/website/routes.go index 88393b0..942fd1b 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -1,159 +1,46 @@ package website import ( - "context" "errors" "fmt" - "html/template" - "math/rand" "net/http" - "net/url" - "regexp" "strconv" - "strings" "time" - "git.handmade.network/hmn/hmn/src/auth" - "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/email" "git.handmade.network/hmn/hmn/src/hmndata" "git.handmade.network/hmn/hmn/src/hmnurl" - "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" - "git.handmade.network/hmn/hmn/src/perf" - "git.handmade.network/hmn/hmn/src/templates" "git.handmade.network/hmn/hmn/src/utils" "github.com/jackc/pgx/v4/pgxpool" - "github.com/teacat/noire" ) -func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) http.Handler { +func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { router := &Router{} routes := RouteBuilder{ Router: router, - Middleware: func(h Handler) Handler { - return func(c *RequestContext) (res ResponseData) { - c.Conn = conn - - logPerf := TrackRequestPerf(c) - defer logPerf() - - defer LogContextErrorsFromResponse(c, &res) - defer MiddlewarePanicCatcher(c, &res) - - return h(c) - } + Middlewares: []Middleware{ + setDBConn(conn), + trackRequestPerf, + logContextErrorsMiddleware, + panicCatcherMiddleware, }, } - anyProject := routes - anyProject.Middleware = func(h Handler) Handler { - return func(c *RequestContext) (res ResponseData) { - c.Conn = conn - - logPerf := TrackRequestPerf(c) - defer logPerf() - - defer LogContextErrorsFromResponse(c, &res) - defer MiddlewarePanicCatcher(c, &res) - - defer storeNoticesInCookie(c, &res) - - ok, errRes := LoadCommonWebsiteData(c) - if !ok { - return errRes - } - - return h(c) - } - } - - hmnOnly := routes - hmnOnly.Middleware = func(h Handler) Handler { - return func(c *RequestContext) (res ResponseData) { - c.Conn = conn - - logPerf := TrackRequestPerf(c) - defer logPerf() - - defer LogContextErrorsFromResponse(c, &res) - defer MiddlewarePanicCatcher(c, &res) - - defer storeNoticesInCookie(c, &res) - - ok, errRes := LoadCommonWebsiteData(c) - if !ok { - return errRes - } - - if !c.CurrentProject.IsHMN() { - return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently) - } - - return h(c) - } - } - - authMiddleware := func(h Handler) Handler { - return func(c *RequestContext) (res ResponseData) { - if c.CurrentUser == nil { - return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther) - } - - return h(c) - } - } - - adminMiddleware := func(h Handler) Handler { - return func(c *RequestContext) (res ResponseData) { - if c.CurrentUser == nil || !c.CurrentUser.IsStaff { - return FourOhFour(c) - } - - return h(c) - } - } - - csrfMiddleware := func(h Handler) Handler { - // CSRF mitigation actions per the OWASP cheat sheet: - // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html - return func(c *RequestContext) ResponseData { - c.Req.ParseMultipartForm(100 * 1024 * 1024) - csrfToken := c.Req.Form.Get(auth.CSRFFieldName) - if csrfToken != c.CurrentSession.CSRFToken { - c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?") - - res := c.Redirect("/", http.StatusSeeOther) - logoutUser(c, &res) - - return res - } - - return h(c) - } - } - - securityTimerMiddleware := func(duration time.Duration, h Handler) Handler { - // NOTE(asaf): Will make sure that the request takes at least `delayMs` to finish. Adds a 10% random duration. - return func(c *RequestContext) ResponseData { - additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10))) - timer := time.NewTimer(duration + additionalDuration) - res := h(c) - select { - case <-longRequestContext.Done(): - case <-c.Context().Done(): - case <-timer.C: - } - return res - } - } + anyProject := routes.WithMiddleware( + storeNoticesInCookieMiddleware, + loadCommonData, + ) + hmnOnly := anyProject.WithMiddleware( + redirectToHMN, + ) routes.GET(hmnurl.RegexPublic, func(c *RequestContext) ResponseData { var res ResponseData http.StripPrefix("/public/", http.FileServer(http.Dir("public"))).ServeHTTP(&res, c.Req) - AddCORSHeaders(c, &res) + addCORSHeaders(c, &res) return res }) routes.GET(hmnurl.RegexFishbowlFiles, FishbowlFiles) @@ -189,10 +76,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht hmnOnly.POST(hmnurl.RegexDoPasswordReset, DoPasswordResetSubmit) hmnOnly.GET(hmnurl.RegexAdminAtomFeed, AdminAtomFeed) - hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminMiddleware(AdminApprovalQueue)) - hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminMiddleware(csrfMiddleware(AdminApprovalQueueSubmit))) - hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminMiddleware(csrfMiddleware(UserProfileAdminSetStatus))) - hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminMiddleware(csrfMiddleware(UserProfileAdminNuke))) + hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminsOnly(AdminApprovalQueue)) + hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminsOnly(csrfMiddleware(AdminApprovalQueueSubmit))) + hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminsOnly(csrfMiddleware(UserProfileAdminSetStatus))) + hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminsOnly(csrfMiddleware(UserProfileAdminNuke))) hmnOnly.GET(hmnurl.RegexFeed, Feed) hmnOnly.GET(hmnurl.RegexAtomFeed, AtomFeed) @@ -200,19 +87,19 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht hmnOnly.GET(hmnurl.RegexSnippet, Snippet) hmnOnly.GET(hmnurl.RegexProjectIndex, ProjectIndex) - hmnOnly.GET(hmnurl.RegexProjectNew, authMiddleware(ProjectNew)) - hmnOnly.POST(hmnurl.RegexProjectNew, authMiddleware(csrfMiddleware(ProjectNewSubmit))) + hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew)) + hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit))) - hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback)) - hmnOnly.POST(hmnurl.RegexDiscordUnlink, authMiddleware(csrfMiddleware(DiscordUnlink))) - hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, authMiddleware(csrfMiddleware(DiscordShowcaseBacklog))) + hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback)) + hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink))) + hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog))) hmnOnly.POST(hmnurl.RegexTwitchEventSubCallback, TwitchEventSubCallback) hmnOnly.GET(hmnurl.RegexTwitchDebugPage, TwitchDebugPage) hmnOnly.GET(hmnurl.RegexUserProfile, UserProfile) - hmnOnly.GET(hmnurl.RegexUserSettings, authMiddleware(UserSettings)) - hmnOnly.POST(hmnurl.RegexUserSettings, authMiddleware(csrfMiddleware(UserSettingsSave))) + hmnOnly.GET(hmnurl.RegexUserSettings, needsAuth(UserSettings)) + hmnOnly.POST(hmnurl.RegexUserSettings, needsAuth(csrfMiddleware(UserSettingsSave))) hmnOnly.GET(hmnurl.RegexPodcast, PodcastIndex) hmnOnly.GET(hmnurl.RegexPodcastEdit, PodcastEdit) @@ -231,6 +118,9 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht hmnOnly.GET(hmnurl.RegexLibraryAny, LibraryNotPortedYet) + // Project routes can appear either at the root (e.g. hero.handmade.network/edit) + // or on a personal project path (e.g. handmade.network/p/123/hero/edit). So, we + // have pulled all those routes into this function. attachProjectRoutes := func(rb *RouteBuilder) { rb.GET(hmnurl.RegexHomepage, func(c *RequestContext) ResponseData { if c.CurrentProject.IsHMN() { @@ -240,8 +130,8 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht } }) - rb.GET(hmnurl.RegexProjectEdit, authMiddleware(ProjectEdit)) - rb.POST(hmnurl.RegexProjectEdit, authMiddleware(csrfMiddleware(ProjectEditSubmit))) + rb.GET(hmnurl.RegexProjectEdit, needsAuth(ProjectEdit)) + rb.POST(hmnurl.RegexProjectEdit, needsAuth(csrfMiddleware(ProjectEditSubmit))) // Middleware used for forum action routes - anything related to actually creating or editing forum content needsForums := func(h Handler) Handler { @@ -251,14 +141,14 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht return FourOhFour(c) } // Require auth if forums are enabled - return authMiddleware(h)(c) + return needsAuth(h)(c) } } rb.POST(hmnurl.RegexForumNewThreadSubmit, needsForums(csrfMiddleware(ForumNewThreadSubmit))) rb.GET(hmnurl.RegexForumNewThread, needsForums(ForumNewThread)) rb.GET(hmnurl.RegexForumThread, ForumThread) rb.GET(hmnurl.RegexForum, Forum) - rb.POST(hmnurl.RegexForumMarkRead, authMiddleware(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled + rb.POST(hmnurl.RegexForumMarkRead, needsAuth(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled rb.GET(hmnurl.RegexForumPost, ForumPostRedirect) rb.GET(hmnurl.RegexForumPostReply, needsForums(ForumPostReply)) rb.POST(hmnurl.RegexForumPostReply, needsForums(csrfMiddleware(ForumPostReplySubmit))) @@ -276,7 +166,7 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht return FourOhFour(c) } // Require auth if blogs are enabled - return authMiddleware(h)(c) + return needsAuth(h)(c) } } rb.GET(hmnurl.RegexBlog, BlogIndex) @@ -296,64 +186,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht ), http.StatusMovedPermanently) }) } - hmnOnly.Group(hmnurl.RegexPersonalProject, func(rb *RouteBuilder) { - // TODO(ben): Perhaps someday we can make this middleware modification feel better? It seems - // pretty common to run the outermost middleware first before doing other stuff, but having - // to nest functions this way feels real bad. - rb.Middleware = func(h Handler) Handler { - return hmnOnly.Middleware(func(c *RequestContext) ResponseData { - // At this point we are definitely on the plain old HMN subdomain. - - // Fetch personal project and do whatever - id, err := strconv.Atoi(c.PathParams["projectid"]) - if err != nil { - panic(oops.New(err, "project id was not numeric (bad regex in routing)")) - } - p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{ - Lifecycles: models.AllProjectLifecycles, - IncludeHidden: true, - }) - if err != nil { - if errors.Is(err, db.NotFound) { - return FourOhFour(c) - } else { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project")) - } - } - - c.CurrentProject = &p.Project - c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject) - c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners) - - if !p.Project.Personal { - return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) - } - - if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(p.Project.Name) { - return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) - } - - return h(c) - }) - } - attachProjectRoutes(rb) - }) - anyProject.Group(regexp.MustCompile("^"), func(rb *RouteBuilder) { - rb.Middleware = func(h Handler) Handler { - return anyProject.Middleware(func(c *RequestContext) ResponseData { - // We could be on any project's subdomain. - - // Check if the current project (matched by subdomain) is actually no longer official - // and therefore needs to be redirected to the personal project version of the route. - if c.CurrentProject.Personal { - return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) - } - - return h(c) - }) - } - attachProjectRoutes(rb) - }) + officialProjectRoutes := anyProject.WithMiddleware(officialProjectMiddleware) + personalProjectRoutes := hmnOnly.Group(hmnurl.RegexPersonalProject, personalProjectMiddleware) + attachProjectRoutes(&officialProjectRoutes) + attachProjectRoutes(&personalProjectRoutes) anyProject.POST(hmnurl.RegexAssetUpload, AssetUpload) @@ -375,318 +211,69 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht return router } -func ProjectCSS(c *RequestContext) ResponseData { - color := c.URL().Query().Get("color") - if color == "" { - return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) - } - - baseData := getBaseData(c, "", nil) - - bgColor := noire.NewHex(color) - h, s, l := bgColor.HSL() - if baseData.Theme == "dark" { - l = 15 - } else { - l = 95 - } - if s > 20 { - s = 20 - } - bgColor = noire.NewHSL(h, s, l) - - templateData := struct { - templates.BaseData - Color string - PostBgColor string - }{ - BaseData: baseData, - Color: color, - PostBgColor: bgColor.HTML(), - } - - var res ResponseData - res.Header().Add("Content-Type", "text/css") - err := res.WriteTemplate("project.css", templateData, c.Perf) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS")) - } - - return res -} - -func FourOhFour(c *RequestContext) ResponseData { - var res ResponseData - res.StatusCode = http.StatusNotFound - - if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") { - templateData := struct { - templates.BaseData - Wanted string - }{ - BaseData: getBaseData(c, "Page not found", nil), - Wanted: c.FullUrl(), +func setDBConn(conn *pgxpool.Pool) Middleware { + return func(h Handler) Handler { + return func(c *RequestContext) ResponseData { + c.Conn = conn + return h(c) } - res.MustWriteTemplate("404.html", templateData, c.Perf) - } else { - res.Write([]byte("Not Found")) } - return res } -type RejectData struct { - templates.BaseData - RejectReason string -} - -func RejectRequest(c *RequestContext, reason string) ResponseData { - var res ResponseData - err := res.WriteTemplate("reject.html", RejectData{ - BaseData: getBaseData(c, "Rejected", nil), - RejectReason: reason, - }, c.Perf) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template")) - } - return res -} - -func LoadCommonWebsiteData(c *RequestContext) (bool, ResponseData) { - c.Perf.StartBlock("MIDDLEWARE", "Load common website data") - defer c.Perf.EndBlock() - - // get user - { - sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) - if err == nil { - user, session, err := getCurrentUserAndSession(c, sessionCookie.Value) - if err != nil { - return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user")) - } - - c.CurrentUser = user - c.CurrentSession = session +func redirectToHMN(h Handler) Handler { + return func(c *RequestContext) ResponseData { + if !c.CurrentProject.IsHMN() { + return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently) } - // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. + + return h(c) } +} - // get official project - { - hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost()) - slug := strings.TrimRight(hostPrefix, ".") - var owners []*models.User +func officialProjectMiddleware(h Handler) Handler { + return func(c *RequestContext) ResponseData { + // Check if the current project (matched by subdomain) is actually no longer official + // and therefore needs to be redirected to the personal project version of the route. + if c.CurrentProject.Personal { + return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) + } - if len(slug) > 0 { - dbProject, err := hmndata.FetchProjectBySlug(c.Context(), c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{ - Lifecycles: models.AllProjectLifecycles, - IncludeHidden: true, - }) - if err == nil { - c.CurrentProject = &dbProject.Project - c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme) - owners = dbProject.Owners + return h(c) + } +} + +func personalProjectMiddleware(h Handler) Handler { + return func(c *RequestContext) ResponseData { + hmnProject := c.CurrentProject + + id := utils.Must1(strconv.Atoi(c.PathParams["projectid"])) + p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{ + Lifecycles: models.AllProjectLifecycles, + IncludeHidden: true, + }) + if err != nil { + if errors.Is(err, db.NotFound) { + return FourOhFour(c) } else { - if errors.Is(err, db.NotFound) { - // do nothing, this is fine - } else { - return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) - } + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project")) } } - if c.CurrentProject == nil { - dbProject, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{ - Lifecycles: models.AllProjectLifecycles, - IncludeHidden: true, - }) - if err != nil { - panic(oops.New(err, "failed to fetch HMN project")) - } - c.CurrentProject = &dbProject.Project - c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme) - } - - if c.CurrentProject == nil { - panic("failed to load project data") - } - - c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners) + c.CurrentProject = &p.Project + c.CurrentProject.Color1 = hmnProject.Color1 + c.CurrentProject.Color2 = hmnProject.Color2 c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject) - } + c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners) - c.Theme = "light" - if c.CurrentUser != nil && c.CurrentUser.DarkTheme { - c.Theme = "dark" - } - - return true, ResponseData{} -} - -func AddCORSHeaders(c *RequestContext, res *ResponseData) { - parsed, err := url.Parse(config.Config.BaseUrl) - if err != nil { - c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers") - return - } - origin := "" - origins, found := c.Req.Header["Origin"] - if found { - origin = origins[0] - } - if strings.HasSuffix(origin, parsed.Host) { - res.Header().Add("Access-Control-Allow-Origin", origin) - res.Header().Add("Access-Control-Allow-Credentials", "true") - res.Header().Add("Vary", "Origin") - } -} - -// Given a session id, fetches user data from the database. Will return nil if -// the user cannot be found, and will only return an error if it's serious. -func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) { - session, err := auth.GetSession(c.Context(), c.Conn, sessionId) - if err != nil { - if errors.Is(err, auth.ErrNoSession) { - return nil, nil, nil - } else { - return nil, nil, oops.New(err, "failed to get current session") - } - } - - user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, nil, session.Username, hmndata.UsersQuery{ - AnyStatus: true, - }) - if err != nil { - if errors.Is(err, db.NotFound) { - logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found") - return nil, nil, nil // user was deleted or something - } else { - return nil, nil, oops.New(err, "failed to get user for session") - } - } - - return user, session, nil -} - -func TrackRequestPerf(c *RequestContext) (after func()) { - c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path) - c.ctx = context.WithValue(c.Context(), perf.PerfContextKey, c.Perf) - - return func() { - c.Perf.EndRequest() - log := logging.Info() - blockStack := make([]time.Time, 0) - for i, block := range c.Perf.Blocks { - for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) { - blockStack = blockStack[:len(blockStack)-1] - } - log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs())) - blockStack = append(blockStack, block.End) - } - log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000)) - // perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this - } -} - -func LogContextErrors(c *RequestContext, errs ...error) { - for _, err := range errs { - c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request") - } -} - -func LogContextErrorsFromResponse(c *RequestContext, res *ResponseData) { - LogContextErrors(c, res.Errors...) -} - -func MiddlewarePanicCatcher(c *RequestContext, res *ResponseData) { - if recovered := recover(); recovered != nil { - maybeError, ok := recovered.(*error) - var err error - if ok { - err = *maybeError - } else { - err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered)) - } - *res = c.ErrorResponse(http.StatusInternalServerError, err) - } -} - -const NoticesCookieName = "hmn_notices" - -func getNoticesFromCookie(c *RequestContext) []templates.Notice { - cookie, err := c.Req.Cookie(NoticesCookieName) - if err != nil { - if !errors.Is(err, http.ErrNoCookie) { - c.Logger.Warn().Err(err).Msg("failed to get notices cookie") - } - return nil - } - return deserializeNoticesFromCookie(cookie.Value) -} - -func storeNoticesInCookie(c *RequestContext, res *ResponseData) { - serialized := serializeNoticesForCookie(c, res.FutureNotices) - if serialized != "" { - noticesCookie := http.Cookie{ - Name: NoticesCookieName, - Value: serialized, - Path: "/", - Domain: config.Config.Auth.CookieDomain, - Expires: time.Now().Add(time.Minute * 5), - Secure: config.Config.Auth.CookieSecure, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - } - res.SetCookie(¬icesCookie) - } else if !(res.StatusCode >= 300 && res.StatusCode < 400) { - // NOTE(asaf): Don't clear on redirect - noticesCookie := http.Cookie{ - Name: NoticesCookieName, - Path: "/", - Domain: config.Config.Auth.CookieDomain, - MaxAge: -1, - } - res.SetCookie(¬icesCookie) - } -} - -func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string { - var builder strings.Builder - maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices. - size := 0 - for i, notice := range notices { - sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1 - if i != 0 { - sizeIncrease += 1 - } - if size+sizeIncrease > maxSize { - c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie") - break + if !c.CurrentProject.Personal { + return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) } - if i != 0 { - builder.WriteString("\t") + if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(c.CurrentProject.Name) { + return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther) } - builder.WriteString(notice.Class) - builder.WriteString("|") - builder.WriteString(string(notice.Content)) - size += sizeIncrease + return h(c) } - return builder.String() -} - -func deserializeNoticesFromCookie(cookieVal string) []templates.Notice { - var result []templates.Notice - notices := strings.Split(cookieVal, "\t") - for _, notice := range notices { - parts := strings.SplitN(notice, "|", 2) - if len(parts) == 2 { - result = append(result, templates.Notice{ - Class: parts[0], - Content: template.HTML(parts[1]), - }) - } - } - return result } diff --git a/src/website/routes_test.go b/src/website/routes_test.go index a452ef4..959c33d 100644 --- a/src/website/routes_test.go +++ b/src/website/routes_test.go @@ -31,7 +31,7 @@ func TestLogContextErrors(t *testing.T) { Middleware: func(h Handler) Handler { return func(c *RequestContext) (res ResponseData) { c.Logger = &logger - defer LogContextErrorsFromResponse(c, &res) + defer logContextErrorsMiddleware(c, &res) return h(c) } }, diff --git a/src/website/showcase.go b/src/website/showcase.go index fe3eff0..c98d2a8 100644 --- a/src/website/showcase.go +++ b/src/website/showcase.go @@ -16,7 +16,7 @@ type ShowcaseData struct { } func Showcase(c *RequestContext) ResponseData { - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{}) + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{}) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippets")) } diff --git a/src/website/snippet.go b/src/website/snippet.go index d3c4fde..83ce20d 100644 --- a/src/website/snippet.go +++ b/src/website/snippet.go @@ -30,7 +30,7 @@ func Snippet(c *RequestContext) ResponseData { return FourOhFour(c) } - s, err := hmndata.FetchSnippet(c.Context(), c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{}) + s, err := hmndata.FetchSnippet(c, c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{}) if err != nil { if errors.Is(err, db.NotFound) { return FourOhFour(c) diff --git a/src/website/twitch.go b/src/website/twitch.go index cd8ca7d..0af15c2 100644 --- a/src/website/twitch.go +++ b/src/website/twitch.go @@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData { } func TwitchDebugPage(c *RequestContext) ResponseData { - streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn, + streams, err := db.Query[models.TwitchStream](c, c.Conn, ` SELECT $columns FROM diff --git a/src/website/user.go b/src/website/user.go index 681e928..f749e2a 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -52,7 +52,7 @@ func UserProfile(c *RequestContext) ResponseData { if c.CurrentUser != nil && strings.ToLower(c.CurrentUser.Username) == username { profileUser = c.CurrentUser } else { - user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, c.CurrentUser, username, hmndata.UsersQuery{}) + user, err := hmndata.FetchUserByUsername(c, c.Conn, c.CurrentUser, username, hmndata.UsersQuery{}) if err != nil { if errors.Is(err, db.NotFound) { return FourOhFour(c) @@ -72,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch user links") - userLinks, err := db.Query[models.Link](c.Context(), c.Conn, + userLinks, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM @@ -92,7 +92,7 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.EndBlock() - projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ + projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: []int{profileUser.ID}, Lifecycles: models.AllProjectLifecycles, IncludeHidden: true, @@ -111,13 +111,13 @@ func UserProfile(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetch posts") - posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ + posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{ UserIDs: []int{profileUser.ID}, SortDescending: true, }) c.Perf.EndBlock() - snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ + snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{ OwnerIDs: []int{profileUser.ID}, }) if err != nil { @@ -125,7 +125,7 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch subforum tree") - subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) + subforumTree := models.GetFullSubforumTree(c, c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() @@ -213,7 +213,7 @@ func UserSettings(c *RequestContext) ResponseData { DiscordShowcaseBacklogUrl string } - links, err := db.Query[models.Link](c.Context(), c.Conn, + links, err := db.Query[models.Link](c, c.Conn, ` SELECT $columns FROM link @@ -230,7 +230,7 @@ func UserSettings(c *RequestContext) ResponseData { var tduser *templates.DiscordUser var numUnsavedMessages int - duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, + duser, err := db.QueryOne[models.DiscordUser](c, c.Conn, ` SELECT $columns FROM discord_user @@ -246,7 +246,7 @@ func UserSettings(c *RequestContext) ResponseData { tmp := templates.DiscordUserToTemplate(duser) tduser = &tmp - numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn, + numUnsavedMessages, err = db.QueryOneScalar[int](c, c.Conn, ` SELECT COUNT(*) FROM @@ -299,11 +299,11 @@ func UserSettingsSave(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse form")) } - tx, err := c.Conn.Begin(c.Context()) + tx, err := c.Conn.Begin(c) if err != nil { panic(err) } - defer tx.Rollback(c.Context()) + defer tx.Rollback(c) form, err := c.GetFormValues() if err != nil { @@ -315,7 +315,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { email := form.Get("email") if !hmnemail.IsEmail(email) { - return RejectRequest(c, "Your email was not valid.") + return c.RejectRequest("Your email was not valid.") } showEmail := form.Get("showemail") != "" @@ -328,7 +328,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { discordShowcaseAuto := form.Get("discord-showcase-auto") != "" discordDeleteSnippetOnMessageDelete := form.Get("discord-snippet-keep") == "" - _, err = tx.Exec(c.Context(), + _, err = tx.Exec(c, ` UPDATE hmn_user SET @@ -360,15 +360,15 @@ func UserSettingsSave(c *RequestContext) ResponseData { } // Process links - twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil) + twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil) linksText := form.Get("links") links := ParseLinks(linksText) - _, err = tx.Exec(c.Context(), `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID) + _, err = tx.Exec(c, `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID) if err != nil { c.Logger.Warn().Err(err).Msg("failed to delete old links") } else { for i, link := range links { - _, err := tx.Exec(c.Context(), + _, err := tx.Exec(c, ` INSERT INTO link (name, url, ordering, user_id) VALUES ($1, $2, $3, $4) @@ -384,7 +384,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { } } } - twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil) + twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil) if preErr == nil && postErr == nil { twitch.UserOrProjectLinksUpdated(twitchLoginsPreChange, twitchLoginsPostChange) } @@ -407,7 +407,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { } var avatarUUID *uuid.UUID if newAvatar.Exists { - avatarAsset, err := assets.Create(c.Context(), tx, assets.CreateInput{ + avatarAsset, err := assets.Create(c, tx, assets.CreateInput{ Content: newAvatar.Content, Filename: newAvatar.Filename, ContentType: newAvatar.Mime, @@ -421,7 +421,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { avatarUUID = &avatarAsset.ID } if newAvatar.Exists || newAvatar.Remove { - _, err := tx.Exec(c.Context(), + _, err := tx.Exec(c, ` UPDATE hmn_user SET @@ -437,7 +437,7 @@ func UserSettingsSave(c *RequestContext) ResponseData { } } - err = tx.Commit(c.Context()) + err = tx.Commit(c) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save user settings")) } @@ -454,7 +454,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData { userIdStr := c.Req.Form.Get("user_id") userId, err := strconv.Atoi(userIdStr) if err != nil { - return RejectRequest(c, "No user id provided") + return c.RejectRequest("No user id provided") } status := c.Req.Form.Get("status") @@ -469,10 +469,10 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData { case "banned": desiredStatus = models.UserStatusBanned default: - return RejectRequest(c, "No legal user status provided") + return c.RejectRequest("No legal user status provided") } - _, err = c.Conn.Exec(c.Context(), + _, err = c.Conn.Exec(c, ` UPDATE hmn_user SET status = $1 @@ -485,7 +485,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status")) } if desiredStatus == models.UserStatusBanned { - err = auth.DeleteSessionForUser(c.Context(), c.Conn, c.Req.Form.Get("username")) + err = auth.DeleteSessionForUser(c, c.Conn, c.Req.Form.Get("username")) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user")) } @@ -500,10 +500,10 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData { userIdStr := c.Req.Form.Get("user_id") userId, err := strconv.Atoi(userIdStr) if err != nil { - return RejectRequest(c, "No user id provided") + return c.RejectRequest("No user id provided") } - err = deleteAllPostsForUser(c.Context(), c.Conn, userId) + err = deleteAllPostsForUser(c, c.Conn, userId) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete user posts")) } @@ -514,7 +514,7 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData { func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *ResponseData { if new != confirm { - res := RejectRequest(c, "Your password and password confirmation did not match.") + res := c.RejectRequest("Your password and password confirmation did not match.") return &res } @@ -531,12 +531,12 @@ func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *Res } if !ok { - res := RejectRequest(c, "The old password you provided was not correct.") + res := c.RejectRequest("The old password you provided was not correct.") return &res } newHashedPassword := auth.HashPassword(new) - err = auth.UpdatePassword(c.Context(), tx, c.CurrentUser.Username, newHashedPassword) + err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword) if err != nil { res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password")) return &res diff --git a/src/website/website.go b/src/website/website.go index 2ed8337..6b30627 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -33,14 +33,13 @@ var WebsiteCommand = &cobra.Command{ logging.Info().Msg("Hello, HMN!") backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background()) - longRequestContext, cancelLongRequests := context.WithCancel(context.Background()) conn := db.NewConnPool() perfCollector := perf.RunPerfCollector(backgroundJobContext) server := http.Server{ Addr: config.Config.Addr, - Handler: NewWebsiteRoutes(longRequestContext, conn), + Handler: NewWebsiteRoutes(conn), } backgroundJobsDone := jobs.Zip( @@ -59,8 +58,6 @@ var WebsiteCommand = &cobra.Command{ <-signals logging.Info().Msg("Shutting down the website") go func() { - logging.Info().Msg("cancelling long requests") - cancelLongRequests() timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() logging.Info().Msg("shutting down web server")