Rework requests and middleware (#57)

o boy

Resolves #10 (hopefully!)

Co-authored-by: Ben Visness <bvisness@gmail.com>
Reviewed-on: #57
This commit is contained in:
bvisness 2022-06-24 21:38:11 +00:00
parent 32db9b1843
commit e9d4300100
27 changed files with 851 additions and 782 deletions

View File

@ -84,6 +84,8 @@ func UrlWithFragment(path string, query []Q, fragment string) string {
return HMNProjectContext.UrlWithFragment(path, query, fragment) return HMNProjectContext.UrlWithFragment(path, query, fragment)
} }
// Takes a project URL and rewrites it using the current URL context. This can be used
// to convert a personal project URL to official and vice versa.
func (c *UrlContext) RewriteProjectUrl(u *url.URL) string { func (c *UrlContext) RewriteProjectUrl(u *url.URL) string {
// we need to strip anything matching the personal project regex to get the base path // we need to strip anything matching the personal project regex to get the base path
match := RegexPersonalProject.FindString(u.Path) match := RegexPersonalProject.FindString(u.Path)

View File

@ -69,7 +69,7 @@ func AdminAtomFeed(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
@ -134,7 +134,7 @@ type unapprovedUserData struct {
func AdminApprovalQueue(c *RequestContext) ResponseData { func AdminApprovalQueue(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID) userIds = append(userIds, u.User.ID)
} }
userLinks, err := db.Query[models.Link](c.Context(), c.Conn, userLinks, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -253,13 +253,13 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id") userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr) userId, err := strconv.Atoi(userIdStr)
if err != nil { if err != nil {
return RejectRequest(c, "User id can't be parsed") return c.RejectRequest("User id can't be parsed")
} }
user, err := hmndata.FetchUser(c.Context(), c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{}) user, err := hmndata.FetchUser(c, c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{})
if err != nil { if err != nil {
if errors.Is(err, db.NotFound) { if errors.Is(err, db.NotFound) {
return RejectRequest(c, "User not found") return c.RejectRequest("User not found")
} else { } else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
} }
@ -267,7 +267,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
whatHappened := "" whatHappened := ""
if action == ApprovalQueueActionApprove { if action == ApprovalQueueActionApprove {
_, err := c.Conn.Exec(c.Context(), _, err := c.Conn.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -281,7 +281,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
} }
whatHappened = fmt.Sprintf("%s approved successfully", user.Username) whatHappened = fmt.Sprintf("%s approved successfully", user.Username)
} else if action == ApprovalQueueActionSpammer { } else if action == ApprovalQueueActionSpammer {
_, err := c.Conn.Exec(c.Context(), _, err := c.Conn.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -293,15 +293,15 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to set user to banned")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to set user to banned"))
} }
err = auth.DeleteSessionForUser(c.Context(), c.Conn, user.Username) err = auth.DeleteSessionForUser(c, c.Conn, user.Username)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
} }
err = deleteAllPostsForUser(c.Context(), c.Conn, user.ID) err = deleteAllPostsForUser(c, c.Conn, user.ID)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's posts")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's posts"))
} }
err = deleteAllProjectsForUser(c.Context(), c.Conn, user.ID) err = deleteAllProjectsForUser(c, c.Conn, user.ID)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's projects")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's projects"))
} }
@ -324,7 +324,7 @@ type UnapprovedPost struct {
} }
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn, posts, err := db.Query[UnapprovedPost](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -355,7 +355,7 @@ type UnapprovedProject struct {
} }
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn, ownerIDs, err := db.QueryScalar[int](c, c.Conn,
` `
SELECT id SELECT id
FROM FROM
@ -369,7 +369,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
return nil, oops.New(err, "failed to fetch unapproved users") return nil, oops.New(err, "failed to fetch unapproved users")
} }
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: ownerIDs, OwnerIDs: ownerIDs,
IncludeHidden: true, IncludeHidden: true,
}) })
@ -382,7 +382,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
projectIDs = append(projectIDs, p.Project.ID) projectIDs = append(projectIDs, p.Project.ID)
} }
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, projectLinks, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM

View File

@ -19,7 +19,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
requestedUsername := usernameArgs[0] requestedUsername := usernameArgs[0]
found = true found = true
c.Perf.StartBlock("SQL", "Fetch user") c.Perf.StartBlock("SQL", "Fetch user")
user, err := db.QueryOne[models.User](c.Context(), c.Conn, user, err := db.QueryOne[models.User](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -45,7 +45,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
var res ResponseData var res ResponseData
res.Header().Set("Content-Type", "application/json") res.Header().Set("Content-Type", "application/json")
AddCORSHeaders(c, &res) addCORSHeaders(c, &res)
if found { if found {
res.Write([]byte(fmt.Sprintf(`{ "found": true, "canonical": "%s" }`, canonicalUsername))) res.Write([]byte(fmt.Sprintf(`{ "found": true, "canonical": "%s" }`, canonicalUsername)))
} else { } else {

View File

@ -85,7 +85,7 @@ func AssetUpload(c *RequestContext) ResponseData {
} }
} }
asset, err := assets.Create(c.Context(), c.Conn, assets.CreateInput{ asset, err := assets.Create(c, c.Conn, assets.CreateInput{
Content: data, Content: data,
Filename: originalFilename, Filename: originalFilename,
ContentType: mimeType, ContentType: mimeType,

View File

@ -28,7 +28,7 @@ type LoginPageData struct {
func LoginPage(c *RequestContext) ResponseData { func LoginPage(c *RequestContext) ResponseData {
if c.CurrentUser != nil { if c.CurrentUser != nil {
return RejectRequest(c, "You are already logged in.") return c.RejectRequest("You are already logged in.")
} }
var res ResponseData var res ResponseData
@ -75,7 +75,7 @@ func Login(c *RequestContext) ResponseData {
return res return res
} }
user, err := db.QueryOne[models.User](c.Context(), c.Conn, user, err := db.QueryOne[models.User](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM hmn_user FROM hmn_user
@ -102,7 +102,7 @@ func Login(c *RequestContext) ResponseData {
} }
if user.Status == models.UserStatusInactive { if user.Status == models.UserStatusInactive {
return RejectRequest(c, "You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.") return c.RejectRequest("You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.")
} }
res := c.Redirect(redirect, http.StatusSeeOther) res := c.Redirect(redirect, http.StatusSeeOther)
@ -136,7 +136,7 @@ func RegisterNewUser(c *RequestContext) ResponseData {
func RegisterNewUserSubmit(c *RequestContext) ResponseData { func RegisterNewUserSubmit(c *RequestContext) ResponseData {
if c.CurrentUser != nil { if c.CurrentUser != nil {
return RejectRequest(c, "Can't register new user. You are already logged in") return c.RejectRequest("Can't register new user. You are already logged in")
} }
c.Req.ParseForm() c.Req.ParseForm()
@ -146,16 +146,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
password := c.Req.Form.Get("password") password := c.Req.Form.Get("password")
password2 := c.Req.Form.Get("password2") password2 := c.Req.Form.Get("password2")
if !UsernameRegex.Match([]byte(username)) { if !UsernameRegex.Match([]byte(username)) {
return RejectRequest(c, "Invalid username") return c.RejectRequest("Invalid username")
} }
if !email.IsEmail(emailAddress) { if !email.IsEmail(emailAddress) {
return RejectRequest(c, "Invalid email address") return c.RejectRequest("Invalid email address")
} }
if len(password) < 8 { if len(password) < 8 {
return RejectRequest(c, "Password too short") return c.RejectRequest("Password too short")
} }
if password != password2 { if password != password2 {
return RejectRequest(c, "Password confirmation doesn't match password") return c.RejectRequest("Password confirmation doesn't match password")
} }
c.Perf.StartBlock("SQL", "Check blacklist") c.Perf.StartBlock("SQL", "Check blacklist")
@ -169,7 +169,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Check for existing usernames and emails") c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
userAlreadyExists := true userAlreadyExists := true
_, err := db.QueryOneScalar[int](c.Context(), c.Conn, _, err := db.QueryOneScalar[int](c, c.Conn,
` `
SELECT id SELECT id
FROM hmn_user FROM hmn_user
@ -186,11 +186,11 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
} }
if userAlreadyExists { if userAlreadyExists {
return RejectRequest(c, fmt.Sprintf("Username (%s) already exists.", username)) return c.RejectRequest(fmt.Sprintf("Username (%s) already exists.", username))
} }
emailAlreadyExists := true emailAlreadyExists := true
_, err = db.QueryOneScalar[int](c.Context(), c.Conn, _, err = db.QueryOneScalar[int](c, c.Conn,
` `
SELECT id SELECT id
FROM hmn_user FROM hmn_user
@ -215,16 +215,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
hashed := auth.HashPassword(password) hashed := auth.HashPassword(password)
c.Perf.StartBlock("SQL", "Create user and one time token") c.Perf.StartBlock("SQL", "Create user and one time token")
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
now := time.Now() now := time.Now()
var newUserId int var newUserId int
err = tx.QueryRow(c.Context(), err = tx.QueryRow(c,
` `
INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip) INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
@ -237,7 +237,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
} }
ott := models.GenerateToken() ott := models.GenerateToken()
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id) INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
VALUES($1, $2, $3, $4, $5) VALUES($1, $2, $3, $4, $5)
@ -263,7 +263,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Commit user") c.Perf.StartBlock("SQL", "Commit user")
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit user to the db")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit user to the db"))
} }
@ -302,7 +302,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
username, hasUsername := c.PathParams["username"] username, hasUsername := c.PathParams["username"]
if !hasUsername { if !hasUsername {
return RejectRequest(c, "Bad validation url") return c.RejectRequest("Bad validation url")
} }
token := "" token := ""
@ -319,7 +319,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
} }
if !hasToken { if !hasToken {
return RejectRequest(c, "Bad validation url") return c.RejectRequest("Bad validation url")
} }
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypeRegistration) validationResult := validateUsernameAndToken(c, username, token, models.TokenTypeRegistration)
@ -366,13 +366,13 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Updating user status and deleting token") c.Perf.StartBlock("SQL", "Updating user status and deleting token")
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -385,7 +385,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
} }
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM one_time_token WHERE id = $1 DELETE FROM one_time_token WHERE id = $1
`, `,
@ -395,7 +395,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete one time token")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete one time token"))
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit transaction"))
} }
@ -413,7 +413,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
// NOTE(asaf): Only call this when validationResult.Match is false. // NOTE(asaf): Only call this when validationResult.Match is false.
func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, validationResult validateUserAndTokenResult) ResponseData { func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, validationResult validateUserAndTokenResult) ResponseData {
if validationResult.User == nil { if validationResult.User == nil {
return RejectRequest(c, "You haven't validated your email in time and your user was deleted. You may try registering again with the same username.") return c.RejectRequest("You haven't validated your email in time and your user was deleted. You may try registering again with the same username.")
} }
if validationResult.OneTimeToken == nil { if validationResult.OneTimeToken == nil {
@ -422,7 +422,7 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali
return c.Redirect(hmnurl.BuildLoginPage(""), http.StatusSeeOther) return c.Redirect(hmnurl.BuildLoginPage(""), http.StatusSeeOther)
} }
return RejectRequest(c, "Bad token. If you are having problems registering or logging in, please contact the staff.") return c.RejectRequest("Bad token. If you are having problems registering or logging in, please contact the staff.")
} }
// NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email, // NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email,
@ -446,14 +446,14 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
emailAddress := strings.TrimSpace(c.Req.Form.Get("email")) emailAddress := strings.TrimSpace(c.Req.Form.Get("email"))
if username == "" && emailAddress == "" { if username == "" && emailAddress == "" {
return RejectRequest(c, "You must provide a username and an email address.") return c.RejectRequest("You must provide a username and an email address.")
} }
c.Perf.StartBlock("SQL", "Fetching user") c.Perf.StartBlock("SQL", "Fetching user")
type userQuery struct { type userQuery struct {
User models.User `db:"hmn_user"` User models.User `db:"hmn_user"`
} }
user, err := db.QueryOne[models.User](c.Context(), c.Conn, user, err := db.QueryOne[models.User](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM hmn_user FROM hmn_user
@ -473,7 +473,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if user != nil { if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token") c.Perf.StartBlock("SQL", "Fetching existing token")
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, resetToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM one_time_token FROM one_time_token
@ -495,7 +495,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken != nil { if resetToken != nil {
if resetToken.Expires.Before(now.Add(time.Minute * 30)) { // NOTE(asaf): Expired or about to expire if resetToken.Expires.Before(now.Add(time.Minute * 30)) { // NOTE(asaf): Expired or about to expire
c.Perf.StartBlock("SQL", "Deleting expired token") c.Perf.StartBlock("SQL", "Deleting expired token")
_, err = c.Conn.Exec(c.Context(), _, err = c.Conn.Exec(c,
` `
DELETE FROM one_time_token DELETE FROM one_time_token
WHERE id = $1 WHERE id = $1
@ -512,7 +512,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil { if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token") c.Perf.StartBlock("SQL", "Creating new token")
newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, newToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
` `
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id) INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
@ -567,12 +567,12 @@ func DoPasswordReset(c *RequestContext) ResponseData {
token, hasToken := c.PathParams["token"] token, hasToken := c.PathParams["token"]
if !hasToken || !hasUsername { if !hasToken || !hasUsername {
return RejectRequest(c, "Bad validation url.") return c.RejectRequest("Bad validation url.")
} }
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset) validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
if !validationResult.Match { if !validationResult.Match {
return RejectRequest(c, "Bad validation url.") return c.RejectRequest("Bad validation url.")
} }
var res ResponseData var res ResponseData
@ -601,30 +601,30 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset) validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
if !validationResult.Match { if !validationResult.Match {
return RejectRequest(c, "Bad validation url.") return c.RejectRequest("Bad validation url.")
} }
if c.CurrentUser != nil && c.CurrentUser.ID != validationResult.User.ID { if c.CurrentUser != nil && c.CurrentUser.ID != validationResult.User.ID {
return RejectRequest(c, fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username)) return c.RejectRequest(fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username))
} }
if len(password) < 8 { if len(password) < 8 {
return RejectRequest(c, "Password too short") return c.RejectRequest("Password too short")
} }
if password != password2 { if password != password2 {
return RejectRequest(c, "Password confirmation doesn't match password") return c.RejectRequest("Password confirmation doesn't match password")
} }
hashed := auth.HashPassword(password) hashed := auth.HashPassword(password)
c.Perf.StartBlock("SQL", "Update user's password and delete reset token") c.Perf.StartBlock("SQL", "Update user's password and delete reset token")
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
tag, err := tx.Exec(c.Context(), tag, err := tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET password = $1 SET password = $1
@ -638,7 +638,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
} }
if validationResult.User.Status == models.UserStatusInactive { if validationResult.User.Status == models.UserStatusInactive {
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -652,7 +652,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
} }
} }
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM one_time_token DELETE FROM one_time_token
WHERE id = $1 WHERE id = $1
@ -663,7 +663,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete onetimetoken")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete onetimetoken"))
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit password reset to the db")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit password reset to the db"))
} }
@ -698,7 +698,7 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
// re-hash and save the user's password if necessary // re-hash and save the user's password if necessary
if hashed.IsOutdated() { if hashed.IsOutdated() {
newHashed := auth.HashPassword(password) newHashed := auth.HashPassword(password)
err := auth.UpdatePassword(c.Context(), c.Conn, user.Username, newHashed) err := auth.UpdatePassword(c, c.Conn, user.Username, newHashed)
if err != nil { if err != nil {
c.Logger.Error().Err(err).Msg("failed to update user's password") c.Logger.Error().Err(err).Msg("failed to update user's password")
} }
@ -711,15 +711,15 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) error { func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) error {
c.Perf.StartBlock("SQL", "Setting last login and creating session") c.Perf.StartBlock("SQL", "Setting last login and creating session")
defer c.Perf.EndBlock() defer c.Perf.EndBlock()
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return oops.New(err, "failed to start db transaction") return oops.New(err, "failed to start db transaction")
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
now := time.Now() now := time.Now()
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET last_login = $1 SET last_login = $1
@ -732,12 +732,12 @@ func loginUser(c *RequestContext, user *models.User, responseData *ResponseData)
return oops.New(err, "failed to update last_login for user") return oops.New(err, "failed to update last_login for user")
} }
session, err := auth.CreateSession(c.Context(), c.Conn, user.Username) session, err := auth.CreateSession(c, c.Conn, user.Username)
if err != nil { if err != nil {
return oops.New(err, "failed to create session") return oops.New(err, "failed to create session")
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return oops.New(err, "failed to commit transaction") return oops.New(err, "failed to commit transaction")
} }
@ -749,7 +749,7 @@ func logoutUser(c *RequestContext, res *ResponseData) {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil { if err == nil {
// clear the session from the db immediately, no expiration // clear the session from the db immediately, no expiration
err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value) err := auth.DeleteSession(c, c.Conn, sessionCookie.Value)
if err != nil { if err != nil {
logging.Error().Err(err).Msg("failed to delete session on logout") logging.Error().Err(err).Msg("failed to delete session on logout")
} }
@ -772,7 +772,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"hmn_user"` User models.User `db:"hmn_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
} }
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, data, err := db.QueryOne[userAndTokenQuery](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM hmn_user FROM hmn_user

View File

@ -37,7 +37,7 @@ func BlogIndex(c *RequestContext) ResponseData {
const postsPerPage = 20 const postsPerPage = 20
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -51,7 +51,7 @@ func BlogIndex(c *RequestContext) ResponseData {
return c.Redirect(c.UrlContext.BuildBlog(page), http.StatusSeeOther) return c.Redirect(c.UrlContext.BuildBlog(page), http.StatusSeeOther)
} }
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
Limit: postsPerPage, Limit: postsPerPage,
@ -78,7 +78,7 @@ func BlogIndex(c *RequestContext) ResponseData {
canCreate := false canCreate := false
if c.CurrentProject.HasBlog() && c.CurrentUser != nil { if c.CurrentProject.HasBlog() && c.CurrentUser != nil {
isProjectOwner := false isProjectOwner := false
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID) owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project owners")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project owners"))
} }
@ -128,7 +128,7 @@ func BlogThread(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
thread, posts, err := hmndata.FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{ thread, posts, err := hmndata.FetchThreadPosts(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -155,7 +155,7 @@ func BlogThread(c *RequestContext) ResponseData {
// Update thread last read info // Update thread last read info
if c.CurrentUser != nil { if c.CurrentUser != nil {
c.Perf.StartBlock("SQL", "Update TLRI") c.Perf.StartBlock("SQL", "Update TLRI")
_, err := c.Conn.Exec(c.Context(), _, err := c.Conn.Exec(c,
` `
INSERT INTO thread_last_read_info (thread_id, user_id, lastread) INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
@ -196,7 +196,7 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -227,11 +227,11 @@ func BlogNewThread(c *RequestContext) ResponseData {
} }
func BlogNewThreadSubmit(c *RequestContext) ResponseData { func BlogNewThreadSubmit(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
err = c.Req.ParseForm() err = c.Req.ParseForm()
if err != nil { if err != nil {
@ -240,15 +240,15 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
title := c.Req.Form.Get("title") title := c.Req.Form.Get("title")
unparsed := c.Req.Form.Get("body") unparsed := c.Req.Form.Get("body")
if title == "" { if title == "" {
return RejectRequest(c, "You must provide a title for your post.") return c.RejectRequest("You must provide a title for your post.")
} }
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.") return c.RejectRequest("You must provide a body for your post.")
} }
// Create thread // Create thread
var threadId int var threadId int
err = tx.QueryRow(c.Context(), err = tx.QueryRow(c,
` `
INSERT INTO thread (title, type, project_id, first_id, last_id) INSERT INTO thread (title, type, project_id, first_id, last_id)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
@ -265,9 +265,9 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
} }
// Create everything else // Create everything else
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new blog post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new blog post"))
} }
@ -282,11 +282,11 @@ func BlogPostEdit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -326,17 +326,17 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -351,16 +351,16 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
unparsed := c.Req.Form.Get("body") unparsed := c.Req.Form.Get("body")
editReason := c.Req.Form.Get("editreason") editReason := c.Req.Form.Get("editreason")
if title != "" && post.Thread.FirstID != post.Post.ID { if title != "" && post.Thread.FirstID != post.Post.ID {
return RejectRequest(c, "You can only edit the title by editing the first post.") return c.RejectRequest("You can only edit the title by editing the first post.")
} }
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "You must provide a post body.") return c.RejectRequest("You must provide a post body.")
} }
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
if title != "" { if title != "" {
_, err := tx.Exec(c.Context(), _, err := tx.Exec(c,
` `
UPDATE thread SET title = $1 WHERE id = $2 UPDATE thread SET title = $1 WHERE id = $2
`, `,
@ -372,7 +372,7 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
} }
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post"))
} }
@ -387,7 +387,7 @@ func BlogPostReply(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -421,11 +421,11 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
err = c.Req.ParseForm() err = c.Req.ParseForm()
if err != nil { if err != nil {
@ -433,12 +433,12 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
} }
unparsed := c.Req.Form.Get("body") unparsed := c.Req.Form.Get("body")
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "Your reply cannot be empty.") return c.RejectRequest("Your reply cannot be empty.")
} }
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host) newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to blog post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to blog post"))
} }
@ -453,11 +453,11 @@ func BlogPostDelete(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -503,19 +503,19 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID) threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
} }
@ -523,7 +523,7 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
if threadDeleted { if threadDeleted {
return c.Redirect(c.UrlContext.BuildHomepage(), http.StatusSeeOther) return c.Redirect(c.UrlContext.BuildHomepage(), http.StatusSeeOther)
} else { } else {
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{ thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
}) })
@ -560,7 +560,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
res.ThreadID = threadId res.ThreadID = threadId
c.Perf.StartBlock("SQL", "Verify that the thread exists") c.Perf.StartBlock("SQL", "Verify that the thread exists")
threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, threadExists, err := db.QueryOneScalar[bool](c, c.Conn,
` `
SELECT COUNT(*) > 0 SELECT COUNT(*) > 0
FROM thread FROM thread
@ -588,7 +588,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
res.PostID = postId res.PostID = postId
c.Perf.StartBlock("SQL", "Verify that the post exists") c.Perf.StartBlock("SQL", "Verify that the post exists")
postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, postExists, err := db.QueryOneScalar[bool](c, c.Conn,
` `
SELECT COUNT(*) > 0 SELECT COUNT(*) > 0
FROM post FROM post

138
src/website/common.go Normal file
View File

@ -0,0 +1,138 @@
package website
import (
"errors"
"net/http"
"net/url"
"strings"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/hmndata"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates"
)
func loadCommonData(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
{
// get user
{
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
}
c.CurrentUser = user
c.CurrentSession = session
}
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
}
// get current official project (HMN or otherwise, by subdomain)
{
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost())
slug := strings.TrimRight(hostPrefix, ".")
var owners []*models.User
if len(slug) > 0 {
dbProject, err := hmndata.FetchProjectBySlug(c, c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err == nil {
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
owners = dbProject.Owners
} else {
if errors.Is(err, db.NotFound) {
// do nothing, this is fine
} else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
}
}
if c.CurrentProject == nil {
dbProject, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
panic(oops.New(err, "failed to fetch HMN project"))
}
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
}
if c.CurrentProject == nil {
panic("failed to load project data")
}
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
}
c.Theme = "light"
if c.CurrentUser != nil && c.CurrentUser.DarkTheme {
c.Theme = "dark"
}
}
c.Perf.EndBlock()
return h(c)
}
}
// Given a session id, fetches user data from the database. Will return nil if
// the user cannot be found, and will only return an error if it's serious.
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
session, err := auth.GetSession(c, c.Conn, sessionId)
if err != nil {
if errors.Is(err, auth.ErrNoSession) {
return nil, nil, nil
} else {
return nil, nil, oops.New(err, "failed to get current session")
}
}
user, err := hmndata.FetchUserByUsername(c, c.Conn, nil, session.Username, hmndata.UsersQuery{
AnyStatus: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
return nil, nil, nil // user was deleted or something
} else {
return nil, nil, oops.New(err, "failed to get user for session")
}
}
return user, session, nil
}
func addCORSHeaders(c *RequestContext, res *ResponseData) {
parsed, err := url.Parse(config.Config.BaseUrl)
if err != nil {
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
return
}
origin := ""
origins, found := c.Req.Header["Origin"]
if found {
origin = origins[0]
}
if strings.HasSuffix(origin, parsed.Host) {
res.Header().Add("Access-Control-Allow-Origin", origin)
res.Header().Add("Access-Control-Allow-Credentials", "true")
res.Header().Add("Vary", "Origin")
}
}

View File

@ -35,31 +35,31 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
// This occurs when the user cancels. Just go back to the profile page. // This occurs when the user cancels. Just go back to the profile page.
return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther)
} else { } else {
return RejectRequest(c, "Failed to authenticate with Discord.") return c.RejectRequest("Failed to authenticate with Discord.")
} }
} }
// Do the actual token exchange // Do the actual token exchange
code := query.Get("code") code := query.Get("code")
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback())
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
} }
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second) expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken) user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
} }
// Add the role on Discord // Add the role on Discord
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID) err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
} }
// Add the user to our database // Add the user to our database
_, err = c.Conn.Exec(c.Context(), _, err = c.Conn.Exec(c,
` `
INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id) INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
@ -79,7 +79,7 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
} }
if c.CurrentUser.Status == models.UserStatusConfirmed { if c.CurrentUser.Status == models.UserStatusConfirmed {
_, err = c.Conn.Exec(c.Context(), _, err = c.Conn.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -98,13 +98,13 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
} }
func DiscordUnlink(c *RequestContext) ResponseData { func DiscordUnlink(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx, discordUser, err := db.QueryOne[models.DiscordUser](c, tx,
` `
SELECT $columns SELECT $columns
FROM discord_user FROM discord_user
@ -120,7 +120,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
} }
} }
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM discord_user DELETE FROM discord_user
WHERE id = $1 WHERE id = $1
@ -131,12 +131,12 @@ func DiscordUnlink(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete Discord user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete Discord user"))
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit Discord user delete")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit Discord user delete"))
} }
err = discord.RemoveGuildMemberRole(c.Context(), discordUser.UserID, config.Config.Discord.MemberRoleID) err = discord.RemoveGuildMemberRole(c, discordUser.UserID, config.Config.Discord.MemberRoleID)
if err != nil { if err != nil {
c.Logger.Warn().Err(err).Msg("failed to remove member role on unlink") c.Logger.Warn().Err(err).Msg("failed to remove member role on unlink")
} }
@ -145,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
} }
func DiscordShowcaseBacklog(c *RequestContext) ResponseData { func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
`SELECT $columns FROM discord_user WHERE hmn_user_id = $1`, `SELECT $columns FROM discord_user WHERE hmn_user_id = $1`,
c.CurrentUser.ID, c.CurrentUser.ID,
) )
@ -157,7 +157,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
} }
msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn, msgIDs, err := db.QueryScalar[string](c, c.Conn,
` `
SELECT msg.id SELECT msg.id
FROM FROM
@ -174,12 +174,12 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
} }
for _, msgID := range msgIDs { for _, msgID := range msgIDs {
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID) interned, err := discord.FetchInternedMessage(c, c.Conn, msgID)
if err != nil && !errors.Is(err, db.NotFound) { if err != nil && !errors.Is(err, db.NotFound) {
return c.ErrorResponse(http.StatusInternalServerError, err) return c.ErrorResponse(http.StatusInternalServerError, err)
} else if err == nil { } else if err == nil {
// NOTE(asaf): Creating snippet even if the checkbox is off because the user asked us to. // NOTE(asaf): Creating snippet even if the checkbox is off because the user asked us to.
err = discord.HandleSnippetForInternedMessage(c.Context(), c.Conn, interned, true) err = discord.HandleSnippetForInternedMessage(c, c.Conn, interned, true)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err) return c.ErrorResponse(http.StatusInternalServerError, err)
} }

View File

@ -1,6 +1,31 @@
package website package website
import "fmt" import (
"fmt"
"net/http"
"strings"
"git.handmade.network/hmn/hmn/src/templates"
)
func FourOhFour(c *RequestContext) ResponseData {
var res ResponseData
res.StatusCode = http.StatusNotFound
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
templateData := struct {
templates.BaseData
Wanted string
}{
BaseData: getBaseData(c, "Page not found", nil),
Wanted: c.FullUrl(),
}
res.MustWriteTemplate("404.html", templateData, c.Perf)
} else {
res.Write([]byte("Not Found"))
}
return res
}
// A SafeError can be used to wrap another error and explicitly provide // A SafeError can be used to wrap another error and explicitly provide
// an error message that is safe to show to a user. This allows the original // an error message that is safe to show to a user. This allows the original

View File

@ -33,7 +33,7 @@ var feedThreadTypes = []models.ThreadType{
} }
func Feed(c *RequestContext) ResponseData { func Feed(c *RequestContext) ResponseData {
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes, ThreadTypes: feedThreadTypes,
}) })
if err != nil { if err != nil {
@ -156,7 +156,7 @@ func AtomFeed(c *RequestContext) ResponseData {
if hasAll { if hasAll {
itemsPerFeed = 100000 itemsPerFeed = 100000
} }
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, nil, hmndata.ProjectsQuery{ projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, nil, hmndata.ProjectsQuery{
Limit: itemsPerFeed, Limit: itemsPerFeed,
Types: hmndata.OfficialProjects, Types: hmndata.OfficialProjects,
OrderBy: "date_approved DESC", OrderBy: "date_approved DESC",
@ -188,7 +188,7 @@ func AtomFeed(c *RequestContext) ResponseData {
feedData.AtomFeedUrl = hmnurl.BuildAtomFeedForShowcase() feedData.AtomFeedUrl = hmnurl.BuildAtomFeedForShowcase()
feedData.FeedUrl = hmnurl.BuildShowcase() feedData.FeedUrl = hmnurl.BuildShowcase()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Limit: itemsPerFeed, Limit: itemsPerFeed,
}) })
if err != nil { if err != nil {
@ -215,7 +215,7 @@ func AtomFeed(c *RequestContext) ResponseData {
} }
func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostListItem, error) { func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostListItem, error) {
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes, ThreadTypes: feedThreadTypes,
Limit: limit, Limit: limit,
Offset: offset, Offset: offset,
@ -225,7 +225,7 @@ func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostLi
return nil, err return nil, err
} }
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()

View File

@ -206,7 +206,7 @@ func Fishbowl(c *RequestContext) ResponseData {
func FishbowlFiles(c *RequestContext) ResponseData { func FishbowlFiles(c *RequestContext) ResponseData {
var res ResponseData var res ResponseData
fishbowlHTTPFS.ServeHTTP(&res, c.Req) fishbowlHTTPFS.ServeHTTP(&res, c.Req)
AddCORSHeaders(c, &res) addCORSHeaders(c, &res)
return res return res
} }
@ -224,7 +224,7 @@ func linkifyDiscordContent(c *RequestContext, dbConn db.ConnOrTx, content string
discordUserIds = append(discordUserIds, id) discordUserIds = append(discordUserIds, id)
} }
hmnUsers, err := hmndata.FetchUsers(c.Context(), dbConn, c.CurrentUser, hmndata.UsersQuery{ hmnUsers, err := hmndata.FetchUsers(c, dbConn, c.CurrentUser, hmndata.UsersQuery{
DiscordUserIDs: discordUserIds, DiscordUserIDs: discordUserIds,
}) })
if err != nil { if err != nil {

View File

@ -91,7 +91,7 @@ func Forum(c *RequestContext) ResponseData {
currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID)
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{cd.SubforumID}, SubforumIDs: []int{cd.SubforumID},
@ -107,7 +107,7 @@ func Forum(c *RequestContext) ResponseData {
} }
howManyThreadsToSkip := (page - 1) * threadsPerPage howManyThreadsToSkip := (page - 1) * threadsPerPage
mainThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ mainThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{cd.SubforumID}, SubforumIDs: []int{cd.SubforumID},
@ -141,7 +141,7 @@ func Forum(c *RequestContext) ResponseData {
subforumNodes := cd.SubforumTree[cd.SubforumID].Children subforumNodes := cd.SubforumTree[cd.SubforumID].Children
for _, sfNode := range subforumNodes { for _, sfNode := range subforumNodes {
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{sfNode.ID}, SubforumIDs: []int{sfNode.ID},
@ -150,7 +150,7 @@ func Forum(c *RequestContext) ResponseData {
panic(oops.New(err, "failed to get count of threads")) panic(oops.New(err, "failed to get count of threads"))
} }
subforumThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ subforumThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{sfNode.ID}, SubforumIDs: []int{sfNode.ID},
@ -203,7 +203,7 @@ func Forum(c *RequestContext) ResponseData {
func ForumMarkRead(c *RequestContext) ResponseData { func ForumMarkRead(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
@ -212,16 +212,16 @@ func ForumMarkRead(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
sfIds := []int{sfId} sfIds := []int{sfId}
if sfId == 0 { if sfId == 0 {
// Mark literally everything as read // Mark literally everything as read
_, err := tx.Exec(c.Context(), _, err := tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET marked_all_read_at = NOW() SET marked_all_read_at = NOW()
@ -234,7 +234,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
} }
// Delete thread unread info // Delete thread unread info
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM thread_last_read_info DELETE FROM thread_last_read_info
WHERE user_id = $1; WHERE user_id = $1;
@ -246,7 +246,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
} }
// Delete subforum unread info // Delete subforum unread info
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM subforum_last_read_info DELETE FROM subforum_last_read_info
WHERE user_id = $1; WHERE user_id = $1;
@ -258,7 +258,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
} }
} else { } else {
c.Perf.StartBlock("SQL", "Update SLRIs") c.Perf.StartBlock("SQL", "Update SLRIs")
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
INSERT INTO subforum_last_read_info (subforum_id, user_id, lastread) INSERT INTO subforum_last_read_info (subforum_id, user_id, lastread)
SELECT id, $2, $3 SELECT id, $2, $3
@ -277,7 +277,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Delete TLRIs") c.Perf.StartBlock("SQL", "Delete TLRIs")
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
DELETE FROM thread_last_read_info DELETE FROM thread_last_read_info
WHERE WHERE
@ -298,7 +298,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
} }
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit SLRI/TLRI updates")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit SLRI/TLRI updates"))
} }
@ -332,7 +332,7 @@ func ForumThread(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadIDs: []int{cd.ThreadID}, ThreadIDs: []int{cd.ThreadID},
}) })
@ -351,7 +351,7 @@ func ForumThread(c *RequestContext) ResponseData {
return c.Redirect(correctThreadUrl, http.StatusSeeOther) return c.Redirect(correctThreadUrl, http.StatusSeeOther)
} }
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
ThreadIDs: []int{cd.ThreadID}, ThreadIDs: []int{cd.ThreadID},
@ -374,7 +374,7 @@ func ForumThread(c *RequestContext) ResponseData {
PreviousUrl: c.UrlContext.BuildForumThread(currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)), PreviousUrl: c.UrlContext.BuildForumThread(currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)),
} }
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadIDs: []int{thread.ID}, ThreadIDs: []int{thread.ID},
Limit: threadPostsPerPage, Limit: threadPostsPerPage,
@ -396,7 +396,7 @@ func ForumThread(c *RequestContext) ResponseData {
post.ReplyPost = &reply post.ReplyPost = &reply
} }
addAuthorCountsToPost(c.Context(), c.Conn, &post) addAuthorCountsToPost(c, c.Conn, &post)
posts = append(posts, post) posts = append(posts, post)
} }
@ -404,7 +404,7 @@ func ForumThread(c *RequestContext) ResponseData {
// Update thread last read info // Update thread last read info
if c.CurrentUser != nil { if c.CurrentUser != nil {
c.Perf.StartBlock("SQL", "Update TLRI") c.Perf.StartBlock("SQL", "Update TLRI")
_, err = c.Conn.Exec(c.Context(), _, err = c.Conn.Exec(c,
` `
INSERT INTO thread_last_read_info (thread_id, user_id, lastread) INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
@ -445,7 +445,7 @@ func ForumPostRedirect(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
ThreadIDs: []int{cd.ThreadID}, ThreadIDs: []int{cd.ThreadID},
@ -495,11 +495,11 @@ func ForumNewThread(c *RequestContext) ResponseData {
} }
func ForumNewThreadSubmit(c *RequestContext) ResponseData { func ForumNewThreadSubmit(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
cd, ok := getCommonForumData(c) cd, ok := getCommonForumData(c)
if !ok { if !ok {
@ -517,15 +517,15 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
sticky = true sticky = true
} }
if title == "" { if title == "" {
return RejectRequest(c, "You must provide a title for your post.") return c.RejectRequest("You must provide a title for your post.")
} }
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.") return c.RejectRequest("You must provide a body for your post.")
} }
// Create thread // Create thread
var threadId int var threadId int
err = tx.QueryRow(c.Context(), err = tx.QueryRow(c,
` `
INSERT INTO thread (title, sticky, type, project_id, subforum_id, first_id, last_id) INSERT INTO thread (title, sticky, type, project_id, subforum_id, first_id, last_id)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7)
@ -544,9 +544,9 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
} }
// Create everything else // Create everything else
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host) hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new forum thread")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new forum thread"))
} }
@ -561,7 +561,7 @@ func ForumPostReply(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
}) })
@ -600,11 +600,11 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
err = c.Req.ParseForm() err = c.Req.ParseForm()
if err != nil { if err != nil {
@ -612,10 +612,10 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
} }
unparsed := c.Req.Form.Get("body") unparsed := c.Req.Form.Get("body")
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "Your reply cannot be empty.") return c.RejectRequest("Your reply cannot be empty.")
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
}) })
@ -629,9 +629,9 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
replyPostId = &post.Post.ID replyPostId = &post.Post.ID
} }
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host) newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to forum post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to forum post"))
} }
@ -646,11 +646,11 @@ func ForumPostEdit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
}) })
@ -688,17 +688,17 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
}) })
@ -713,16 +713,16 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
unparsed := c.Req.Form.Get("body") unparsed := c.Req.Form.Get("body")
editReason := c.Req.Form.Get("editreason") editReason := c.Req.Form.Get("editreason")
if title != "" && post.Thread.FirstID != post.Post.ID { if title != "" && post.Thread.FirstID != post.Post.ID {
return RejectRequest(c, "You can only edit the title by editing the first post.") return c.RejectRequest("You can only edit the title by editing the first post.")
} }
if unparsed == "" { if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.") return c.RejectRequest("You must provide a body for your post.")
} }
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
if title != "" { if title != "" {
_, err := tx.Exec(c.Context(), _, err := tx.Exec(c,
` `
UPDATE thread SET title = $1 WHERE id = $2 UPDATE thread SET title = $1 WHERE id = $2
`, `,
@ -734,7 +734,7 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
} }
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit forum post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit forum post"))
} }
@ -749,11 +749,11 @@ func ForumPostDelete(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{ post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
}) })
@ -798,19 +798,19 @@ func ForumPostDeleteSubmit(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) { if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c) return FourOhFour(c)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID) threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
} }
@ -831,7 +831,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
panic(err) panic(err)
} }
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{ thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
// This is the rare query where we want all thread types! // This is the rare query where we want all thread types!
}) })
@ -842,7 +842,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
@ -874,7 +874,7 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) {
defer c.Perf.EndBlock() defer c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()

View File

@ -89,7 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
img.Seek(0, io.SeekStart) img.Seek(0, io.SeekStart)
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
sha1sum := hasher.Sum(nil) sha1sum := hasher.Sum(nil)
imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn, imageFile, err := db.QueryOne[models.ImageFile](c, dbConn,
` `
INSERT INTO image_file (file, size, sha1sum, protected, width, height) INSERT INTO image_file (file, size, sha1sum, protected, width, height)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)

View File

@ -51,7 +51,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
} }
tagId := -1 tagId := -1
jamTag, err := hmndata.FetchTag(c.Context(), c.Conn, hmndata.TagQuery{ jamTag, err := hmndata.FetchTag(c, c.Conn, hmndata.TagQuery{
Text: []string{"wheeljam"}, Text: []string{"wheeljam"},
}) })
if err == nil { if err == nil {
@ -60,7 +60,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
c.Logger.Warn().Err(err).Msg("failed to fetch jam tag; will fetch all snippets as a result") c.Logger.Warn().Err(err).Msg("failed to fetch jam tag; will fetch all snippets as a result")
} }
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Tags: []int{tagId}, Tags: []int{tagId},
}) })
if err != nil { if err != nil {

View File

@ -34,13 +34,13 @@ type LandingTemplateData struct {
func Index(c *RequestContext) ResponseData { func Index(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
var timelineItems []templates.TimelineItem var timelineItems []templates.TimelineItem
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes, ThreadTypes: feedThreadTypes,
}) })
if err != nil { if err != nil {
@ -65,7 +65,7 @@ func Index(c *RequestContext) ResponseData {
} }
// This is essentially an alternate for feed page 1. // This is essentially an alternate for feed page 1.
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes, ThreadTypes: feedThreadTypes,
Limit: feedPostsPerPage, Limit: feedPostsPerPage,
SortDescending: true, SortDescending: true,
@ -84,7 +84,7 @@ func Index(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Get news") c.Perf.StartBlock("SQL", "Get news")
newsThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{ newsThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{models.HMNProjectID}, ProjectIDs: []int{models.HMNProjectID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
Limit: 1, Limit: 1,
@ -106,7 +106,7 @@ func Index(c *RequestContext) ResponseData {
} }
c.Perf.EndBlock() c.Perf.EndBlock()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Limit: 40, Limit: 40,
}) })
if err != nil { if err != nil {

123
src/website/middlewares.go Normal file
View File

@ -0,0 +1,123 @@
package website
import (
"fmt"
"math/rand"
"net/http"
"time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/utils"
)
func panicCatcherMiddleware(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
defer func() {
if recovered := recover(); recovered != nil {
maybeError, ok := recovered.(*error)
var err error
if ok {
err = *maybeError
} else {
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
}
res = c.ErrorResponse(http.StatusInternalServerError, err)
}
}()
return h(c)
}
}
func trackRequestPerf(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
defer func() {
c.Perf.EndRequest()
log := logging.Info()
blockStack := make([]time.Time, 0)
for i, block := range c.Perf.Blocks {
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
blockStack = blockStack[:len(blockStack)-1]
}
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
blockStack = append(blockStack, block.End)
}
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
}()
return h(c)
}
}
func needsAuth(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil {
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
}
return h(c)
}
}
func adminsOnly(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
return FourOhFour(c)
}
return h(c)
}
}
func csrfMiddleware(h Handler) Handler {
// CSRF mitigation actions per the OWASP cheat sheet:
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
return func(c *RequestContext) ResponseData {
c.Req.ParseMultipartForm(100 * 1024 * 1024)
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
if csrfToken != c.CurrentSession.CSRFToken {
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
res := c.Redirect("/", http.StatusSeeOther)
logoutUser(c, &res)
return res
}
return h(c)
}
}
func securityTimerMiddleware(duration time.Duration, h Handler) Handler {
// NOTE(asaf): Will make sure that the request takes at least `duration` to finish. Adds a 10% random duration.
return func(c *RequestContext) ResponseData {
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
timer := time.NewTimer(duration + additionalDuration)
res := h(c)
select {
case <-c.Done():
case <-timer.C:
}
return res
}
}
func logContextErrors(c *RequestContext, errs ...error) {
for _, err := range errs {
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
}
}
func logContextErrorsMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
res := h(c)
logContextErrors(c, res.Errors...)
return res
}
}

100
src/website/notices.go Normal file
View File

@ -0,0 +1,100 @@
package website
import (
"errors"
"html/template"
"net/http"
"strings"
"time"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/templates"
)
const NoticesCookieName = "hmn_notices"
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
cookie, err := c.Req.Cookie(NoticesCookieName)
if err != nil {
if !errors.Is(err, http.ErrNoCookie) {
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
}
return nil
}
return deserializeNoticesFromCookie(cookie.Value)
}
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
serialized := serializeNoticesForCookie(c, res.FutureNotices)
if serialized != "" {
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Value: serialized,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
Expires: time.Now().Add(time.Minute * 5),
Secure: config.Config.Auth.CookieSecure,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
res.SetCookie(&noticesCookie)
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
// NOTE(asaf): Don't clear on redirect
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
MaxAge: -1,
}
res.SetCookie(&noticesCookie)
}
}
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
var builder strings.Builder
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
size := 0
for i, notice := range notices {
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
if i != 0 {
sizeIncrease += 1
}
if size+sizeIncrease > maxSize {
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
break
}
if i != 0 {
builder.WriteString("\t")
}
builder.WriteString(notice.Class)
builder.WriteString("|")
builder.WriteString(string(notice.Content))
size += sizeIncrease
}
return builder.String()
}
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
var result []templates.Notice
notices := strings.Split(cookieVal, "\t")
for _, notice := range notices {
parts := strings.SplitN(notice, "|", 2)
if len(parts) == 2 {
result = append(result, templates.Notice{
Class: parts[0],
Content: template.HTML(parts[1]),
})
}
}
return result
}
func storeNoticesInCookieMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
res := h(c)
storeNoticesInCookie(c, &res)
return res
}
}

View File

@ -126,29 +126,29 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
title := c.Req.Form.Get("title") title := c.Req.Form.Get("title")
if len(strings.TrimSpace(title)) == 0 { if len(strings.TrimSpace(title)) == 0 {
return RejectRequest(c, "Podcast title is empty") return c.RejectRequest("Podcast title is empty")
} }
description := c.Req.Form.Get("description") description := c.Req.Form.Get("description")
if len(strings.TrimSpace(description)) == 0 { if len(strings.TrimSpace(description)) == 0 {
return RejectRequest(c, "Podcast description is empty") return c.RejectRequest("Podcast description is empty")
} }
c.Perf.StartBlock("SQL", "Updating podcast") c.Perf.StartBlock("SQL", "Updating podcast")
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
imageSaveResult := SaveImageFile(c, tx, "podcast_image", maxFileSize, fmt.Sprintf("podcast/%s/logo%d", c.CurrentProject.Slug, time.Now().UTC().Unix())) imageSaveResult := SaveImageFile(c, tx, "podcast_image", maxFileSize, fmt.Sprintf("podcast/%s/logo%d", c.CurrentProject.Slug, time.Now().UTC().Unix()))
if imageSaveResult.ValidationError != "" { if imageSaveResult.ValidationError != "" {
return RejectRequest(c, imageSaveResult.ValidationError) return c.RejectRequest(imageSaveResult.ValidationError)
} else if imageSaveResult.FatalError != nil { } else if imageSaveResult.FatalError != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(imageSaveResult.FatalError, "Failed to save podcast image")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(imageSaveResult.FatalError, "Failed to save podcast image"))
} }
if imageSaveResult.ImageFile != nil { if imageSaveResult.ImageFile != nil {
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE podcast UPDATE podcast
SET SET
@ -166,7 +166,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to update podcast")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to update podcast"))
} }
} else { } else {
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE podcast UPDATE podcast
SET SET
@ -179,7 +179,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
podcastResult.Podcast.ID, podcastResult.Podcast.ID,
) )
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
c.Perf.EndBlock() c.Perf.EndBlock()
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to commit db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to commit db transaction"))
@ -357,16 +357,16 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
c.Req.ParseForm() c.Req.ParseForm()
title := c.Req.Form.Get("title") title := c.Req.Form.Get("title")
if len(strings.TrimSpace(title)) == 0 { if len(strings.TrimSpace(title)) == 0 {
return RejectRequest(c, "Episode title is empty") return c.RejectRequest("Episode title is empty")
} }
description := c.Req.Form.Get("description") description := c.Req.Form.Get("description")
if len(strings.TrimSpace(description)) == 0 { if len(strings.TrimSpace(description)) == 0 {
return RejectRequest(c, "Episode description is empty") return c.RejectRequest("Episode description is empty")
} }
episodeNumberStr := c.Req.Form.Get("episode_number") episodeNumberStr := c.Req.Form.Get("episode_number")
episodeNumber, err := strconv.Atoi(episodeNumberStr) episodeNumber, err := strconv.Atoi(episodeNumberStr)
if err != nil { if err != nil {
return RejectRequest(c, "Episode number can't be parsed") return c.RejectRequest("Episode number can't be parsed")
} }
episodeFile := c.Req.Form.Get("episode_file") episodeFile := c.Req.Form.Get("episode_file")
found = false found = false
@ -378,7 +378,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
} }
if !found { if !found {
return RejectRequest(c, "Requested episode file not found") return c.RejectRequest("Requested episode file not found")
} }
c.Perf.StartBlock("MP3", "Parsing mp3 file for duration") c.Perf.StartBlock("MP3", "Parsing mp3 file for duration")
@ -417,7 +417,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
if isEdit { if isEdit {
guidStr = podcastResult.Episodes[0].GUID.String() guidStr = podcastResult.Episodes[0].GUID.String()
c.Perf.StartBlock("SQL", "Updating podcast episode") c.Perf.StartBlock("SQL", "Updating podcast episode")
_, err := c.Conn.Exec(c.Context(), _, err := c.Conn.Exec(c,
` `
UPDATE podcast_episode UPDATE podcast_episode
SET SET
@ -446,7 +446,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
guid := uuid.New() guid := uuid.New()
guidStr = guid.String() guidStr = guid.String()
c.Perf.StartBlock("SQL", "Creating new podcast episode") c.Perf.StartBlock("SQL", "Creating new podcast episode")
_, err := c.Conn.Exec(c.Context(), _, err := c.Conn.Exec(c,
` `
INSERT INTO podcast_episode INSERT INTO podcast_episode
(guid, title, description, description_rendered, audio_filename, duration, pub_date, episode_number, podcast_id) (guid, title, description, description_rendered, audio_filename, duration, pub_date, episode_number, podcast_id)
@ -532,7 +532,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
Podcast models.Podcast `db:"podcast"` Podcast models.Podcast `db:"podcast"`
ImageFilename string `db:"imagefile.file"` ImageFilename string `db:"imagefile.file"`
} }
podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn, podcastQueryResult, err := db.QueryOne[podcastQuery](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -558,7 +558,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
if fetchEpisodes { if fetchEpisodes {
if episodeGUID == "" { if episodeGUID == "" {
c.Perf.StartBlock("SQL", "Fetch podcast episodes") c.Perf.StartBlock("SQL", "Fetch podcast episodes")
episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn, episodes, err := db.Query[models.PodcastEpisode](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM podcast_episode AS episode FROM podcast_episode AS episode
@ -578,7 +578,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
return result, err return result, err
} }
c.Perf.StartBlock("SQL", "Fetch podcast episode") c.Perf.StartBlock("SQL", "Fetch podcast episode")
episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn, episode, err := db.QueryOne[models.PodcastEpisode](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM podcast_episode AS episode FROM podcast_episode AS episode

View File

@ -26,11 +26,52 @@ import (
"git.handmade.network/hmn/hmn/src/utils" "git.handmade.network/hmn/hmn/src/utils"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/teacat/noire"
) )
const maxPersonalProjects = 5 const maxPersonalProjects = 5
const maxProjectOwners = 5 const maxProjectOwners = 5
func ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color")
if color == "" {
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
}
baseData := getBaseData(c, "", nil)
bgColor := noire.NewHex(color)
h, s, l := bgColor.HSL()
if baseData.Theme == "dark" {
l = 15
} else {
l = 95
}
if s > 20 {
s = 20
}
bgColor = noire.NewHSL(h, s, l)
templateData := struct {
templates.BaseData
Color string
PostBgColor string
}{
BaseData: baseData,
Color: color,
PostBgColor: bgColor.HTML(),
}
var res ResponseData
res.Header().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
return res
}
type ProjectTemplateData struct { type ProjectTemplateData struct {
templates.BaseData templates.BaseData
@ -48,7 +89,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
const maxCarouselProjects = 10 const maxCarouselProjects = 10
const maxPersonalProjects = 10 const maxPersonalProjects = 10
officialProjects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ officialProjects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
Types: hmndata.OfficialProjects, Types: hmndata.OfficialProjects,
}) })
if err != nil { if err != nil {
@ -123,7 +164,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
// Fetch and highlight a random selection of personal projects // Fetch and highlight a random selection of personal projects
var personalProjects []templates.Project var personalProjects []templates.Project
{ {
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
Types: hmndata.PersonalProjects, Types: hmndata.PersonalProjects,
}) })
if err != nil { if err != nil {
@ -181,13 +222,13 @@ func ProjectHomepage(c *RequestContext) ResponseData {
// There are no further permission checks to do, because permissions are // There are no further permission checks to do, because permissions are
// checked whatever way we fetch the project. // checked whatever way we fetch the project.
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID) owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err) return c.ErrorResponse(http.StatusInternalServerError, err)
} }
c.Perf.StartBlock("SQL", "Fetching screenshots") c.Perf.StartBlock("SQL", "Fetching screenshots")
screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn, screenshotFilenames, err := db.QueryScalar[string](c, c.Conn,
` `
SELECT screenshot.file SELECT screenshot.file
FROM FROM
@ -204,7 +245,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
c.Perf.EndBlock() c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetching project links") c.Perf.StartBlock("SQL", "Fetching project links")
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, projectLinks, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -221,12 +262,12 @@ func ProjectHomepage(c *RequestContext) ResponseData {
c.Perf.EndBlock() c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetching project timeline") c.Perf.StartBlock("SQL", "Fetching project timeline")
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID}, ProjectIDs: []int{c.CurrentProject.ID},
Limit: maxRecentActivity, Limit: maxRecentActivity,
SortDescending: true, SortDescending: true,
@ -241,7 +282,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
Value: c.CurrentProject.Blurb, Value: c.CurrentProject.Blurb,
}) })
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{ p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles, Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true, IncludeHidden: true,
}) })
@ -317,7 +358,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
tagId = *c.CurrentProject.TagID tagId = *c.CurrentProject.TagID
} }
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Tags: []int{tagId}, Tags: []int{tagId},
}) })
if err != nil { if err != nil {
@ -364,7 +405,7 @@ type ProjectEditData struct {
} }
func ProjectNew(c *RequestContext) ResponseData { func ProjectNew(c *RequestContext) ResponseData {
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{c.CurrentUser.ID}, OwnerIDs: []int{c.CurrentUser.ID},
Types: hmndata.PersonalProjects, Types: hmndata.PersonalProjects,
}) })
@ -372,7 +413,7 @@ func ProjectNew(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
} }
if numProjects >= maxPersonalProjects { if numProjects >= maxPersonalProjects {
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
} }
var project templates.ProjectSettings var project templates.ProjectSettings
@ -397,16 +438,16 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error) return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
} }
if len(formResult.RejectionReason) != 0 { if len(formResult.RejectionReason) != 0 {
return RejectRequest(c, formResult.RejectionReason) return c.RejectRequest(formResult.RejectionReason)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{c.CurrentUser.ID}, OwnerIDs: []int{c.CurrentUser.ID},
Types: hmndata.PersonalProjects, Types: hmndata.PersonalProjects,
}) })
@ -414,11 +455,11 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
} }
if numProjects >= maxPersonalProjects { if numProjects >= maxPersonalProjects {
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects)) return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
} }
var projectId int var projectId int
err = tx.QueryRow(c.Context(), err = tx.QueryRow(c,
` `
INSERT INTO project INSERT INTO project
(name, blurb, description, descparsed, lifecycle, date_created, all_last_updated) (name, blurb, description, descparsed, lifecycle, date_created, all_last_updated)
@ -439,12 +480,12 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
formResult.Payload.ProjectID = projectId formResult.Payload.ProjectID = projectId
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload) err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err) return c.ErrorResponse(http.StatusInternalServerError, err)
} }
tx.Commit(c.Context()) tx.Commit(c)
urlContext := &hmnurl.UrlContext{ urlContext := &hmnurl.UrlContext{
PersonalProject: true, PersonalProject: true,
@ -461,7 +502,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
} }
p, err := hmndata.FetchProject( p, err := hmndata.FetchProject(
c.Context(), c.Conn, c, c.Conn,
c.CurrentUser, c.CurrentProject.ID, c.CurrentUser, c.CurrentProject.ID,
hmndata.ProjectsQuery{ hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles, Lifecycles: models.AllProjectLifecycles,
@ -473,7 +514,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Fetching project links") c.Perf.StartBlock("SQL", "Fetching project links")
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, projectLinks, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -524,23 +565,23 @@ func ProjectEditSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error) return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
} }
if len(formResult.RejectionReason) != 0 { if len(formResult.RejectionReason) != 0 {
return RejectRequest(c, formResult.RejectionReason) return c.RejectRequest(formResult.RejectionReason)
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
formResult.Payload.ProjectID = c.CurrentProject.ID formResult.Payload.ProjectID = c.CurrentProject.ID
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload) err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err) return c.ErrorResponse(http.StatusInternalServerError, err)
} }
tx.Commit(c.Context()) tx.Commit(c)
urlContext := &hmnurl.UrlContext{ urlContext := &hmnurl.UrlContext{
PersonalProject: formResult.Payload.Personal, PersonalProject: formResult.Payload.Personal,

View File

@ -13,6 +13,7 @@ import (
"path" "path"
"regexp" "regexp"
"strings" "strings"
"time"
"git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/logging"
@ -43,15 +44,22 @@ func (r *Route) String() string {
} }
type RouteBuilder struct { type RouteBuilder struct {
Router *Router Router *Router
Prefixes []*regexp.Regexp Prefixes []*regexp.Regexp
Middleware Middleware Middlewares []Middleware
} }
type Handler func(c *RequestContext) ResponseData type Handler func(c *RequestContext) ResponseData
type Middleware func(h Handler) Handler type Middleware func(h Handler) Handler
func applyMiddlewares(h Handler, ms []Middleware) Handler {
result := h
for i := len(ms) - 1; i >= 0; i-- {
result = ms[i](result)
}
return result
}
func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler) { func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler) {
// Ensure that this regex matches the start of the string // Ensure that this regex matches the start of the string
regexStr := regex.String() regexStr := regex.String()
@ -59,7 +67,7 @@ func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler
panic("All routing regexes must begin with '^'") panic("All routing regexes must begin with '^'")
} }
h = rb.Middleware(h) h = applyMiddlewares(h, rb.Middlewares)
for _, method := range methods { for _, method := range methods {
rb.Router.Routes = append(rb.Router.Routes, Route{ rb.Router.Routes = append(rb.Router.Routes, Route{
Method: method, Method: method,
@ -81,10 +89,19 @@ func (rb *RouteBuilder) POST(regex *regexp.Regexp, h Handler) {
rb.Handle([]string{http.MethodPost}, regex, h) rb.Handle([]string{http.MethodPost}, regex, h)
} }
func (rb *RouteBuilder) Group(regex *regexp.Regexp, addRoutes func(rb *RouteBuilder)) { func (rb *RouteBuilder) WithMiddleware(ms ...Middleware) RouteBuilder {
newRb := *rb
newRb.Middlewares = append(rb.Middlewares, ms...)
return newRb
}
func (rb *RouteBuilder) Group(regex *regexp.Regexp, ms ...Middleware) RouteBuilder {
newRb := *rb newRb := *rb
newRb.Prefixes = append(newRb.Prefixes, regex) newRb.Prefixes = append(newRb.Prefixes, regex)
addRoutes(&newRb) newRb.Middlewares = append(rb.Middlewares, ms...)
return newRb
} }
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
@ -138,6 +155,8 @@ nextroute:
Req: req, Req: req,
Res: rw, Res: rw,
PathParams: params, PathParams: params,
ctx: req.Context(),
} }
c.PathParams = params c.PathParams = params
@ -174,13 +193,33 @@ type RequestContext struct {
ctx context.Context ctx context.Context
} }
func (c *RequestContext) Context() context.Context { // Our RequestContext is a context.Context
if c.ctx == nil {
c.ctx = c.Req.Context() var _ context.Context = &RequestContext{}
}
return c.ctx func (c *RequestContext) Deadline() (time.Time, bool) {
return c.ctx.Deadline()
} }
func (c *RequestContext) Done() <-chan struct{} {
return c.ctx.Done()
}
func (c *RequestContext) Err() error {
return c.ctx.Err()
}
func (c *RequestContext) Value(key any) any {
switch key {
case perf.PerfContextKey:
return c.Perf
default:
return c.ctx.Value(key)
}
}
// Plus it does many other things specific to us
func (c *RequestContext) URL() *url.URL { func (c *RequestContext) URL() *url.URL {
return c.Req.URL return c.Req.URL
} }
@ -325,7 +364,7 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData {
func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData { func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
LogContextErrors(c, errs...) logContextErrors(c, errs...)
panic(r) panic(r)
} }
}() }()
@ -338,6 +377,23 @@ func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
return res return res
} }
func (c *RequestContext) RejectRequest(reason string) ResponseData {
type RejectData struct {
templates.BaseData
RejectReason string
}
var res ResponseData
err := res.WriteTemplate("reject.html", RejectData{
BaseData: getBaseData(c, "Rejected", nil),
RejectReason: reason,
}, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
}
return res
}
type ResponseData struct { type ResponseData struct {
StatusCode int StatusCode int
Body *bytes.Buffer Body *bytes.Buffer

View File

@ -1,159 +1,46 @@
package website package website
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"html/template"
"math/rand"
"net/http" "net/http"
"net/url"
"regexp"
"strconv" "strconv"
"strings"
"time" "time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/email" "git.handmade.network/hmn/hmn/src/email"
"git.handmade.network/hmn/hmn/src/hmndata" "git.handmade.network/hmn/hmn/src/hmndata"
"git.handmade.network/hmn/hmn/src/hmnurl" "git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops" "git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/templates"
"git.handmade.network/hmn/hmn/src/utils" "git.handmade.network/hmn/hmn/src/utils"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/teacat/noire"
) )
func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) http.Handler { func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
router := &Router{} router := &Router{}
routes := RouteBuilder{ routes := RouteBuilder{
Router: router, Router: router,
Middleware: func(h Handler) Handler { Middlewares: []Middleware{
return func(c *RequestContext) (res ResponseData) { setDBConn(conn),
c.Conn = conn trackRequestPerf,
logContextErrorsMiddleware,
logPerf := TrackRequestPerf(c) panicCatcherMiddleware,
defer logPerf()
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
return h(c)
}
}, },
} }
anyProject := routes anyProject := routes.WithMiddleware(
anyProject.Middleware = func(h Handler) Handler { storeNoticesInCookieMiddleware,
return func(c *RequestContext) (res ResponseData) { loadCommonData,
c.Conn = conn )
hmnOnly := anyProject.WithMiddleware(
logPerf := TrackRequestPerf(c) redirectToHMN,
defer logPerf() )
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
defer storeNoticesInCookie(c, &res)
ok, errRes := LoadCommonWebsiteData(c)
if !ok {
return errRes
}
return h(c)
}
}
hmnOnly := routes
hmnOnly.Middleware = func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
c.Conn = conn
logPerf := TrackRequestPerf(c)
defer logPerf()
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
defer storeNoticesInCookie(c, &res)
ok, errRes := LoadCommonWebsiteData(c)
if !ok {
return errRes
}
if !c.CurrentProject.IsHMN() {
return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
}
return h(c)
}
}
authMiddleware := func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil {
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
}
return h(c)
}
}
adminMiddleware := func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
return FourOhFour(c)
}
return h(c)
}
}
csrfMiddleware := func(h Handler) Handler {
// CSRF mitigation actions per the OWASP cheat sheet:
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
return func(c *RequestContext) ResponseData {
c.Req.ParseMultipartForm(100 * 1024 * 1024)
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
if csrfToken != c.CurrentSession.CSRFToken {
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
res := c.Redirect("/", http.StatusSeeOther)
logoutUser(c, &res)
return res
}
return h(c)
}
}
securityTimerMiddleware := func(duration time.Duration, h Handler) Handler {
// NOTE(asaf): Will make sure that the request takes at least `delayMs` to finish. Adds a 10% random duration.
return func(c *RequestContext) ResponseData {
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
timer := time.NewTimer(duration + additionalDuration)
res := h(c)
select {
case <-longRequestContext.Done():
case <-c.Context().Done():
case <-timer.C:
}
return res
}
}
routes.GET(hmnurl.RegexPublic, func(c *RequestContext) ResponseData { routes.GET(hmnurl.RegexPublic, func(c *RequestContext) ResponseData {
var res ResponseData var res ResponseData
http.StripPrefix("/public/", http.FileServer(http.Dir("public"))).ServeHTTP(&res, c.Req) http.StripPrefix("/public/", http.FileServer(http.Dir("public"))).ServeHTTP(&res, c.Req)
AddCORSHeaders(c, &res) addCORSHeaders(c, &res)
return res return res
}) })
routes.GET(hmnurl.RegexFishbowlFiles, FishbowlFiles) routes.GET(hmnurl.RegexFishbowlFiles, FishbowlFiles)
@ -189,10 +76,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.POST(hmnurl.RegexDoPasswordReset, DoPasswordResetSubmit) hmnOnly.POST(hmnurl.RegexDoPasswordReset, DoPasswordResetSubmit)
hmnOnly.GET(hmnurl.RegexAdminAtomFeed, AdminAtomFeed) hmnOnly.GET(hmnurl.RegexAdminAtomFeed, AdminAtomFeed)
hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminMiddleware(AdminApprovalQueue)) hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminsOnly(AdminApprovalQueue))
hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminMiddleware(csrfMiddleware(AdminApprovalQueueSubmit))) hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminsOnly(csrfMiddleware(AdminApprovalQueueSubmit)))
hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminMiddleware(csrfMiddleware(UserProfileAdminSetStatus))) hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminsOnly(csrfMiddleware(UserProfileAdminSetStatus)))
hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminMiddleware(csrfMiddleware(UserProfileAdminNuke))) hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminsOnly(csrfMiddleware(UserProfileAdminNuke)))
hmnOnly.GET(hmnurl.RegexFeed, Feed) hmnOnly.GET(hmnurl.RegexFeed, Feed)
hmnOnly.GET(hmnurl.RegexAtomFeed, AtomFeed) hmnOnly.GET(hmnurl.RegexAtomFeed, AtomFeed)
@ -200,19 +87,19 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.GET(hmnurl.RegexSnippet, Snippet) hmnOnly.GET(hmnurl.RegexSnippet, Snippet)
hmnOnly.GET(hmnurl.RegexProjectIndex, ProjectIndex) hmnOnly.GET(hmnurl.RegexProjectIndex, ProjectIndex)
hmnOnly.GET(hmnurl.RegexProjectNew, authMiddleware(ProjectNew)) hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew))
hmnOnly.POST(hmnurl.RegexProjectNew, authMiddleware(csrfMiddleware(ProjectNewSubmit))) hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit)))
hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback)) hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback))
hmnOnly.POST(hmnurl.RegexDiscordUnlink, authMiddleware(csrfMiddleware(DiscordUnlink))) hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink)))
hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, authMiddleware(csrfMiddleware(DiscordShowcaseBacklog))) hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog)))
hmnOnly.POST(hmnurl.RegexTwitchEventSubCallback, TwitchEventSubCallback) hmnOnly.POST(hmnurl.RegexTwitchEventSubCallback, TwitchEventSubCallback)
hmnOnly.GET(hmnurl.RegexTwitchDebugPage, TwitchDebugPage) hmnOnly.GET(hmnurl.RegexTwitchDebugPage, TwitchDebugPage)
hmnOnly.GET(hmnurl.RegexUserProfile, UserProfile) hmnOnly.GET(hmnurl.RegexUserProfile, UserProfile)
hmnOnly.GET(hmnurl.RegexUserSettings, authMiddleware(UserSettings)) hmnOnly.GET(hmnurl.RegexUserSettings, needsAuth(UserSettings))
hmnOnly.POST(hmnurl.RegexUserSettings, authMiddleware(csrfMiddleware(UserSettingsSave))) hmnOnly.POST(hmnurl.RegexUserSettings, needsAuth(csrfMiddleware(UserSettingsSave)))
hmnOnly.GET(hmnurl.RegexPodcast, PodcastIndex) hmnOnly.GET(hmnurl.RegexPodcast, PodcastIndex)
hmnOnly.GET(hmnurl.RegexPodcastEdit, PodcastEdit) hmnOnly.GET(hmnurl.RegexPodcastEdit, PodcastEdit)
@ -231,6 +118,9 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.GET(hmnurl.RegexLibraryAny, LibraryNotPortedYet) hmnOnly.GET(hmnurl.RegexLibraryAny, LibraryNotPortedYet)
// Project routes can appear either at the root (e.g. hero.handmade.network/edit)
// or on a personal project path (e.g. handmade.network/p/123/hero/edit). So, we
// have pulled all those routes into this function.
attachProjectRoutes := func(rb *RouteBuilder) { attachProjectRoutes := func(rb *RouteBuilder) {
rb.GET(hmnurl.RegexHomepage, func(c *RequestContext) ResponseData { rb.GET(hmnurl.RegexHomepage, func(c *RequestContext) ResponseData {
if c.CurrentProject.IsHMN() { if c.CurrentProject.IsHMN() {
@ -240,8 +130,8 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
} }
}) })
rb.GET(hmnurl.RegexProjectEdit, authMiddleware(ProjectEdit)) rb.GET(hmnurl.RegexProjectEdit, needsAuth(ProjectEdit))
rb.POST(hmnurl.RegexProjectEdit, authMiddleware(csrfMiddleware(ProjectEditSubmit))) rb.POST(hmnurl.RegexProjectEdit, needsAuth(csrfMiddleware(ProjectEditSubmit)))
// Middleware used for forum action routes - anything related to actually creating or editing forum content // Middleware used for forum action routes - anything related to actually creating or editing forum content
needsForums := func(h Handler) Handler { needsForums := func(h Handler) Handler {
@ -251,14 +141,14 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return FourOhFour(c) return FourOhFour(c)
} }
// Require auth if forums are enabled // Require auth if forums are enabled
return authMiddleware(h)(c) return needsAuth(h)(c)
} }
} }
rb.POST(hmnurl.RegexForumNewThreadSubmit, needsForums(csrfMiddleware(ForumNewThreadSubmit))) rb.POST(hmnurl.RegexForumNewThreadSubmit, needsForums(csrfMiddleware(ForumNewThreadSubmit)))
rb.GET(hmnurl.RegexForumNewThread, needsForums(ForumNewThread)) rb.GET(hmnurl.RegexForumNewThread, needsForums(ForumNewThread))
rb.GET(hmnurl.RegexForumThread, ForumThread) rb.GET(hmnurl.RegexForumThread, ForumThread)
rb.GET(hmnurl.RegexForum, Forum) rb.GET(hmnurl.RegexForum, Forum)
rb.POST(hmnurl.RegexForumMarkRead, authMiddleware(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled rb.POST(hmnurl.RegexForumMarkRead, needsAuth(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled
rb.GET(hmnurl.RegexForumPost, ForumPostRedirect) rb.GET(hmnurl.RegexForumPost, ForumPostRedirect)
rb.GET(hmnurl.RegexForumPostReply, needsForums(ForumPostReply)) rb.GET(hmnurl.RegexForumPostReply, needsForums(ForumPostReply))
rb.POST(hmnurl.RegexForumPostReply, needsForums(csrfMiddleware(ForumPostReplySubmit))) rb.POST(hmnurl.RegexForumPostReply, needsForums(csrfMiddleware(ForumPostReplySubmit)))
@ -276,7 +166,7 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return FourOhFour(c) return FourOhFour(c)
} }
// Require auth if blogs are enabled // Require auth if blogs are enabled
return authMiddleware(h)(c) return needsAuth(h)(c)
} }
} }
rb.GET(hmnurl.RegexBlog, BlogIndex) rb.GET(hmnurl.RegexBlog, BlogIndex)
@ -296,64 +186,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
), http.StatusMovedPermanently) ), http.StatusMovedPermanently)
}) })
} }
hmnOnly.Group(hmnurl.RegexPersonalProject, func(rb *RouteBuilder) { officialProjectRoutes := anyProject.WithMiddleware(officialProjectMiddleware)
// TODO(ben): Perhaps someday we can make this middleware modification feel better? It seems personalProjectRoutes := hmnOnly.Group(hmnurl.RegexPersonalProject, personalProjectMiddleware)
// pretty common to run the outermost middleware first before doing other stuff, but having attachProjectRoutes(&officialProjectRoutes)
// to nest functions this way feels real bad. attachProjectRoutes(&personalProjectRoutes)
rb.Middleware = func(h Handler) Handler {
return hmnOnly.Middleware(func(c *RequestContext) ResponseData {
// At this point we are definitely on the plain old HMN subdomain.
// Fetch personal project and do whatever
id, err := strconv.Atoi(c.PathParams["projectid"])
if err != nil {
panic(oops.New(err, "project id was not numeric (bad regex in routing)"))
}
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)
} else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
}
}
c.CurrentProject = &p.Project
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
if !p.Project.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(p.Project.Name) {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
return h(c)
})
}
attachProjectRoutes(rb)
})
anyProject.Group(regexp.MustCompile("^"), func(rb *RouteBuilder) {
rb.Middleware = func(h Handler) Handler {
return anyProject.Middleware(func(c *RequestContext) ResponseData {
// We could be on any project's subdomain.
// Check if the current project (matched by subdomain) is actually no longer official
// and therefore needs to be redirected to the personal project version of the route.
if c.CurrentProject.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
return h(c)
})
}
attachProjectRoutes(rb)
})
anyProject.POST(hmnurl.RegexAssetUpload, AssetUpload) anyProject.POST(hmnurl.RegexAssetUpload, AssetUpload)
@ -375,318 +211,69 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return router return router
} }
func ProjectCSS(c *RequestContext) ResponseData { func setDBConn(conn *pgxpool.Pool) Middleware {
color := c.URL().Query().Get("color") return func(h Handler) Handler {
if color == "" { return func(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) c.Conn = conn
} return h(c)
baseData := getBaseData(c, "", nil)
bgColor := noire.NewHex(color)
h, s, l := bgColor.HSL()
if baseData.Theme == "dark" {
l = 15
} else {
l = 95
}
if s > 20 {
s = 20
}
bgColor = noire.NewHSL(h, s, l)
templateData := struct {
templates.BaseData
Color string
PostBgColor string
}{
BaseData: baseData,
Color: color,
PostBgColor: bgColor.HTML(),
}
var res ResponseData
res.Header().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
return res
}
func FourOhFour(c *RequestContext) ResponseData {
var res ResponseData
res.StatusCode = http.StatusNotFound
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
templateData := struct {
templates.BaseData
Wanted string
}{
BaseData: getBaseData(c, "Page not found", nil),
Wanted: c.FullUrl(),
} }
res.MustWriteTemplate("404.html", templateData, c.Perf)
} else {
res.Write([]byte("Not Found"))
} }
return res
} }
type RejectData struct { func redirectToHMN(h Handler) Handler {
templates.BaseData return func(c *RequestContext) ResponseData {
RejectReason string if !c.CurrentProject.IsHMN() {
} return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
func RejectRequest(c *RequestContext, reason string) ResponseData {
var res ResponseData
err := res.WriteTemplate("reject.html", RejectData{
BaseData: getBaseData(c, "Rejected", nil),
RejectReason: reason,
}, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
}
return res
}
func LoadCommonWebsiteData(c *RequestContext) (bool, ResponseData) {
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
defer c.Perf.EndBlock()
// get user
{
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
if err != nil {
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
}
c.CurrentUser = user
c.CurrentSession = session
} }
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
return h(c)
} }
}
// get official project func officialProjectMiddleware(h Handler) Handler {
{ return func(c *RequestContext) ResponseData {
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost()) // Check if the current project (matched by subdomain) is actually no longer official
slug := strings.TrimRight(hostPrefix, ".") // and therefore needs to be redirected to the personal project version of the route.
var owners []*models.User if c.CurrentProject.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
if len(slug) > 0 { return h(c)
dbProject, err := hmndata.FetchProjectBySlug(c.Context(), c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{ }
Lifecycles: models.AllProjectLifecycles, }
IncludeHidden: true,
}) func personalProjectMiddleware(h Handler) Handler {
if err == nil { return func(c *RequestContext) ResponseData {
c.CurrentProject = &dbProject.Project hmnProject := c.CurrentProject
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
owners = dbProject.Owners id := utils.Must1(strconv.Atoi(c.PathParams["projectid"]))
p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)
} else { } else {
if errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
// do nothing, this is fine
} else {
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
} }
} }
if c.CurrentProject == nil { c.CurrentProject = &p.Project
dbProject, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{ c.CurrentProject.Color1 = hmnProject.Color1
Lifecycles: models.AllProjectLifecycles, c.CurrentProject.Color2 = hmnProject.Color2
IncludeHidden: true,
})
if err != nil {
panic(oops.New(err, "failed to fetch HMN project"))
}
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
}
if c.CurrentProject == nil {
panic("failed to load project data")
}
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject) c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
} c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
c.Theme = "light" if !c.CurrentProject.Personal {
if c.CurrentUser != nil && c.CurrentUser.DarkTheme { return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
c.Theme = "dark"
}
return true, ResponseData{}
}
func AddCORSHeaders(c *RequestContext, res *ResponseData) {
parsed, err := url.Parse(config.Config.BaseUrl)
if err != nil {
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
return
}
origin := ""
origins, found := c.Req.Header["Origin"]
if found {
origin = origins[0]
}
if strings.HasSuffix(origin, parsed.Host) {
res.Header().Add("Access-Control-Allow-Origin", origin)
res.Header().Add("Access-Control-Allow-Credentials", "true")
res.Header().Add("Vary", "Origin")
}
}
// Given a session id, fetches user data from the database. Will return nil if
// the user cannot be found, and will only return an error if it's serious.
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
session, err := auth.GetSession(c.Context(), c.Conn, sessionId)
if err != nil {
if errors.Is(err, auth.ErrNoSession) {
return nil, nil, nil
} else {
return nil, nil, oops.New(err, "failed to get current session")
}
}
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, nil, session.Username, hmndata.UsersQuery{
AnyStatus: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
return nil, nil, nil // user was deleted or something
} else {
return nil, nil, oops.New(err, "failed to get user for session")
}
}
return user, session, nil
}
func TrackRequestPerf(c *RequestContext) (after func()) {
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
c.ctx = context.WithValue(c.Context(), perf.PerfContextKey, c.Perf)
return func() {
c.Perf.EndRequest()
log := logging.Info()
blockStack := make([]time.Time, 0)
for i, block := range c.Perf.Blocks {
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
blockStack = blockStack[:len(blockStack)-1]
}
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
blockStack = append(blockStack, block.End)
}
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
}
}
func LogContextErrors(c *RequestContext, errs ...error) {
for _, err := range errs {
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
}
}
func LogContextErrorsFromResponse(c *RequestContext, res *ResponseData) {
LogContextErrors(c, res.Errors...)
}
func MiddlewarePanicCatcher(c *RequestContext, res *ResponseData) {
if recovered := recover(); recovered != nil {
maybeError, ok := recovered.(*error)
var err error
if ok {
err = *maybeError
} else {
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
}
*res = c.ErrorResponse(http.StatusInternalServerError, err)
}
}
const NoticesCookieName = "hmn_notices"
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
cookie, err := c.Req.Cookie(NoticesCookieName)
if err != nil {
if !errors.Is(err, http.ErrNoCookie) {
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
}
return nil
}
return deserializeNoticesFromCookie(cookie.Value)
}
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
serialized := serializeNoticesForCookie(c, res.FutureNotices)
if serialized != "" {
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Value: serialized,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
Expires: time.Now().Add(time.Minute * 5),
Secure: config.Config.Auth.CookieSecure,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
res.SetCookie(&noticesCookie)
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
// NOTE(asaf): Don't clear on redirect
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
MaxAge: -1,
}
res.SetCookie(&noticesCookie)
}
}
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
var builder strings.Builder
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
size := 0
for i, notice := range notices {
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
if i != 0 {
sizeIncrease += 1
}
if size+sizeIncrease > maxSize {
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
break
} }
if i != 0 { if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(c.CurrentProject.Name) {
builder.WriteString("\t") return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
} }
builder.WriteString(notice.Class)
builder.WriteString("|")
builder.WriteString(string(notice.Content))
size += sizeIncrease return h(c)
} }
return builder.String()
}
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
var result []templates.Notice
notices := strings.Split(cookieVal, "\t")
for _, notice := range notices {
parts := strings.SplitN(notice, "|", 2)
if len(parts) == 2 {
result = append(result, templates.Notice{
Class: parts[0],
Content: template.HTML(parts[1]),
})
}
}
return result
} }

View File

@ -31,7 +31,7 @@ func TestLogContextErrors(t *testing.T) {
Middleware: func(h Handler) Handler { Middleware: func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) { return func(c *RequestContext) (res ResponseData) {
c.Logger = &logger c.Logger = &logger
defer LogContextErrorsFromResponse(c, &res) defer logContextErrorsMiddleware(c, &res)
return h(c) return h(c)
} }
}, },

View File

@ -16,7 +16,7 @@ type ShowcaseData struct {
} }
func Showcase(c *RequestContext) ResponseData { func Showcase(c *RequestContext) ResponseData {
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{}) snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{})
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippets")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippets"))
} }

View File

@ -30,7 +30,7 @@ func Snippet(c *RequestContext) ResponseData {
return FourOhFour(c) return FourOhFour(c)
} }
s, err := hmndata.FetchSnippet(c.Context(), c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{}) s, err := hmndata.FetchSnippet(c, c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{})
if err != nil { if err != nil {
if errors.Is(err, db.NotFound) { if errors.Is(err, db.NotFound) {
return FourOhFour(c) return FourOhFour(c)

View File

@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData {
} }
func TwitchDebugPage(c *RequestContext) ResponseData { func TwitchDebugPage(c *RequestContext) ResponseData {
streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn, streams, err := db.Query[models.TwitchStream](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM

View File

@ -52,7 +52,7 @@ func UserProfile(c *RequestContext) ResponseData {
if c.CurrentUser != nil && strings.ToLower(c.CurrentUser.Username) == username { if c.CurrentUser != nil && strings.ToLower(c.CurrentUser.Username) == username {
profileUser = c.CurrentUser profileUser = c.CurrentUser
} else { } else {
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, c.CurrentUser, username, hmndata.UsersQuery{}) user, err := hmndata.FetchUserByUsername(c, c.Conn, c.CurrentUser, username, hmndata.UsersQuery{})
if err != nil { if err != nil {
if errors.Is(err, db.NotFound) { if errors.Is(err, db.NotFound) {
return FourOhFour(c) return FourOhFour(c)
@ -72,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Fetch user links") c.Perf.StartBlock("SQL", "Fetch user links")
userLinks, err := db.Query[models.Link](c.Context(), c.Conn, userLinks, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -92,7 +92,7 @@ func UserProfile(c *RequestContext) ResponseData {
} }
c.Perf.EndBlock() c.Perf.EndBlock()
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{profileUser.ID}, OwnerIDs: []int{profileUser.ID},
Lifecycles: models.AllProjectLifecycles, Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true, IncludeHidden: true,
@ -111,13 +111,13 @@ func UserProfile(c *RequestContext) ResponseData {
c.Perf.EndBlock() c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch posts") c.Perf.StartBlock("SQL", "Fetch posts")
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{ posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
UserIDs: []int{profileUser.ID}, UserIDs: []int{profileUser.ID},
SortDescending: true, SortDescending: true,
}) })
c.Perf.EndBlock() c.Perf.EndBlock()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{ snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
OwnerIDs: []int{profileUser.ID}, OwnerIDs: []int{profileUser.ID},
}) })
if err != nil { if err != nil {
@ -125,7 +125,7 @@ func UserProfile(c *RequestContext) ResponseData {
} }
c.Perf.StartBlock("SQL", "Fetch subforum tree") c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock() c.Perf.EndBlock()
@ -213,7 +213,7 @@ func UserSettings(c *RequestContext) ResponseData {
DiscordShowcaseBacklogUrl string DiscordShowcaseBacklogUrl string
} }
links, err := db.Query[models.Link](c.Context(), c.Conn, links, err := db.Query[models.Link](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM link FROM link
@ -230,7 +230,7 @@ func UserSettings(c *RequestContext) ResponseData {
var tduser *templates.DiscordUser var tduser *templates.DiscordUser
var numUnsavedMessages int var numUnsavedMessages int
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
` `
SELECT $columns SELECT $columns
FROM discord_user FROM discord_user
@ -246,7 +246,7 @@ func UserSettings(c *RequestContext) ResponseData {
tmp := templates.DiscordUserToTemplate(duser) tmp := templates.DiscordUserToTemplate(duser)
tduser = &tmp tduser = &tmp
numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn, numUnsavedMessages, err = db.QueryOneScalar[int](c, c.Conn,
` `
SELECT COUNT(*) SELECT COUNT(*)
FROM FROM
@ -299,11 +299,11 @@ func UserSettingsSave(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse form")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse form"))
} }
tx, err := c.Conn.Begin(c.Context()) tx, err := c.Conn.Begin(c)
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer tx.Rollback(c.Context()) defer tx.Rollback(c)
form, err := c.GetFormValues() form, err := c.GetFormValues()
if err != nil { if err != nil {
@ -315,7 +315,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
email := form.Get("email") email := form.Get("email")
if !hmnemail.IsEmail(email) { if !hmnemail.IsEmail(email) {
return RejectRequest(c, "Your email was not valid.") return c.RejectRequest("Your email was not valid.")
} }
showEmail := form.Get("showemail") != "" showEmail := form.Get("showemail") != ""
@ -328,7 +328,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
discordShowcaseAuto := form.Get("discord-showcase-auto") != "" discordShowcaseAuto := form.Get("discord-showcase-auto") != ""
discordDeleteSnippetOnMessageDelete := form.Get("discord-snippet-keep") == "" discordDeleteSnippetOnMessageDelete := form.Get("discord-snippet-keep") == ""
_, err = tx.Exec(c.Context(), _, err = tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET SET
@ -360,15 +360,15 @@ func UserSettingsSave(c *RequestContext) ResponseData {
} }
// Process links // Process links
twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil) twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
linksText := form.Get("links") linksText := form.Get("links")
links := ParseLinks(linksText) links := ParseLinks(linksText)
_, err = tx.Exec(c.Context(), `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID) _, err = tx.Exec(c, `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID)
if err != nil { if err != nil {
c.Logger.Warn().Err(err).Msg("failed to delete old links") c.Logger.Warn().Err(err).Msg("failed to delete old links")
} else { } else {
for i, link := range links { for i, link := range links {
_, err := tx.Exec(c.Context(), _, err := tx.Exec(c,
` `
INSERT INTO link (name, url, ordering, user_id) INSERT INTO link (name, url, ordering, user_id)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
@ -384,7 +384,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
} }
} }
} }
twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil) twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
if preErr == nil && postErr == nil { if preErr == nil && postErr == nil {
twitch.UserOrProjectLinksUpdated(twitchLoginsPreChange, twitchLoginsPostChange) twitch.UserOrProjectLinksUpdated(twitchLoginsPreChange, twitchLoginsPostChange)
} }
@ -407,7 +407,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
} }
var avatarUUID *uuid.UUID var avatarUUID *uuid.UUID
if newAvatar.Exists { if newAvatar.Exists {
avatarAsset, err := assets.Create(c.Context(), tx, assets.CreateInput{ avatarAsset, err := assets.Create(c, tx, assets.CreateInput{
Content: newAvatar.Content, Content: newAvatar.Content,
Filename: newAvatar.Filename, Filename: newAvatar.Filename,
ContentType: newAvatar.Mime, ContentType: newAvatar.Mime,
@ -421,7 +421,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
avatarUUID = &avatarAsset.ID avatarUUID = &avatarAsset.ID
} }
if newAvatar.Exists || newAvatar.Remove { if newAvatar.Exists || newAvatar.Remove {
_, err := tx.Exec(c.Context(), _, err := tx.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET SET
@ -437,7 +437,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
} }
} }
err = tx.Commit(c.Context()) err = tx.Commit(c)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save user settings")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save user settings"))
} }
@ -454,7 +454,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id") userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr) userId, err := strconv.Atoi(userIdStr)
if err != nil { if err != nil {
return RejectRequest(c, "No user id provided") return c.RejectRequest("No user id provided")
} }
status := c.Req.Form.Get("status") status := c.Req.Form.Get("status")
@ -469,10 +469,10 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
case "banned": case "banned":
desiredStatus = models.UserStatusBanned desiredStatus = models.UserStatusBanned
default: default:
return RejectRequest(c, "No legal user status provided") return c.RejectRequest("No legal user status provided")
} }
_, err = c.Conn.Exec(c.Context(), _, err = c.Conn.Exec(c,
` `
UPDATE hmn_user UPDATE hmn_user
SET status = $1 SET status = $1
@ -485,7 +485,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
} }
if desiredStatus == models.UserStatusBanned { if desiredStatus == models.UserStatusBanned {
err = auth.DeleteSessionForUser(c.Context(), c.Conn, c.Req.Form.Get("username")) err = auth.DeleteSessionForUser(c, c.Conn, c.Req.Form.Get("username"))
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
} }
@ -500,10 +500,10 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id") userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr) userId, err := strconv.Atoi(userIdStr)
if err != nil { if err != nil {
return RejectRequest(c, "No user id provided") return c.RejectRequest("No user id provided")
} }
err = deleteAllPostsForUser(c.Context(), c.Conn, userId) err = deleteAllPostsForUser(c, c.Conn, userId)
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete user posts")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete user posts"))
} }
@ -514,7 +514,7 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *ResponseData { func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *ResponseData {
if new != confirm { if new != confirm {
res := RejectRequest(c, "Your password and password confirmation did not match.") res := c.RejectRequest("Your password and password confirmation did not match.")
return &res return &res
} }
@ -531,12 +531,12 @@ func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *Res
} }
if !ok { if !ok {
res := RejectRequest(c, "The old password you provided was not correct.") res := c.RejectRequest("The old password you provided was not correct.")
return &res return &res
} }
newHashedPassword := auth.HashPassword(new) newHashedPassword := auth.HashPassword(new)
err = auth.UpdatePassword(c.Context(), tx, c.CurrentUser.Username, newHashedPassword) err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword)
if err != nil { if err != nil {
res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password")) res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password"))
return &res return &res

View File

@ -33,14 +33,13 @@ var WebsiteCommand = &cobra.Command{
logging.Info().Msg("Hello, HMN!") logging.Info().Msg("Hello, HMN!")
backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background()) backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background())
longRequestContext, cancelLongRequests := context.WithCancel(context.Background())
conn := db.NewConnPool() conn := db.NewConnPool()
perfCollector := perf.RunPerfCollector(backgroundJobContext) perfCollector := perf.RunPerfCollector(backgroundJobContext)
server := http.Server{ server := http.Server{
Addr: config.Config.Addr, Addr: config.Config.Addr,
Handler: NewWebsiteRoutes(longRequestContext, conn), Handler: NewWebsiteRoutes(conn),
} }
backgroundJobsDone := jobs.Zip( backgroundJobsDone := jobs.Zip(
@ -59,8 +58,6 @@ var WebsiteCommand = &cobra.Command{
<-signals <-signals
logging.Info().Msg("Shutting down the website") logging.Info().Msg("Shutting down the website")
go func() { go func() {
logging.Info().Msg("cancelling long requests")
cancelLongRequests()
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
logging.Info().Msg("shutting down web server") logging.Info().Msg("shutting down web server")