Rework requests and middleware (#57)

o boy

Resolves #10 (hopefully!)

Co-authored-by: Ben Visness <bvisness@gmail.com>
Reviewed-on: hmn/hmn#57
This commit is contained in:
bvisness 2022-06-24 21:38:11 +00:00
parent 32db9b1843
commit e9d4300100
27 changed files with 851 additions and 782 deletions

View File

@ -84,6 +84,8 @@ func UrlWithFragment(path string, query []Q, fragment string) string {
return HMNProjectContext.UrlWithFragment(path, query, fragment)
}
// Takes a project URL and rewrites it using the current URL context. This can be used
// to convert a personal project URL to official and vice versa.
func (c *UrlContext) RewriteProjectUrl(u *url.URL) string {
// we need to strip anything matching the personal project regex to get the base path
match := RegexPersonalProject.FindString(u.Path)

View File

@ -69,7 +69,7 @@ func AdminAtomFeed(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
@ -134,7 +134,7 @@ type unapprovedUserData struct {
func AdminApprovalQueue(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID)
}
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
userLinks, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM
@ -253,13 +253,13 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
return RejectRequest(c, "User id can't be parsed")
return c.RejectRequest("User id can't be parsed")
}
user, err := hmndata.FetchUser(c.Context(), c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{})
user, err := hmndata.FetchUser(c, c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{})
if err != nil {
if errors.Is(err, db.NotFound) {
return RejectRequest(c, "User not found")
return c.RejectRequest("User not found")
} else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
}
@ -267,7 +267,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
whatHappened := ""
if action == ApprovalQueueActionApprove {
_, err := c.Conn.Exec(c.Context(),
_, err := c.Conn.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -281,7 +281,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
}
whatHappened = fmt.Sprintf("%s approved successfully", user.Username)
} else if action == ApprovalQueueActionSpammer {
_, err := c.Conn.Exec(c.Context(),
_, err := c.Conn.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -293,15 +293,15 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to set user to banned"))
}
err = auth.DeleteSessionForUser(c.Context(), c.Conn, user.Username)
err = auth.DeleteSessionForUser(c, c.Conn, user.Username)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
}
err = deleteAllPostsForUser(c.Context(), c.Conn, user.ID)
err = deleteAllPostsForUser(c, c.Conn, user.ID)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's posts"))
}
err = deleteAllProjectsForUser(c.Context(), c.Conn, user.ID)
err = deleteAllProjectsForUser(c, c.Conn, user.ID)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's projects"))
}
@ -324,7 +324,7 @@ type UnapprovedPost struct {
}
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
posts, err := db.Query[UnapprovedPost](c, c.Conn,
`
SELECT $columns
FROM
@ -355,7 +355,7 @@ type UnapprovedProject struct {
}
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn,
ownerIDs, err := db.QueryScalar[int](c, c.Conn,
`
SELECT id
FROM
@ -369,7 +369,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
return nil, oops.New(err, "failed to fetch unapproved users")
}
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: ownerIDs,
IncludeHidden: true,
})
@ -382,7 +382,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
projectIDs = append(projectIDs, p.Project.ID)
}
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
projectLinks, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM

View File

@ -19,7 +19,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
requestedUsername := usernameArgs[0]
found = true
c.Perf.StartBlock("SQL", "Fetch user")
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
user, err := db.QueryOne[models.User](c, c.Conn,
`
SELECT $columns
FROM
@ -45,7 +45,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
var res ResponseData
res.Header().Set("Content-Type", "application/json")
AddCORSHeaders(c, &res)
addCORSHeaders(c, &res)
if found {
res.Write([]byte(fmt.Sprintf(`{ "found": true, "canonical": "%s" }`, canonicalUsername)))
} else {

View File

@ -85,7 +85,7 @@ func AssetUpload(c *RequestContext) ResponseData {
}
}
asset, err := assets.Create(c.Context(), c.Conn, assets.CreateInput{
asset, err := assets.Create(c, c.Conn, assets.CreateInput{
Content: data,
Filename: originalFilename,
ContentType: mimeType,

View File

@ -28,7 +28,7 @@ type LoginPageData struct {
func LoginPage(c *RequestContext) ResponseData {
if c.CurrentUser != nil {
return RejectRequest(c, "You are already logged in.")
return c.RejectRequest("You are already logged in.")
}
var res ResponseData
@ -75,7 +75,7 @@ func Login(c *RequestContext) ResponseData {
return res
}
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
user, err := db.QueryOne[models.User](c, c.Conn,
`
SELECT $columns
FROM hmn_user
@ -102,7 +102,7 @@ func Login(c *RequestContext) ResponseData {
}
if user.Status == models.UserStatusInactive {
return RejectRequest(c, "You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.")
return c.RejectRequest("You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.")
}
res := c.Redirect(redirect, http.StatusSeeOther)
@ -136,7 +136,7 @@ func RegisterNewUser(c *RequestContext) ResponseData {
func RegisterNewUserSubmit(c *RequestContext) ResponseData {
if c.CurrentUser != nil {
return RejectRequest(c, "Can't register new user. You are already logged in")
return c.RejectRequest("Can't register new user. You are already logged in")
}
c.Req.ParseForm()
@ -146,16 +146,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
password := c.Req.Form.Get("password")
password2 := c.Req.Form.Get("password2")
if !UsernameRegex.Match([]byte(username)) {
return RejectRequest(c, "Invalid username")
return c.RejectRequest("Invalid username")
}
if !email.IsEmail(emailAddress) {
return RejectRequest(c, "Invalid email address")
return c.RejectRequest("Invalid email address")
}
if len(password) < 8 {
return RejectRequest(c, "Password too short")
return c.RejectRequest("Password too short")
}
if password != password2 {
return RejectRequest(c, "Password confirmation doesn't match password")
return c.RejectRequest("Password confirmation doesn't match password")
}
c.Perf.StartBlock("SQL", "Check blacklist")
@ -169,7 +169,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
userAlreadyExists := true
_, err := db.QueryOneScalar[int](c.Context(), c.Conn,
_, err := db.QueryOneScalar[int](c, c.Conn,
`
SELECT id
FROM hmn_user
@ -186,11 +186,11 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
}
if userAlreadyExists {
return RejectRequest(c, fmt.Sprintf("Username (%s) already exists.", username))
return c.RejectRequest(fmt.Sprintf("Username (%s) already exists.", username))
}
emailAlreadyExists := true
_, err = db.QueryOneScalar[int](c.Context(), c.Conn,
_, err = db.QueryOneScalar[int](c, c.Conn,
`
SELECT id
FROM hmn_user
@ -215,16 +215,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
hashed := auth.HashPassword(password)
c.Perf.StartBlock("SQL", "Create user and one time token")
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
now := time.Now()
var newUserId int
err = tx.QueryRow(c.Context(),
err = tx.QueryRow(c,
`
INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip)
VALUES ($1, $2, $3, $4, $5, $6)
@ -237,7 +237,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
}
ott := models.GenerateToken()
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
VALUES($1, $2, $3, $4, $5)
@ -263,7 +263,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Commit user")
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit user to the db"))
}
@ -302,7 +302,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
username, hasUsername := c.PathParams["username"]
if !hasUsername {
return RejectRequest(c, "Bad validation url")
return c.RejectRequest("Bad validation url")
}
token := ""
@ -319,7 +319,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
}
if !hasToken {
return RejectRequest(c, "Bad validation url")
return c.RejectRequest("Bad validation url")
}
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypeRegistration)
@ -366,13 +366,13 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Updating user status and deleting token")
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -385,7 +385,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
}
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM one_time_token WHERE id = $1
`,
@ -395,7 +395,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete one time token"))
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit transaction"))
}
@ -413,7 +413,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
// NOTE(asaf): Only call this when validationResult.Match is false.
func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, validationResult validateUserAndTokenResult) ResponseData {
if validationResult.User == nil {
return RejectRequest(c, "You haven't validated your email in time and your user was deleted. You may try registering again with the same username.")
return c.RejectRequest("You haven't validated your email in time and your user was deleted. You may try registering again with the same username.")
}
if validationResult.OneTimeToken == nil {
@ -422,7 +422,7 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali
return c.Redirect(hmnurl.BuildLoginPage(""), http.StatusSeeOther)
}
return RejectRequest(c, "Bad token. If you are having problems registering or logging in, please contact the staff.")
return c.RejectRequest("Bad token. If you are having problems registering or logging in, please contact the staff.")
}
// NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email,
@ -446,14 +446,14 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
emailAddress := strings.TrimSpace(c.Req.Form.Get("email"))
if username == "" && emailAddress == "" {
return RejectRequest(c, "You must provide a username and an email address.")
return c.RejectRequest("You must provide a username and an email address.")
}
c.Perf.StartBlock("SQL", "Fetching user")
type userQuery struct {
User models.User `db:"hmn_user"`
}
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
user, err := db.QueryOne[models.User](c, c.Conn,
`
SELECT $columns
FROM hmn_user
@ -473,7 +473,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token")
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
resetToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
`
SELECT $columns
FROM one_time_token
@ -495,7 +495,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken != nil {
if resetToken.Expires.Before(now.Add(time.Minute * 30)) { // NOTE(asaf): Expired or about to expire
c.Perf.StartBlock("SQL", "Deleting expired token")
_, err = c.Conn.Exec(c.Context(),
_, err = c.Conn.Exec(c,
`
DELETE FROM one_time_token
WHERE id = $1
@ -512,7 +512,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token")
newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
newToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
`
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5)
@ -567,12 +567,12 @@ func DoPasswordReset(c *RequestContext) ResponseData {
token, hasToken := c.PathParams["token"]
if !hasToken || !hasUsername {
return RejectRequest(c, "Bad validation url.")
return c.RejectRequest("Bad validation url.")
}
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
if !validationResult.Match {
return RejectRequest(c, "Bad validation url.")
return c.RejectRequest("Bad validation url.")
}
var res ResponseData
@ -601,30 +601,30 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
if !validationResult.Match {
return RejectRequest(c, "Bad validation url.")
return c.RejectRequest("Bad validation url.")
}
if c.CurrentUser != nil && c.CurrentUser.ID != validationResult.User.ID {
return RejectRequest(c, fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username))
return c.RejectRequest(fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username))
}
if len(password) < 8 {
return RejectRequest(c, "Password too short")
return c.RejectRequest("Password too short")
}
if password != password2 {
return RejectRequest(c, "Password confirmation doesn't match password")
return c.RejectRequest("Password confirmation doesn't match password")
}
hashed := auth.HashPassword(password)
c.Perf.StartBlock("SQL", "Update user's password and delete reset token")
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
tag, err := tx.Exec(c.Context(),
tag, err := tx.Exec(c,
`
UPDATE hmn_user
SET password = $1
@ -638,7 +638,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
}
if validationResult.User.Status == models.UserStatusInactive {
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -652,7 +652,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
}
}
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM one_time_token
WHERE id = $1
@ -663,7 +663,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete onetimetoken"))
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit password reset to the db"))
}
@ -698,7 +698,7 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
// re-hash and save the user's password if necessary
if hashed.IsOutdated() {
newHashed := auth.HashPassword(password)
err := auth.UpdatePassword(c.Context(), c.Conn, user.Username, newHashed)
err := auth.UpdatePassword(c, c.Conn, user.Username, newHashed)
if err != nil {
c.Logger.Error().Err(err).Msg("failed to update user's password")
}
@ -711,15 +711,15 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) error {
c.Perf.StartBlock("SQL", "Setting last login and creating session")
defer c.Perf.EndBlock()
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return oops.New(err, "failed to start db transaction")
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
now := time.Now()
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE hmn_user
SET last_login = $1
@ -732,12 +732,12 @@ func loginUser(c *RequestContext, user *models.User, responseData *ResponseData)
return oops.New(err, "failed to update last_login for user")
}
session, err := auth.CreateSession(c.Context(), c.Conn, user.Username)
session, err := auth.CreateSession(c, c.Conn, user.Username)
if err != nil {
return oops.New(err, "failed to create session")
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return oops.New(err, "failed to commit transaction")
}
@ -749,7 +749,7 @@ func logoutUser(c *RequestContext, res *ResponseData) {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
// clear the session from the db immediately, no expiration
err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value)
err := auth.DeleteSession(c, c.Conn, sessionCookie.Value)
if err != nil {
logging.Error().Err(err).Msg("failed to delete session on logout")
}
@ -772,7 +772,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"hmn_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
}
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
data, err := db.QueryOne[userAndTokenQuery](c, c.Conn,
`
SELECT $columns
FROM hmn_user

View File

@ -37,7 +37,7 @@ func BlogIndex(c *RequestContext) ResponseData {
const postsPerPage = 20
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -51,7 +51,7 @@ func BlogIndex(c *RequestContext) ResponseData {
return c.Redirect(c.UrlContext.BuildBlog(page), http.StatusSeeOther)
}
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
Limit: postsPerPage,
@ -78,7 +78,7 @@ func BlogIndex(c *RequestContext) ResponseData {
canCreate := false
if c.CurrentProject.HasBlog() && c.CurrentUser != nil {
isProjectOwner := false
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID)
owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project owners"))
}
@ -128,7 +128,7 @@ func BlogThread(c *RequestContext) ResponseData {
return FourOhFour(c)
}
thread, posts, err := hmndata.FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{
thread, posts, err := hmndata.FetchThreadPosts(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -155,7 +155,7 @@ func BlogThread(c *RequestContext) ResponseData {
// Update thread last read info
if c.CurrentUser != nil {
c.Perf.StartBlock("SQL", "Update TLRI")
_, err := c.Conn.Exec(c.Context(),
_, err := c.Conn.Exec(c,
`
INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
VALUES ($1, $2, $3)
@ -196,7 +196,7 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData {
return FourOhFour(c)
}
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -227,11 +227,11 @@ func BlogNewThread(c *RequestContext) ResponseData {
}
func BlogNewThreadSubmit(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
err = c.Req.ParseForm()
if err != nil {
@ -240,15 +240,15 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
title := c.Req.Form.Get("title")
unparsed := c.Req.Form.Get("body")
if title == "" {
return RejectRequest(c, "You must provide a title for your post.")
return c.RejectRequest("You must provide a title for your post.")
}
if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.")
return c.RejectRequest("You must provide a body for your post.")
}
// Create thread
var threadId int
err = tx.QueryRow(c.Context(),
err = tx.QueryRow(c,
`
INSERT INTO thread (title, type, project_id, first_id, last_id)
VALUES ($1, $2, $3, $4, $5)
@ -265,9 +265,9 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
}
// Create everything else
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new blog post"))
}
@ -282,11 +282,11 @@ func BlogPostEdit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -326,17 +326,17 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -351,16 +351,16 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
unparsed := c.Req.Form.Get("body")
editReason := c.Req.Form.Get("editreason")
if title != "" && post.Thread.FirstID != post.Post.ID {
return RejectRequest(c, "You can only edit the title by editing the first post.")
return c.RejectRequest("You can only edit the title by editing the first post.")
}
if unparsed == "" {
return RejectRequest(c, "You must provide a post body.")
return c.RejectRequest("You must provide a post body.")
}
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
if title != "" {
_, err := tx.Exec(c.Context(),
_, err := tx.Exec(c,
`
UPDATE thread SET title = $1 WHERE id = $2
`,
@ -372,7 +372,7 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
}
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post"))
}
@ -387,7 +387,7 @@ func BlogPostReply(c *RequestContext) ResponseData {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -421,11 +421,11 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
err = c.Req.ParseForm()
if err != nil {
@ -433,12 +433,12 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
}
unparsed := c.Req.Form.Get("body")
if unparsed == "" {
return RejectRequest(c, "Your reply cannot be empty.")
return c.RejectRequest("Your reply cannot be empty.")
}
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host)
newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to blog post"))
}
@ -453,11 +453,11 @@ func BlogPostDelete(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -503,19 +503,19 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID)
threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
}
@ -523,7 +523,7 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
if threadDeleted {
return c.Redirect(c.UrlContext.BuildHomepage(), http.StatusSeeOther)
} else {
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
})
@ -560,7 +560,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
res.ThreadID = threadId
c.Perf.StartBlock("SQL", "Verify that the thread exists")
threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
threadExists, err := db.QueryOneScalar[bool](c, c.Conn,
`
SELECT COUNT(*) > 0
FROM thread
@ -588,7 +588,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
res.PostID = postId
c.Perf.StartBlock("SQL", "Verify that the post exists")
postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
postExists, err := db.QueryOneScalar[bool](c, c.Conn,
`
SELECT COUNT(*) > 0
FROM post

138
src/website/common.go Normal file
View File

@ -0,0 +1,138 @@
package website
import (
"errors"
"net/http"
"net/url"
"strings"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/hmndata"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates"
)
func loadCommonData(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
{
// get user
{
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
}
c.CurrentUser = user
c.CurrentSession = session
}
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
}
// get current official project (HMN or otherwise, by subdomain)
{
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost())
slug := strings.TrimRight(hostPrefix, ".")
var owners []*models.User
if len(slug) > 0 {
dbProject, err := hmndata.FetchProjectBySlug(c, c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err == nil {
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
owners = dbProject.Owners
} else {
if errors.Is(err, db.NotFound) {
// do nothing, this is fine
} else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
}
}
if c.CurrentProject == nil {
dbProject, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
panic(oops.New(err, "failed to fetch HMN project"))
}
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
}
if c.CurrentProject == nil {
panic("failed to load project data")
}
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
}
c.Theme = "light"
if c.CurrentUser != nil && c.CurrentUser.DarkTheme {
c.Theme = "dark"
}
}
c.Perf.EndBlock()
return h(c)
}
}
// Given a session id, fetches user data from the database. Will return nil if
// the user cannot be found, and will only return an error if it's serious.
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
session, err := auth.GetSession(c, c.Conn, sessionId)
if err != nil {
if errors.Is(err, auth.ErrNoSession) {
return nil, nil, nil
} else {
return nil, nil, oops.New(err, "failed to get current session")
}
}
user, err := hmndata.FetchUserByUsername(c, c.Conn, nil, session.Username, hmndata.UsersQuery{
AnyStatus: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
return nil, nil, nil // user was deleted or something
} else {
return nil, nil, oops.New(err, "failed to get user for session")
}
}
return user, session, nil
}
func addCORSHeaders(c *RequestContext, res *ResponseData) {
parsed, err := url.Parse(config.Config.BaseUrl)
if err != nil {
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
return
}
origin := ""
origins, found := c.Req.Header["Origin"]
if found {
origin = origins[0]
}
if strings.HasSuffix(origin, parsed.Host) {
res.Header().Add("Access-Control-Allow-Origin", origin)
res.Header().Add("Access-Control-Allow-Credentials", "true")
res.Header().Add("Vary", "Origin")
}
}

View File

@ -35,31 +35,31 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
// This occurs when the user cancels. Just go back to the profile page.
return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther)
} else {
return RejectRequest(c, "Failed to authenticate with Discord.")
return c.RejectRequest("Failed to authenticate with Discord.")
}
}
// Do the actual token exchange
code := query.Get("code")
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback())
res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback())
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
}
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken)
user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
}
// Add the role on Discord
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID)
err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
}
// Add the user to our database
_, err = c.Conn.Exec(c.Context(),
_, err = c.Conn.Exec(c,
`
INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
@ -79,7 +79,7 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
}
if c.CurrentUser.Status == models.UserStatusConfirmed {
_, err = c.Conn.Exec(c.Context(),
_, err = c.Conn.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -98,13 +98,13 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
}
func DiscordUnlink(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx,
discordUser, err := db.QueryOne[models.DiscordUser](c, tx,
`
SELECT $columns
FROM discord_user
@ -120,7 +120,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
}
}
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM discord_user
WHERE id = $1
@ -131,12 +131,12 @@ func DiscordUnlink(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete Discord user"))
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit Discord user delete"))
}
err = discord.RemoveGuildMemberRole(c.Context(), discordUser.UserID, config.Config.Discord.MemberRoleID)
err = discord.RemoveGuildMemberRole(c, discordUser.UserID, config.Config.Discord.MemberRoleID)
if err != nil {
c.Logger.Warn().Err(err).Msg("failed to remove member role on unlink")
}
@ -145,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
}
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
`SELECT $columns FROM discord_user WHERE hmn_user_id = $1`,
c.CurrentUser.ID,
)
@ -157,7 +157,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
}
msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn,
msgIDs, err := db.QueryScalar[string](c, c.Conn,
`
SELECT msg.id
FROM
@ -174,12 +174,12 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
}
for _, msgID := range msgIDs {
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
interned, err := discord.FetchInternedMessage(c, c.Conn, msgID)
if err != nil && !errors.Is(err, db.NotFound) {
return c.ErrorResponse(http.StatusInternalServerError, err)
} else if err == nil {
// NOTE(asaf): Creating snippet even if the checkbox is off because the user asked us to.
err = discord.HandleSnippetForInternedMessage(c.Context(), c.Conn, interned, true)
err = discord.HandleSnippetForInternedMessage(c, c.Conn, interned, true)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err)
}

View File

@ -1,6 +1,31 @@
package website
import "fmt"
import (
"fmt"
"net/http"
"strings"
"git.handmade.network/hmn/hmn/src/templates"
)
func FourOhFour(c *RequestContext) ResponseData {
var res ResponseData
res.StatusCode = http.StatusNotFound
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
templateData := struct {
templates.BaseData
Wanted string
}{
BaseData: getBaseData(c, "Page not found", nil),
Wanted: c.FullUrl(),
}
res.MustWriteTemplate("404.html", templateData, c.Perf)
} else {
res.Write([]byte("Not Found"))
}
return res
}
// A SafeError can be used to wrap another error and explicitly provide
// an error message that is safe to show to a user. This allows the original

View File

@ -33,7 +33,7 @@ var feedThreadTypes = []models.ThreadType{
}
func Feed(c *RequestContext) ResponseData {
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes,
})
if err != nil {
@ -156,7 +156,7 @@ func AtomFeed(c *RequestContext) ResponseData {
if hasAll {
itemsPerFeed = 100000
}
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, nil, hmndata.ProjectsQuery{
projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, nil, hmndata.ProjectsQuery{
Limit: itemsPerFeed,
Types: hmndata.OfficialProjects,
OrderBy: "date_approved DESC",
@ -188,7 +188,7 @@ func AtomFeed(c *RequestContext) ResponseData {
feedData.AtomFeedUrl = hmnurl.BuildAtomFeedForShowcase()
feedData.FeedUrl = hmnurl.BuildShowcase()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Limit: itemsPerFeed,
})
if err != nil {
@ -215,7 +215,7 @@ func AtomFeed(c *RequestContext) ResponseData {
}
func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostListItem, error) {
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes,
Limit: limit,
Offset: offset,
@ -225,7 +225,7 @@ func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostLi
return nil, err
}
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()

View File

@ -206,7 +206,7 @@ func Fishbowl(c *RequestContext) ResponseData {
func FishbowlFiles(c *RequestContext) ResponseData {
var res ResponseData
fishbowlHTTPFS.ServeHTTP(&res, c.Req)
AddCORSHeaders(c, &res)
addCORSHeaders(c, &res)
return res
}
@ -224,7 +224,7 @@ func linkifyDiscordContent(c *RequestContext, dbConn db.ConnOrTx, content string
discordUserIds = append(discordUserIds, id)
}
hmnUsers, err := hmndata.FetchUsers(c.Context(), dbConn, c.CurrentUser, hmndata.UsersQuery{
hmnUsers, err := hmndata.FetchUsers(c, dbConn, c.CurrentUser, hmndata.UsersQuery{
DiscordUserIDs: discordUserIds,
})
if err != nil {

View File

@ -91,7 +91,7 @@ func Forum(c *RequestContext) ResponseData {
currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID)
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{cd.SubforumID},
@ -107,7 +107,7 @@ func Forum(c *RequestContext) ResponseData {
}
howManyThreadsToSkip := (page - 1) * threadsPerPage
mainThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
mainThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{cd.SubforumID},
@ -141,7 +141,7 @@ func Forum(c *RequestContext) ResponseData {
subforumNodes := cd.SubforumTree[cd.SubforumID].Children
for _, sfNode := range subforumNodes {
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{sfNode.ID},
@ -150,7 +150,7 @@ func Forum(c *RequestContext) ResponseData {
panic(oops.New(err, "failed to get count of threads"))
}
subforumThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
subforumThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
SubforumIDs: []int{sfNode.ID},
@ -203,7 +203,7 @@ func Forum(c *RequestContext) ResponseData {
func ForumMarkRead(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
@ -212,16 +212,16 @@ func ForumMarkRead(c *RequestContext) ResponseData {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
sfIds := []int{sfId}
if sfId == 0 {
// Mark literally everything as read
_, err := tx.Exec(c.Context(),
_, err := tx.Exec(c,
`
UPDATE hmn_user
SET marked_all_read_at = NOW()
@ -234,7 +234,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
}
// Delete thread unread info
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM thread_last_read_info
WHERE user_id = $1;
@ -246,7 +246,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
}
// Delete subforum unread info
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM subforum_last_read_info
WHERE user_id = $1;
@ -258,7 +258,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
}
} else {
c.Perf.StartBlock("SQL", "Update SLRIs")
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
INSERT INTO subforum_last_read_info (subforum_id, user_id, lastread)
SELECT id, $2, $3
@ -277,7 +277,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Delete TLRIs")
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
DELETE FROM thread_last_read_info
WHERE
@ -298,7 +298,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
}
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit SLRI/TLRI updates"))
}
@ -332,7 +332,7 @@ func ForumThread(c *RequestContext) ResponseData {
return FourOhFour(c)
}
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadIDs: []int{cd.ThreadID},
})
@ -351,7 +351,7 @@ func ForumThread(c *RequestContext) ResponseData {
return c.Redirect(correctThreadUrl, http.StatusSeeOther)
}
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
ThreadIDs: []int{cd.ThreadID},
@ -374,7 +374,7 @@ func ForumThread(c *RequestContext) ResponseData {
PreviousUrl: c.UrlContext.BuildForumThread(currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)),
}
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadIDs: []int{thread.ID},
Limit: threadPostsPerPage,
@ -396,7 +396,7 @@ func ForumThread(c *RequestContext) ResponseData {
post.ReplyPost = &reply
}
addAuthorCountsToPost(c.Context(), c.Conn, &post)
addAuthorCountsToPost(c, c.Conn, &post)
posts = append(posts, post)
}
@ -404,7 +404,7 @@ func ForumThread(c *RequestContext) ResponseData {
// Update thread last read info
if c.CurrentUser != nil {
c.Perf.StartBlock("SQL", "Update TLRI")
_, err = c.Conn.Exec(c.Context(),
_, err = c.Conn.Exec(c,
`
INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
VALUES ($1, $2, $3)
@ -445,7 +445,7 @@ func ForumPostRedirect(c *RequestContext) ResponseData {
return FourOhFour(c)
}
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
ThreadIDs: []int{cd.ThreadID},
@ -495,11 +495,11 @@ func ForumNewThread(c *RequestContext) ResponseData {
}
func ForumNewThreadSubmit(c *RequestContext) ResponseData {
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
cd, ok := getCommonForumData(c)
if !ok {
@ -517,15 +517,15 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
sticky = true
}
if title == "" {
return RejectRequest(c, "You must provide a title for your post.")
return c.RejectRequest("You must provide a title for your post.")
}
if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.")
return c.RejectRequest("You must provide a body for your post.")
}
// Create thread
var threadId int
err = tx.QueryRow(c.Context(),
err = tx.QueryRow(c,
`
INSERT INTO thread (title, sticky, type, project_id, subforum_id, first_id, last_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
@ -544,9 +544,9 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
}
// Create everything else
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new forum thread"))
}
@ -561,7 +561,7 @@ func ForumPostReply(c *RequestContext) ResponseData {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
})
@ -600,11 +600,11 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
err = c.Req.ParseForm()
if err != nil {
@ -612,10 +612,10 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
}
unparsed := c.Req.Form.Get("body")
if unparsed == "" {
return RejectRequest(c, "Your reply cannot be empty.")
return c.RejectRequest("Your reply cannot be empty.")
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
})
@ -629,9 +629,9 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
replyPostId = &post.Post.ID
}
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host)
newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to forum post"))
}
@ -646,11 +646,11 @@ func ForumPostEdit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
})
@ -688,17 +688,17 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
})
@ -713,16 +713,16 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
unparsed := c.Req.Form.Get("body")
editReason := c.Req.Form.Get("editreason")
if title != "" && post.Thread.FirstID != post.Post.ID {
return RejectRequest(c, "You can only edit the title by editing the first post.")
return c.RejectRequest("You can only edit the title by editing the first post.")
}
if unparsed == "" {
return RejectRequest(c, "You must provide a body for your post.")
return c.RejectRequest("You must provide a body for your post.")
}
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
if title != "" {
_, err := tx.Exec(c.Context(),
_, err := tx.Exec(c,
`
UPDATE thread SET title = $1 WHERE id = $2
`,
@ -734,7 +734,7 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
}
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit forum post"))
}
@ -749,11 +749,11 @@ func ForumPostDelete(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
})
@ -798,19 +798,19 @@ func ForumPostDeleteSubmit(c *RequestContext) ResponseData {
return FourOhFour(c)
}
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
return FourOhFour(c)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID)
threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
}
@ -831,7 +831,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
panic(err)
}
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
// This is the rare query where we want all thread types!
})
@ -842,7 +842,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
@ -874,7 +874,7 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) {
defer c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()

View File

@ -89,7 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
img.Seek(0, io.SeekStart)
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
sha1sum := hasher.Sum(nil)
imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn,
imageFile, err := db.QueryOne[models.ImageFile](c, dbConn,
`
INSERT INTO image_file (file, size, sha1sum, protected, width, height)
VALUES ($1, $2, $3, $4, $5, $6)

View File

@ -51,7 +51,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
}
tagId := -1
jamTag, err := hmndata.FetchTag(c.Context(), c.Conn, hmndata.TagQuery{
jamTag, err := hmndata.FetchTag(c, c.Conn, hmndata.TagQuery{
Text: []string{"wheeljam"},
})
if err == nil {
@ -60,7 +60,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
c.Logger.Warn().Err(err).Msg("failed to fetch jam tag; will fetch all snippets as a result")
}
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Tags: []int{tagId},
})
if err != nil {

View File

@ -34,13 +34,13 @@ type LandingTemplateData struct {
func Index(c *RequestContext) ResponseData {
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
var timelineItems []templates.TimelineItem
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes,
})
if err != nil {
@ -65,7 +65,7 @@ func Index(c *RequestContext) ResponseData {
}
// This is essentially an alternate for feed page 1.
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ThreadTypes: feedThreadTypes,
Limit: feedPostsPerPage,
SortDescending: true,
@ -84,7 +84,7 @@ func Index(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Get news")
newsThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
newsThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
ProjectIDs: []int{models.HMNProjectID},
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
Limit: 1,
@ -106,7 +106,7 @@ func Index(c *RequestContext) ResponseData {
}
c.Perf.EndBlock()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Limit: 40,
})
if err != nil {

123
src/website/middlewares.go Normal file
View File

@ -0,0 +1,123 @@
package website
import (
"fmt"
"math/rand"
"net/http"
"time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/utils"
)
func panicCatcherMiddleware(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
defer func() {
if recovered := recover(); recovered != nil {
maybeError, ok := recovered.(*error)
var err error
if ok {
err = *maybeError
} else {
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
}
res = c.ErrorResponse(http.StatusInternalServerError, err)
}
}()
return h(c)
}
}
func trackRequestPerf(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
defer func() {
c.Perf.EndRequest()
log := logging.Info()
blockStack := make([]time.Time, 0)
for i, block := range c.Perf.Blocks {
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
blockStack = blockStack[:len(blockStack)-1]
}
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
blockStack = append(blockStack, block.End)
}
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
}()
return h(c)
}
}
func needsAuth(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil {
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
}
return h(c)
}
}
func adminsOnly(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
return FourOhFour(c)
}
return h(c)
}
}
func csrfMiddleware(h Handler) Handler {
// CSRF mitigation actions per the OWASP cheat sheet:
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
return func(c *RequestContext) ResponseData {
c.Req.ParseMultipartForm(100 * 1024 * 1024)
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
if csrfToken != c.CurrentSession.CSRFToken {
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
res := c.Redirect("/", http.StatusSeeOther)
logoutUser(c, &res)
return res
}
return h(c)
}
}
func securityTimerMiddleware(duration time.Duration, h Handler) Handler {
// NOTE(asaf): Will make sure that the request takes at least `duration` to finish. Adds a 10% random duration.
return func(c *RequestContext) ResponseData {
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
timer := time.NewTimer(duration + additionalDuration)
res := h(c)
select {
case <-c.Done():
case <-timer.C:
}
return res
}
}
func logContextErrors(c *RequestContext, errs ...error) {
for _, err := range errs {
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
}
}
func logContextErrorsMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
res := h(c)
logContextErrors(c, res.Errors...)
return res
}
}

100
src/website/notices.go Normal file
View File

@ -0,0 +1,100 @@
package website
import (
"errors"
"html/template"
"net/http"
"strings"
"time"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/templates"
)
const NoticesCookieName = "hmn_notices"
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
cookie, err := c.Req.Cookie(NoticesCookieName)
if err != nil {
if !errors.Is(err, http.ErrNoCookie) {
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
}
return nil
}
return deserializeNoticesFromCookie(cookie.Value)
}
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
serialized := serializeNoticesForCookie(c, res.FutureNotices)
if serialized != "" {
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Value: serialized,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
Expires: time.Now().Add(time.Minute * 5),
Secure: config.Config.Auth.CookieSecure,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
res.SetCookie(&noticesCookie)
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
// NOTE(asaf): Don't clear on redirect
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
MaxAge: -1,
}
res.SetCookie(&noticesCookie)
}
}
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
var builder strings.Builder
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
size := 0
for i, notice := range notices {
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
if i != 0 {
sizeIncrease += 1
}
if size+sizeIncrease > maxSize {
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
break
}
if i != 0 {
builder.WriteString("\t")
}
builder.WriteString(notice.Class)
builder.WriteString("|")
builder.WriteString(string(notice.Content))
size += sizeIncrease
}
return builder.String()
}
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
var result []templates.Notice
notices := strings.Split(cookieVal, "\t")
for _, notice := range notices {
parts := strings.SplitN(notice, "|", 2)
if len(parts) == 2 {
result = append(result, templates.Notice{
Class: parts[0],
Content: template.HTML(parts[1]),
})
}
}
return result
}
func storeNoticesInCookieMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
res := h(c)
storeNoticesInCookie(c, &res)
return res
}
}

View File

@ -126,29 +126,29 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
title := c.Req.Form.Get("title")
if len(strings.TrimSpace(title)) == 0 {
return RejectRequest(c, "Podcast title is empty")
return c.RejectRequest("Podcast title is empty")
}
description := c.Req.Form.Get("description")
if len(strings.TrimSpace(description)) == 0 {
return RejectRequest(c, "Podcast description is empty")
return c.RejectRequest("Podcast description is empty")
}
c.Perf.StartBlock("SQL", "Updating podcast")
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
imageSaveResult := SaveImageFile(c, tx, "podcast_image", maxFileSize, fmt.Sprintf("podcast/%s/logo%d", c.CurrentProject.Slug, time.Now().UTC().Unix()))
if imageSaveResult.ValidationError != "" {
return RejectRequest(c, imageSaveResult.ValidationError)
return c.RejectRequest(imageSaveResult.ValidationError)
} else if imageSaveResult.FatalError != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(imageSaveResult.FatalError, "Failed to save podcast image"))
}
if imageSaveResult.ImageFile != nil {
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE podcast
SET
@ -166,7 +166,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to update podcast"))
}
} else {
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE podcast
SET
@ -179,7 +179,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
podcastResult.Podcast.ID,
)
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
c.Perf.EndBlock()
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to commit db transaction"))
@ -357,16 +357,16 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
c.Req.ParseForm()
title := c.Req.Form.Get("title")
if len(strings.TrimSpace(title)) == 0 {
return RejectRequest(c, "Episode title is empty")
return c.RejectRequest("Episode title is empty")
}
description := c.Req.Form.Get("description")
if len(strings.TrimSpace(description)) == 0 {
return RejectRequest(c, "Episode description is empty")
return c.RejectRequest("Episode description is empty")
}
episodeNumberStr := c.Req.Form.Get("episode_number")
episodeNumber, err := strconv.Atoi(episodeNumberStr)
if err != nil {
return RejectRequest(c, "Episode number can't be parsed")
return c.RejectRequest("Episode number can't be parsed")
}
episodeFile := c.Req.Form.Get("episode_file")
found = false
@ -378,7 +378,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
}
if !found {
return RejectRequest(c, "Requested episode file not found")
return c.RejectRequest("Requested episode file not found")
}
c.Perf.StartBlock("MP3", "Parsing mp3 file for duration")
@ -417,7 +417,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
if isEdit {
guidStr = podcastResult.Episodes[0].GUID.String()
c.Perf.StartBlock("SQL", "Updating podcast episode")
_, err := c.Conn.Exec(c.Context(),
_, err := c.Conn.Exec(c,
`
UPDATE podcast_episode
SET
@ -446,7 +446,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
guid := uuid.New()
guidStr = guid.String()
c.Perf.StartBlock("SQL", "Creating new podcast episode")
_, err := c.Conn.Exec(c.Context(),
_, err := c.Conn.Exec(c,
`
INSERT INTO podcast_episode
(guid, title, description, description_rendered, audio_filename, duration, pub_date, episode_number, podcast_id)
@ -532,7 +532,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
Podcast models.Podcast `db:"podcast"`
ImageFilename string `db:"imagefile.file"`
}
podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn,
podcastQueryResult, err := db.QueryOne[podcastQuery](c, c.Conn,
`
SELECT $columns
FROM
@ -558,7 +558,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
if fetchEpisodes {
if episodeGUID == "" {
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn,
episodes, err := db.Query[models.PodcastEpisode](c, c.Conn,
`
SELECT $columns
FROM podcast_episode AS episode
@ -578,7 +578,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
return result, err
}
c.Perf.StartBlock("SQL", "Fetch podcast episode")
episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn,
episode, err := db.QueryOne[models.PodcastEpisode](c, c.Conn,
`
SELECT $columns
FROM podcast_episode AS episode

View File

@ -26,11 +26,52 @@ import (
"git.handmade.network/hmn/hmn/src/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
"github.com/teacat/noire"
)
const maxPersonalProjects = 5
const maxProjectOwners = 5
func ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color")
if color == "" {
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
}
baseData := getBaseData(c, "", nil)
bgColor := noire.NewHex(color)
h, s, l := bgColor.HSL()
if baseData.Theme == "dark" {
l = 15
} else {
l = 95
}
if s > 20 {
s = 20
}
bgColor = noire.NewHSL(h, s, l)
templateData := struct {
templates.BaseData
Color string
PostBgColor string
}{
BaseData: baseData,
Color: color,
PostBgColor: bgColor.HTML(),
}
var res ResponseData
res.Header().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
return res
}
type ProjectTemplateData struct {
templates.BaseData
@ -48,7 +89,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
const maxCarouselProjects = 10
const maxPersonalProjects = 10
officialProjects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
officialProjects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
Types: hmndata.OfficialProjects,
})
if err != nil {
@ -123,7 +164,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
// Fetch and highlight a random selection of personal projects
var personalProjects []templates.Project
{
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
Types: hmndata.PersonalProjects,
})
if err != nil {
@ -181,13 +222,13 @@ func ProjectHomepage(c *RequestContext) ResponseData {
// There are no further permission checks to do, because permissions are
// checked whatever way we fetch the project.
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID)
owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err)
}
c.Perf.StartBlock("SQL", "Fetching screenshots")
screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn,
screenshotFilenames, err := db.QueryScalar[string](c, c.Conn,
`
SELECT screenshot.file
FROM
@ -204,7 +245,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetching project links")
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
projectLinks, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM
@ -221,12 +262,12 @@ func ProjectHomepage(c *RequestContext) ResponseData {
c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetching project timeline")
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
ProjectIDs: []int{c.CurrentProject.ID},
Limit: maxRecentActivity,
SortDescending: true,
@ -241,7 +282,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
Value: c.CurrentProject.Blurb,
})
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{
p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
@ -317,7 +358,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
tagId = *c.CurrentProject.TagID
}
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
Tags: []int{tagId},
})
if err != nil {
@ -364,7 +405,7 @@ type ProjectEditData struct {
}
func ProjectNew(c *RequestContext) ResponseData {
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{c.CurrentUser.ID},
Types: hmndata.PersonalProjects,
})
@ -372,7 +413,7 @@ func ProjectNew(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
}
if numProjects >= maxPersonalProjects {
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
}
var project templates.ProjectSettings
@ -397,16 +438,16 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
}
if len(formResult.RejectionReason) != 0 {
return RejectRequest(c, formResult.RejectionReason)
return c.RejectRequest(formResult.RejectionReason)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{c.CurrentUser.ID},
Types: hmndata.PersonalProjects,
})
@ -414,11 +455,11 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
}
if numProjects >= maxPersonalProjects {
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
}
var projectId int
err = tx.QueryRow(c.Context(),
err = tx.QueryRow(c,
`
INSERT INTO project
(name, blurb, description, descparsed, lifecycle, date_created, all_last_updated)
@ -439,12 +480,12 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
formResult.Payload.ProjectID = projectId
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload)
err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err)
}
tx.Commit(c.Context())
tx.Commit(c)
urlContext := &hmnurl.UrlContext{
PersonalProject: true,
@ -461,7 +502,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
}
p, err := hmndata.FetchProject(
c.Context(), c.Conn,
c, c.Conn,
c.CurrentUser, c.CurrentProject.ID,
hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
@ -473,7 +514,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Fetching project links")
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
projectLinks, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM
@ -524,23 +565,23 @@ func ProjectEditSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
}
if len(formResult.RejectionReason) != 0 {
return RejectRequest(c, formResult.RejectionReason)
return c.RejectRequest(formResult.RejectionReason)
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
formResult.Payload.ProjectID = c.CurrentProject.ID
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload)
err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, err)
}
tx.Commit(c.Context())
tx.Commit(c)
urlContext := &hmnurl.UrlContext{
PersonalProject: formResult.Payload.Personal,

View File

@ -13,6 +13,7 @@ import (
"path"
"regexp"
"strings"
"time"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
@ -43,15 +44,22 @@ func (r *Route) String() string {
}
type RouteBuilder struct {
Router *Router
Prefixes []*regexp.Regexp
Middleware Middleware
Router *Router
Prefixes []*regexp.Regexp
Middlewares []Middleware
}
type Handler func(c *RequestContext) ResponseData
type Middleware func(h Handler) Handler
func applyMiddlewares(h Handler, ms []Middleware) Handler {
result := h
for i := len(ms) - 1; i >= 0; i-- {
result = ms[i](result)
}
return result
}
func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler) {
// Ensure that this regex matches the start of the string
regexStr := regex.String()
@ -59,7 +67,7 @@ func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler
panic("All routing regexes must begin with '^'")
}
h = rb.Middleware(h)
h = applyMiddlewares(h, rb.Middlewares)
for _, method := range methods {
rb.Router.Routes = append(rb.Router.Routes, Route{
Method: method,
@ -81,10 +89,19 @@ func (rb *RouteBuilder) POST(regex *regexp.Regexp, h Handler) {
rb.Handle([]string{http.MethodPost}, regex, h)
}
func (rb *RouteBuilder) Group(regex *regexp.Regexp, addRoutes func(rb *RouteBuilder)) {
func (rb *RouteBuilder) WithMiddleware(ms ...Middleware) RouteBuilder {
newRb := *rb
newRb.Middlewares = append(rb.Middlewares, ms...)
return newRb
}
func (rb *RouteBuilder) Group(regex *regexp.Regexp, ms ...Middleware) RouteBuilder {
newRb := *rb
newRb.Prefixes = append(newRb.Prefixes, regex)
addRoutes(&newRb)
newRb.Middlewares = append(rb.Middlewares, ms...)
return newRb
}
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
@ -138,6 +155,8 @@ nextroute:
Req: req,
Res: rw,
PathParams: params,
ctx: req.Context(),
}
c.PathParams = params
@ -174,13 +193,33 @@ type RequestContext struct {
ctx context.Context
}
func (c *RequestContext) Context() context.Context {
if c.ctx == nil {
c.ctx = c.Req.Context()
}
return c.ctx
// Our RequestContext is a context.Context
var _ context.Context = &RequestContext{}
func (c *RequestContext) Deadline() (time.Time, bool) {
return c.ctx.Deadline()
}
func (c *RequestContext) Done() <-chan struct{} {
return c.ctx.Done()
}
func (c *RequestContext) Err() error {
return c.ctx.Err()
}
func (c *RequestContext) Value(key any) any {
switch key {
case perf.PerfContextKey:
return c.Perf
default:
return c.ctx.Value(key)
}
}
// Plus it does many other things specific to us
func (c *RequestContext) URL() *url.URL {
return c.Req.URL
}
@ -325,7 +364,7 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData {
func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
defer func() {
if r := recover(); r != nil {
LogContextErrors(c, errs...)
logContextErrors(c, errs...)
panic(r)
}
}()
@ -338,6 +377,23 @@ func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
return res
}
func (c *RequestContext) RejectRequest(reason string) ResponseData {
type RejectData struct {
templates.BaseData
RejectReason string
}
var res ResponseData
err := res.WriteTemplate("reject.html", RejectData{
BaseData: getBaseData(c, "Rejected", nil),
RejectReason: reason,
}, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
}
return res
}
type ResponseData struct {
StatusCode int
Body *bytes.Buffer

View File

@ -1,159 +1,46 @@
package website
import (
"context"
"errors"
"fmt"
"html/template"
"math/rand"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/email"
"git.handmade.network/hmn/hmn/src/hmndata"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/templates"
"git.handmade.network/hmn/hmn/src/utils"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/teacat/noire"
)
func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) http.Handler {
func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
router := &Router{}
routes := RouteBuilder{
Router: router,
Middleware: func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
c.Conn = conn
logPerf := TrackRequestPerf(c)
defer logPerf()
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
return h(c)
}
Middlewares: []Middleware{
setDBConn(conn),
trackRequestPerf,
logContextErrorsMiddleware,
panicCatcherMiddleware,
},
}
anyProject := routes
anyProject.Middleware = func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
c.Conn = conn
logPerf := TrackRequestPerf(c)
defer logPerf()
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
defer storeNoticesInCookie(c, &res)
ok, errRes := LoadCommonWebsiteData(c)
if !ok {
return errRes
}
return h(c)
}
}
hmnOnly := routes
hmnOnly.Middleware = func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
c.Conn = conn
logPerf := TrackRequestPerf(c)
defer logPerf()
defer LogContextErrorsFromResponse(c, &res)
defer MiddlewarePanicCatcher(c, &res)
defer storeNoticesInCookie(c, &res)
ok, errRes := LoadCommonWebsiteData(c)
if !ok {
return errRes
}
if !c.CurrentProject.IsHMN() {
return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
}
return h(c)
}
}
authMiddleware := func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil {
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
}
return h(c)
}
}
adminMiddleware := func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
return FourOhFour(c)
}
return h(c)
}
}
csrfMiddleware := func(h Handler) Handler {
// CSRF mitigation actions per the OWASP cheat sheet:
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
return func(c *RequestContext) ResponseData {
c.Req.ParseMultipartForm(100 * 1024 * 1024)
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
if csrfToken != c.CurrentSession.CSRFToken {
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
res := c.Redirect("/", http.StatusSeeOther)
logoutUser(c, &res)
return res
}
return h(c)
}
}
securityTimerMiddleware := func(duration time.Duration, h Handler) Handler {
// NOTE(asaf): Will make sure that the request takes at least `delayMs` to finish. Adds a 10% random duration.
return func(c *RequestContext) ResponseData {
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
timer := time.NewTimer(duration + additionalDuration)
res := h(c)
select {
case <-longRequestContext.Done():
case <-c.Context().Done():
case <-timer.C:
}
return res
}
}
anyProject := routes.WithMiddleware(
storeNoticesInCookieMiddleware,
loadCommonData,
)
hmnOnly := anyProject.WithMiddleware(
redirectToHMN,
)
routes.GET(hmnurl.RegexPublic, func(c *RequestContext) ResponseData {
var res ResponseData
http.StripPrefix("/public/", http.FileServer(http.Dir("public"))).ServeHTTP(&res, c.Req)
AddCORSHeaders(c, &res)
addCORSHeaders(c, &res)
return res
})
routes.GET(hmnurl.RegexFishbowlFiles, FishbowlFiles)
@ -189,10 +76,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.POST(hmnurl.RegexDoPasswordReset, DoPasswordResetSubmit)
hmnOnly.GET(hmnurl.RegexAdminAtomFeed, AdminAtomFeed)
hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminMiddleware(AdminApprovalQueue))
hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminMiddleware(csrfMiddleware(AdminApprovalQueueSubmit)))
hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminMiddleware(csrfMiddleware(UserProfileAdminSetStatus)))
hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminMiddleware(csrfMiddleware(UserProfileAdminNuke)))
hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminsOnly(AdminApprovalQueue))
hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminsOnly(csrfMiddleware(AdminApprovalQueueSubmit)))
hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminsOnly(csrfMiddleware(UserProfileAdminSetStatus)))
hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminsOnly(csrfMiddleware(UserProfileAdminNuke)))
hmnOnly.GET(hmnurl.RegexFeed, Feed)
hmnOnly.GET(hmnurl.RegexAtomFeed, AtomFeed)
@ -200,19 +87,19 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.GET(hmnurl.RegexSnippet, Snippet)
hmnOnly.GET(hmnurl.RegexProjectIndex, ProjectIndex)
hmnOnly.GET(hmnurl.RegexProjectNew, authMiddleware(ProjectNew))
hmnOnly.POST(hmnurl.RegexProjectNew, authMiddleware(csrfMiddleware(ProjectNewSubmit)))
hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew))
hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit)))
hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback))
hmnOnly.POST(hmnurl.RegexDiscordUnlink, authMiddleware(csrfMiddleware(DiscordUnlink)))
hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, authMiddleware(csrfMiddleware(DiscordShowcaseBacklog)))
hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback))
hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink)))
hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog)))
hmnOnly.POST(hmnurl.RegexTwitchEventSubCallback, TwitchEventSubCallback)
hmnOnly.GET(hmnurl.RegexTwitchDebugPage, TwitchDebugPage)
hmnOnly.GET(hmnurl.RegexUserProfile, UserProfile)
hmnOnly.GET(hmnurl.RegexUserSettings, authMiddleware(UserSettings))
hmnOnly.POST(hmnurl.RegexUserSettings, authMiddleware(csrfMiddleware(UserSettingsSave)))
hmnOnly.GET(hmnurl.RegexUserSettings, needsAuth(UserSettings))
hmnOnly.POST(hmnurl.RegexUserSettings, needsAuth(csrfMiddleware(UserSettingsSave)))
hmnOnly.GET(hmnurl.RegexPodcast, PodcastIndex)
hmnOnly.GET(hmnurl.RegexPodcastEdit, PodcastEdit)
@ -231,6 +118,9 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
hmnOnly.GET(hmnurl.RegexLibraryAny, LibraryNotPortedYet)
// Project routes can appear either at the root (e.g. hero.handmade.network/edit)
// or on a personal project path (e.g. handmade.network/p/123/hero/edit). So, we
// have pulled all those routes into this function.
attachProjectRoutes := func(rb *RouteBuilder) {
rb.GET(hmnurl.RegexHomepage, func(c *RequestContext) ResponseData {
if c.CurrentProject.IsHMN() {
@ -240,8 +130,8 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
}
})
rb.GET(hmnurl.RegexProjectEdit, authMiddleware(ProjectEdit))
rb.POST(hmnurl.RegexProjectEdit, authMiddleware(csrfMiddleware(ProjectEditSubmit)))
rb.GET(hmnurl.RegexProjectEdit, needsAuth(ProjectEdit))
rb.POST(hmnurl.RegexProjectEdit, needsAuth(csrfMiddleware(ProjectEditSubmit)))
// Middleware used for forum action routes - anything related to actually creating or editing forum content
needsForums := func(h Handler) Handler {
@ -251,14 +141,14 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return FourOhFour(c)
}
// Require auth if forums are enabled
return authMiddleware(h)(c)
return needsAuth(h)(c)
}
}
rb.POST(hmnurl.RegexForumNewThreadSubmit, needsForums(csrfMiddleware(ForumNewThreadSubmit)))
rb.GET(hmnurl.RegexForumNewThread, needsForums(ForumNewThread))
rb.GET(hmnurl.RegexForumThread, ForumThread)
rb.GET(hmnurl.RegexForum, Forum)
rb.POST(hmnurl.RegexForumMarkRead, authMiddleware(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled
rb.POST(hmnurl.RegexForumMarkRead, needsAuth(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled
rb.GET(hmnurl.RegexForumPost, ForumPostRedirect)
rb.GET(hmnurl.RegexForumPostReply, needsForums(ForumPostReply))
rb.POST(hmnurl.RegexForumPostReply, needsForums(csrfMiddleware(ForumPostReplySubmit)))
@ -276,7 +166,7 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return FourOhFour(c)
}
// Require auth if blogs are enabled
return authMiddleware(h)(c)
return needsAuth(h)(c)
}
}
rb.GET(hmnurl.RegexBlog, BlogIndex)
@ -296,64 +186,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
), http.StatusMovedPermanently)
})
}
hmnOnly.Group(hmnurl.RegexPersonalProject, func(rb *RouteBuilder) {
// TODO(ben): Perhaps someday we can make this middleware modification feel better? It seems
// pretty common to run the outermost middleware first before doing other stuff, but having
// to nest functions this way feels real bad.
rb.Middleware = func(h Handler) Handler {
return hmnOnly.Middleware(func(c *RequestContext) ResponseData {
// At this point we are definitely on the plain old HMN subdomain.
// Fetch personal project and do whatever
id, err := strconv.Atoi(c.PathParams["projectid"])
if err != nil {
panic(oops.New(err, "project id was not numeric (bad regex in routing)"))
}
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)
} else {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
}
}
c.CurrentProject = &p.Project
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
if !p.Project.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(p.Project.Name) {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
return h(c)
})
}
attachProjectRoutes(rb)
})
anyProject.Group(regexp.MustCompile("^"), func(rb *RouteBuilder) {
rb.Middleware = func(h Handler) Handler {
return anyProject.Middleware(func(c *RequestContext) ResponseData {
// We could be on any project's subdomain.
// Check if the current project (matched by subdomain) is actually no longer official
// and therefore needs to be redirected to the personal project version of the route.
if c.CurrentProject.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
return h(c)
})
}
attachProjectRoutes(rb)
})
officialProjectRoutes := anyProject.WithMiddleware(officialProjectMiddleware)
personalProjectRoutes := hmnOnly.Group(hmnurl.RegexPersonalProject, personalProjectMiddleware)
attachProjectRoutes(&officialProjectRoutes)
attachProjectRoutes(&personalProjectRoutes)
anyProject.POST(hmnurl.RegexAssetUpload, AssetUpload)
@ -375,318 +211,69 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
return router
}
func ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color")
if color == "" {
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
}
baseData := getBaseData(c, "", nil)
bgColor := noire.NewHex(color)
h, s, l := bgColor.HSL()
if baseData.Theme == "dark" {
l = 15
} else {
l = 95
}
if s > 20 {
s = 20
}
bgColor = noire.NewHSL(h, s, l)
templateData := struct {
templates.BaseData
Color string
PostBgColor string
}{
BaseData: baseData,
Color: color,
PostBgColor: bgColor.HTML(),
}
var res ResponseData
res.Header().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
return res
}
func FourOhFour(c *RequestContext) ResponseData {
var res ResponseData
res.StatusCode = http.StatusNotFound
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
templateData := struct {
templates.BaseData
Wanted string
}{
BaseData: getBaseData(c, "Page not found", nil),
Wanted: c.FullUrl(),
func setDBConn(conn *pgxpool.Pool) Middleware {
return func(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Conn = conn
return h(c)
}
res.MustWriteTemplate("404.html", templateData, c.Perf)
} else {
res.Write([]byte("Not Found"))
}
return res
}
type RejectData struct {
templates.BaseData
RejectReason string
}
func RejectRequest(c *RequestContext, reason string) ResponseData {
var res ResponseData
err := res.WriteTemplate("reject.html", RejectData{
BaseData: getBaseData(c, "Rejected", nil),
RejectReason: reason,
}, c.Perf)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
}
return res
}
func LoadCommonWebsiteData(c *RequestContext) (bool, ResponseData) {
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
defer c.Perf.EndBlock()
// get user
{
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
if err != nil {
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
}
c.CurrentUser = user
c.CurrentSession = session
func redirectToHMN(h Handler) Handler {
return func(c *RequestContext) ResponseData {
if !c.CurrentProject.IsHMN() {
return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
}
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
return h(c)
}
}
// get official project
{
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost())
slug := strings.TrimRight(hostPrefix, ".")
var owners []*models.User
func officialProjectMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
// Check if the current project (matched by subdomain) is actually no longer official
// and therefore needs to be redirected to the personal project version of the route.
if c.CurrentProject.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
if len(slug) > 0 {
dbProject, err := hmndata.FetchProjectBySlug(c.Context(), c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err == nil {
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
owners = dbProject.Owners
return h(c)
}
}
func personalProjectMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
hmnProject := c.CurrentProject
id := utils.Must1(strconv.Atoi(c.PathParams["projectid"]))
p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)
} else {
if errors.Is(err, db.NotFound) {
// do nothing, this is fine
} else {
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
}
}
if c.CurrentProject == nil {
dbProject, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
})
if err != nil {
panic(oops.New(err, "failed to fetch HMN project"))
}
c.CurrentProject = &dbProject.Project
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
}
if c.CurrentProject == nil {
panic("failed to load project data")
}
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
c.CurrentProject = &p.Project
c.CurrentProject.Color1 = hmnProject.Color1
c.CurrentProject.Color2 = hmnProject.Color2
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
}
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
c.Theme = "light"
if c.CurrentUser != nil && c.CurrentUser.DarkTheme {
c.Theme = "dark"
}
return true, ResponseData{}
}
func AddCORSHeaders(c *RequestContext, res *ResponseData) {
parsed, err := url.Parse(config.Config.BaseUrl)
if err != nil {
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
return
}
origin := ""
origins, found := c.Req.Header["Origin"]
if found {
origin = origins[0]
}
if strings.HasSuffix(origin, parsed.Host) {
res.Header().Add("Access-Control-Allow-Origin", origin)
res.Header().Add("Access-Control-Allow-Credentials", "true")
res.Header().Add("Vary", "Origin")
}
}
// Given a session id, fetches user data from the database. Will return nil if
// the user cannot be found, and will only return an error if it's serious.
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
session, err := auth.GetSession(c.Context(), c.Conn, sessionId)
if err != nil {
if errors.Is(err, auth.ErrNoSession) {
return nil, nil, nil
} else {
return nil, nil, oops.New(err, "failed to get current session")
}
}
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, nil, session.Username, hmndata.UsersQuery{
AnyStatus: true,
})
if err != nil {
if errors.Is(err, db.NotFound) {
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
return nil, nil, nil // user was deleted or something
} else {
return nil, nil, oops.New(err, "failed to get user for session")
}
}
return user, session, nil
}
func TrackRequestPerf(c *RequestContext) (after func()) {
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
c.ctx = context.WithValue(c.Context(), perf.PerfContextKey, c.Perf)
return func() {
c.Perf.EndRequest()
log := logging.Info()
blockStack := make([]time.Time, 0)
for i, block := range c.Perf.Blocks {
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
blockStack = blockStack[:len(blockStack)-1]
}
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
blockStack = append(blockStack, block.End)
}
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
}
}
func LogContextErrors(c *RequestContext, errs ...error) {
for _, err := range errs {
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
}
}
func LogContextErrorsFromResponse(c *RequestContext, res *ResponseData) {
LogContextErrors(c, res.Errors...)
}
func MiddlewarePanicCatcher(c *RequestContext, res *ResponseData) {
if recovered := recover(); recovered != nil {
maybeError, ok := recovered.(*error)
var err error
if ok {
err = *maybeError
} else {
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
}
*res = c.ErrorResponse(http.StatusInternalServerError, err)
}
}
const NoticesCookieName = "hmn_notices"
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
cookie, err := c.Req.Cookie(NoticesCookieName)
if err != nil {
if !errors.Is(err, http.ErrNoCookie) {
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
}
return nil
}
return deserializeNoticesFromCookie(cookie.Value)
}
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
serialized := serializeNoticesForCookie(c, res.FutureNotices)
if serialized != "" {
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Value: serialized,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
Expires: time.Now().Add(time.Minute * 5),
Secure: config.Config.Auth.CookieSecure,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
res.SetCookie(&noticesCookie)
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
// NOTE(asaf): Don't clear on redirect
noticesCookie := http.Cookie{
Name: NoticesCookieName,
Path: "/",
Domain: config.Config.Auth.CookieDomain,
MaxAge: -1,
}
res.SetCookie(&noticesCookie)
}
}
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
var builder strings.Builder
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
size := 0
for i, notice := range notices {
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
if i != 0 {
sizeIncrease += 1
}
if size+sizeIncrease > maxSize {
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
break
if !c.CurrentProject.Personal {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
if i != 0 {
builder.WriteString("\t")
if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(c.CurrentProject.Name) {
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
}
builder.WriteString(notice.Class)
builder.WriteString("|")
builder.WriteString(string(notice.Content))
size += sizeIncrease
return h(c)
}
return builder.String()
}
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
var result []templates.Notice
notices := strings.Split(cookieVal, "\t")
for _, notice := range notices {
parts := strings.SplitN(notice, "|", 2)
if len(parts) == 2 {
result = append(result, templates.Notice{
Class: parts[0],
Content: template.HTML(parts[1]),
})
}
}
return result
}

View File

@ -31,7 +31,7 @@ func TestLogContextErrors(t *testing.T) {
Middleware: func(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
c.Logger = &logger
defer LogContextErrorsFromResponse(c, &res)
defer logContextErrorsMiddleware(c, &res)
return h(c)
}
},

View File

@ -16,7 +16,7 @@ type ShowcaseData struct {
}
func Showcase(c *RequestContext) ResponseData {
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{})
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{})
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippets"))
}

View File

@ -30,7 +30,7 @@ func Snippet(c *RequestContext) ResponseData {
return FourOhFour(c)
}
s, err := hmndata.FetchSnippet(c.Context(), c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{})
s, err := hmndata.FetchSnippet(c, c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)

View File

@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData {
}
func TwitchDebugPage(c *RequestContext) ResponseData {
streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn,
streams, err := db.Query[models.TwitchStream](c, c.Conn,
`
SELECT $columns
FROM

View File

@ -52,7 +52,7 @@ func UserProfile(c *RequestContext) ResponseData {
if c.CurrentUser != nil && strings.ToLower(c.CurrentUser.Username) == username {
profileUser = c.CurrentUser
} else {
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, c.CurrentUser, username, hmndata.UsersQuery{})
user, err := hmndata.FetchUserByUsername(c, c.Conn, c.CurrentUser, username, hmndata.UsersQuery{})
if err != nil {
if errors.Is(err, db.NotFound) {
return FourOhFour(c)
@ -72,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Fetch user links")
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
userLinks, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM
@ -92,7 +92,7 @@ func UserProfile(c *RequestContext) ResponseData {
}
c.Perf.EndBlock()
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
OwnerIDs: []int{profileUser.ID},
Lifecycles: models.AllProjectLifecycles,
IncludeHidden: true,
@ -111,13 +111,13 @@ func UserProfile(c *RequestContext) ResponseData {
c.Perf.EndBlock()
c.Perf.StartBlock("SQL", "Fetch posts")
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
UserIDs: []int{profileUser.ID},
SortDescending: true,
})
c.Perf.EndBlock()
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
OwnerIDs: []int{profileUser.ID},
})
if err != nil {
@ -125,7 +125,7 @@ func UserProfile(c *RequestContext) ResponseData {
}
c.Perf.StartBlock("SQL", "Fetch subforum tree")
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
subforumTree := models.GetFullSubforumTree(c, c.Conn)
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
c.Perf.EndBlock()
@ -213,7 +213,7 @@ func UserSettings(c *RequestContext) ResponseData {
DiscordShowcaseBacklogUrl string
}
links, err := db.Query[models.Link](c.Context(), c.Conn,
links, err := db.Query[models.Link](c, c.Conn,
`
SELECT $columns
FROM link
@ -230,7 +230,7 @@ func UserSettings(c *RequestContext) ResponseData {
var tduser *templates.DiscordUser
var numUnsavedMessages int
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
`
SELECT $columns
FROM discord_user
@ -246,7 +246,7 @@ func UserSettings(c *RequestContext) ResponseData {
tmp := templates.DiscordUserToTemplate(duser)
tduser = &tmp
numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn,
numUnsavedMessages, err = db.QueryOneScalar[int](c, c.Conn,
`
SELECT COUNT(*)
FROM
@ -299,11 +299,11 @@ func UserSettingsSave(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse form"))
}
tx, err := c.Conn.Begin(c.Context())
tx, err := c.Conn.Begin(c)
if err != nil {
panic(err)
}
defer tx.Rollback(c.Context())
defer tx.Rollback(c)
form, err := c.GetFormValues()
if err != nil {
@ -315,7 +315,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
email := form.Get("email")
if !hmnemail.IsEmail(email) {
return RejectRequest(c, "Your email was not valid.")
return c.RejectRequest("Your email was not valid.")
}
showEmail := form.Get("showemail") != ""
@ -328,7 +328,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
discordShowcaseAuto := form.Get("discord-showcase-auto") != ""
discordDeleteSnippetOnMessageDelete := form.Get("discord-snippet-keep") == ""
_, err = tx.Exec(c.Context(),
_, err = tx.Exec(c,
`
UPDATE hmn_user
SET
@ -360,15 +360,15 @@ func UserSettingsSave(c *RequestContext) ResponseData {
}
// Process links
twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil)
twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
linksText := form.Get("links")
links := ParseLinks(linksText)
_, err = tx.Exec(c.Context(), `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID)
_, err = tx.Exec(c, `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID)
if err != nil {
c.Logger.Warn().Err(err).Msg("failed to delete old links")
} else {
for i, link := range links {
_, err := tx.Exec(c.Context(),
_, err := tx.Exec(c,
`
INSERT INTO link (name, url, ordering, user_id)
VALUES ($1, $2, $3, $4)
@ -384,7 +384,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
}
}
}
twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil)
twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
if preErr == nil && postErr == nil {
twitch.UserOrProjectLinksUpdated(twitchLoginsPreChange, twitchLoginsPostChange)
}
@ -407,7 +407,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
}
var avatarUUID *uuid.UUID
if newAvatar.Exists {
avatarAsset, err := assets.Create(c.Context(), tx, assets.CreateInput{
avatarAsset, err := assets.Create(c, tx, assets.CreateInput{
Content: newAvatar.Content,
Filename: newAvatar.Filename,
ContentType: newAvatar.Mime,
@ -421,7 +421,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
avatarUUID = &avatarAsset.ID
}
if newAvatar.Exists || newAvatar.Remove {
_, err := tx.Exec(c.Context(),
_, err := tx.Exec(c,
`
UPDATE hmn_user
SET
@ -437,7 +437,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
}
}
err = tx.Commit(c.Context())
err = tx.Commit(c)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save user settings"))
}
@ -454,7 +454,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
return RejectRequest(c, "No user id provided")
return c.RejectRequest("No user id provided")
}
status := c.Req.Form.Get("status")
@ -469,10 +469,10 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
case "banned":
desiredStatus = models.UserStatusBanned
default:
return RejectRequest(c, "No legal user status provided")
return c.RejectRequest("No legal user status provided")
}
_, err = c.Conn.Exec(c.Context(),
_, err = c.Conn.Exec(c,
`
UPDATE hmn_user
SET status = $1
@ -485,7 +485,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
}
if desiredStatus == models.UserStatusBanned {
err = auth.DeleteSessionForUser(c.Context(), c.Conn, c.Req.Form.Get("username"))
err = auth.DeleteSessionForUser(c, c.Conn, c.Req.Form.Get("username"))
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
}
@ -500,10 +500,10 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
userIdStr := c.Req.Form.Get("user_id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
return RejectRequest(c, "No user id provided")
return c.RejectRequest("No user id provided")
}
err = deleteAllPostsForUser(c.Context(), c.Conn, userId)
err = deleteAllPostsForUser(c, c.Conn, userId)
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete user posts"))
}
@ -514,7 +514,7 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *ResponseData {
if new != confirm {
res := RejectRequest(c, "Your password and password confirmation did not match.")
res := c.RejectRequest("Your password and password confirmation did not match.")
return &res
}
@ -531,12 +531,12 @@ func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *Res
}
if !ok {
res := RejectRequest(c, "The old password you provided was not correct.")
res := c.RejectRequest("The old password you provided was not correct.")
return &res
}
newHashedPassword := auth.HashPassword(new)
err = auth.UpdatePassword(c.Context(), tx, c.CurrentUser.Username, newHashedPassword)
err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword)
if err != nil {
res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password"))
return &res

View File

@ -33,14 +33,13 @@ var WebsiteCommand = &cobra.Command{
logging.Info().Msg("Hello, HMN!")
backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background())
longRequestContext, cancelLongRequests := context.WithCancel(context.Background())
conn := db.NewConnPool()
perfCollector := perf.RunPerfCollector(backgroundJobContext)
server := http.Server{
Addr: config.Config.Addr,
Handler: NewWebsiteRoutes(longRequestContext, conn),
Handler: NewWebsiteRoutes(conn),
}
backgroundJobsDone := jobs.Zip(
@ -59,8 +58,6 @@ var WebsiteCommand = &cobra.Command{
<-signals
logging.Info().Msg("Shutting down the website")
go func() {
logging.Info().Msg("cancelling long requests")
cancelLongRequests()
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
logging.Info().Msg("shutting down web server")