Rework requests and middleware (#57)
o boy Resolves #10 (hopefully!) Co-authored-by: Ben Visness <bvisness@gmail.com> Reviewed-on: hmn/hmn#57
This commit is contained in:
parent
32db9b1843
commit
e9d4300100
|
@ -84,6 +84,8 @@ func UrlWithFragment(path string, query []Q, fragment string) string {
|
|||
return HMNProjectContext.UrlWithFragment(path, query, fragment)
|
||||
}
|
||||
|
||||
// Takes a project URL and rewrites it using the current URL context. This can be used
|
||||
// to convert a personal project URL to official and vice versa.
|
||||
func (c *UrlContext) RewriteProjectUrl(u *url.URL) string {
|
||||
// we need to strip anything matching the personal project regex to get the base path
|
||||
match := RegexPersonalProject.FindString(u.Path)
|
||||
|
|
|
@ -69,7 +69,7 @@ func AdminAtomFeed(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
@ -134,7 +134,7 @@ type unapprovedUserData struct {
|
|||
|
||||
func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
|||
userIds = append(userIds, u.User.ID)
|
||||
}
|
||||
|
||||
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
userLinks, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -253,13 +253,13 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
userIdStr := c.Req.Form.Get("user_id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
return RejectRequest(c, "User id can't be parsed")
|
||||
return c.RejectRequest("User id can't be parsed")
|
||||
}
|
||||
|
||||
user, err := hmndata.FetchUser(c.Context(), c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{})
|
||||
user, err := hmndata.FetchUser(c, c.Conn, c.CurrentUser, userId, hmndata.UsersQuery{})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
return RejectRequest(c, "User not found")
|
||||
return c.RejectRequest("User not found")
|
||||
} else {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
||||
}
|
||||
|
@ -267,7 +267,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
whatHappened := ""
|
||||
if action == ApprovalQueueActionApprove {
|
||||
_, err := c.Conn.Exec(c.Context(),
|
||||
_, err := c.Conn.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -281,7 +281,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
whatHappened = fmt.Sprintf("%s approved successfully", user.Username)
|
||||
} else if action == ApprovalQueueActionSpammer {
|
||||
_, err := c.Conn.Exec(c.Context(),
|
||||
_, err := c.Conn.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -293,15 +293,15 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to set user to banned"))
|
||||
}
|
||||
err = auth.DeleteSessionForUser(c.Context(), c.Conn, user.Username)
|
||||
err = auth.DeleteSessionForUser(c, c.Conn, user.Username)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
|
||||
}
|
||||
err = deleteAllPostsForUser(c.Context(), c.Conn, user.ID)
|
||||
err = deleteAllPostsForUser(c, c.Conn, user.ID)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's posts"))
|
||||
}
|
||||
err = deleteAllProjectsForUser(c.Context(), c.Conn, user.ID)
|
||||
err = deleteAllProjectsForUser(c, c.Conn, user.ID)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete spammer's projects"))
|
||||
}
|
||||
|
@ -324,7 +324,7 @@ type UnapprovedPost struct {
|
|||
}
|
||||
|
||||
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||
posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
|
||||
posts, err := db.Query[UnapprovedPost](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -355,7 +355,7 @@ type UnapprovedProject struct {
|
|||
}
|
||||
|
||||
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||
ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn,
|
||||
ownerIDs, err := db.QueryScalar[int](c, c.Conn,
|
||||
`
|
||||
SELECT id
|
||||
FROM
|
||||
|
@ -369,7 +369,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
|||
return nil, oops.New(err, "failed to fetch unapproved users")
|
||||
}
|
||||
|
||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
OwnerIDs: ownerIDs,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
|
@ -382,7 +382,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
|||
projectIDs = append(projectIDs, p.Project.ID)
|
||||
}
|
||||
|
||||
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
projectLinks, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
|
|
@ -19,7 +19,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
|||
requestedUsername := usernameArgs[0]
|
||||
found = true
|
||||
c.Perf.StartBlock("SQL", "Fetch user")
|
||||
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||
user, err := db.QueryOne[models.User](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -45,7 +45,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
|||
|
||||
var res ResponseData
|
||||
res.Header().Set("Content-Type", "application/json")
|
||||
AddCORSHeaders(c, &res)
|
||||
addCORSHeaders(c, &res)
|
||||
if found {
|
||||
res.Write([]byte(fmt.Sprintf(`{ "found": true, "canonical": "%s" }`, canonicalUsername)))
|
||||
} else {
|
||||
|
|
|
@ -85,7 +85,7 @@ func AssetUpload(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
asset, err := assets.Create(c.Context(), c.Conn, assets.CreateInput{
|
||||
asset, err := assets.Create(c, c.Conn, assets.CreateInput{
|
||||
Content: data,
|
||||
Filename: originalFilename,
|
||||
ContentType: mimeType,
|
||||
|
|
|
@ -28,7 +28,7 @@ type LoginPageData struct {
|
|||
|
||||
func LoginPage(c *RequestContext) ResponseData {
|
||||
if c.CurrentUser != nil {
|
||||
return RejectRequest(c, "You are already logged in.")
|
||||
return c.RejectRequest("You are already logged in.")
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
|
@ -75,7 +75,7 @@ func Login(c *RequestContext) ResponseData {
|
|||
return res
|
||||
}
|
||||
|
||||
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||
user, err := db.QueryOne[models.User](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM hmn_user
|
||||
|
@ -102,7 +102,7 @@ func Login(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if user.Status == models.UserStatusInactive {
|
||||
return RejectRequest(c, "You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.")
|
||||
return c.RejectRequest("You must validate your email address before logging in. You should've received an email shortly after registration. If you did not receive the email, please contact the staff.")
|
||||
}
|
||||
|
||||
res := c.Redirect(redirect, http.StatusSeeOther)
|
||||
|
@ -136,7 +136,7 @@ func RegisterNewUser(c *RequestContext) ResponseData {
|
|||
|
||||
func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
||||
if c.CurrentUser != nil {
|
||||
return RejectRequest(c, "Can't register new user. You are already logged in")
|
||||
return c.RejectRequest("Can't register new user. You are already logged in")
|
||||
}
|
||||
c.Req.ParseForm()
|
||||
|
||||
|
@ -146,16 +146,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
password := c.Req.Form.Get("password")
|
||||
password2 := c.Req.Form.Get("password2")
|
||||
if !UsernameRegex.Match([]byte(username)) {
|
||||
return RejectRequest(c, "Invalid username")
|
||||
return c.RejectRequest("Invalid username")
|
||||
}
|
||||
if !email.IsEmail(emailAddress) {
|
||||
return RejectRequest(c, "Invalid email address")
|
||||
return c.RejectRequest("Invalid email address")
|
||||
}
|
||||
if len(password) < 8 {
|
||||
return RejectRequest(c, "Password too short")
|
||||
return c.RejectRequest("Password too short")
|
||||
}
|
||||
if password != password2 {
|
||||
return RejectRequest(c, "Password confirmation doesn't match password")
|
||||
return c.RejectRequest("Password confirmation doesn't match password")
|
||||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Check blacklist")
|
||||
|
@ -169,7 +169,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
|
||||
userAlreadyExists := true
|
||||
_, err := db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||
_, err := db.QueryOneScalar[int](c, c.Conn,
|
||||
`
|
||||
SELECT id
|
||||
FROM hmn_user
|
||||
|
@ -186,11 +186,11 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if userAlreadyExists {
|
||||
return RejectRequest(c, fmt.Sprintf("Username (%s) already exists.", username))
|
||||
return c.RejectRequest(fmt.Sprintf("Username (%s) already exists.", username))
|
||||
}
|
||||
|
||||
emailAlreadyExists := true
|
||||
_, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||
_, err = db.QueryOneScalar[int](c, c.Conn,
|
||||
`
|
||||
SELECT id
|
||||
FROM hmn_user
|
||||
|
@ -215,16 +215,16 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
hashed := auth.HashPassword(password)
|
||||
|
||||
c.Perf.StartBlock("SQL", "Create user and one time token")
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
var newUserId int
|
||||
err = tx.QueryRow(c.Context(),
|
||||
err = tx.QueryRow(c,
|
||||
`
|
||||
INSERT INTO hmn_user (username, email, password, date_joined, name, registration_ip)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
|
@ -237,7 +237,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
ott := models.GenerateToken()
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
|
||||
VALUES($1, $2, $3, $4, $5)
|
||||
|
@ -263,7 +263,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Commit user")
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit user to the db"))
|
||||
}
|
||||
|
@ -302,7 +302,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
|
|||
|
||||
username, hasUsername := c.PathParams["username"]
|
||||
if !hasUsername {
|
||||
return RejectRequest(c, "Bad validation url")
|
||||
return c.RejectRequest("Bad validation url")
|
||||
}
|
||||
|
||||
token := ""
|
||||
|
@ -319,7 +319,7 @@ func EmailConfirmation(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if !hasToken {
|
||||
return RejectRequest(c, "Bad validation url")
|
||||
return c.RejectRequest("Bad validation url")
|
||||
}
|
||||
|
||||
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypeRegistration)
|
||||
|
@ -366,13 +366,13 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Updating user status and deleting token")
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -385,7 +385,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
|
||||
}
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM one_time_token WHERE id = $1
|
||||
`,
|
||||
|
@ -395,7 +395,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete one time token"))
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit transaction"))
|
||||
}
|
||||
|
@ -413,7 +413,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData {
|
|||
// NOTE(asaf): Only call this when validationResult.Match is false.
|
||||
func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, validationResult validateUserAndTokenResult) ResponseData {
|
||||
if validationResult.User == nil {
|
||||
return RejectRequest(c, "You haven't validated your email in time and your user was deleted. You may try registering again with the same username.")
|
||||
return c.RejectRequest("You haven't validated your email in time and your user was deleted. You may try registering again with the same username.")
|
||||
}
|
||||
|
||||
if validationResult.OneTimeToken == nil {
|
||||
|
@ -422,7 +422,7 @@ func makeResponseForBadRegistrationTokenValidationResult(c *RequestContext, vali
|
|||
return c.Redirect(hmnurl.BuildLoginPage(""), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
return RejectRequest(c, "Bad token. If you are having problems registering or logging in, please contact the staff.")
|
||||
return c.RejectRequest("Bad token. If you are having problems registering or logging in, please contact the staff.")
|
||||
}
|
||||
|
||||
// NOTE(asaf): PasswordReset refers specifically to "forgot your password" flow over email,
|
||||
|
@ -446,14 +446,14 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
emailAddress := strings.TrimSpace(c.Req.Form.Get("email"))
|
||||
|
||||
if username == "" && emailAddress == "" {
|
||||
return RejectRequest(c, "You must provide a username and an email address.")
|
||||
return c.RejectRequest("You must provide a username and an email address.")
|
||||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetching user")
|
||||
type userQuery struct {
|
||||
User models.User `db:"hmn_user"`
|
||||
}
|
||||
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||
user, err := db.QueryOne[models.User](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM hmn_user
|
||||
|
@ -473,7 +473,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
if user != nil {
|
||||
c.Perf.StartBlock("SQL", "Fetching existing token")
|
||||
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||
resetToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM one_time_token
|
||||
|
@ -495,7 +495,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
if resetToken != nil {
|
||||
if resetToken.Expires.Before(now.Add(time.Minute * 30)) { // NOTE(asaf): Expired or about to expire
|
||||
c.Perf.StartBlock("SQL", "Deleting expired token")
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
_, err = c.Conn.Exec(c,
|
||||
`
|
||||
DELETE FROM one_time_token
|
||||
WHERE id = $1
|
||||
|
@ -512,7 +512,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
if resetToken == nil {
|
||||
c.Perf.StartBlock("SQL", "Creating new token")
|
||||
newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||
newToken, err := db.QueryOne[models.OneTimeToken](c, c.Conn,
|
||||
`
|
||||
INSERT INTO one_time_token (token_type, created, expires, token_content, owner_id)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
|
@ -567,12 +567,12 @@ func DoPasswordReset(c *RequestContext) ResponseData {
|
|||
token, hasToken := c.PathParams["token"]
|
||||
|
||||
if !hasToken || !hasUsername {
|
||||
return RejectRequest(c, "Bad validation url.")
|
||||
return c.RejectRequest("Bad validation url.")
|
||||
}
|
||||
|
||||
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
|
||||
if !validationResult.Match {
|
||||
return RejectRequest(c, "Bad validation url.")
|
||||
return c.RejectRequest("Bad validation url.")
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
|
@ -601,30 +601,30 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
validationResult := validateUsernameAndToken(c, username, token, models.TokenTypePasswordReset)
|
||||
if !validationResult.Match {
|
||||
return RejectRequest(c, "Bad validation url.")
|
||||
return c.RejectRequest("Bad validation url.")
|
||||
}
|
||||
|
||||
if c.CurrentUser != nil && c.CurrentUser.ID != validationResult.User.ID {
|
||||
return RejectRequest(c, fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username))
|
||||
return c.RejectRequest(fmt.Sprintf("Can't change password for %s. You are logged in as %s.", username, c.CurrentUser.Username))
|
||||
}
|
||||
|
||||
if len(password) < 8 {
|
||||
return RejectRequest(c, "Password too short")
|
||||
return c.RejectRequest("Password too short")
|
||||
}
|
||||
if password != password2 {
|
||||
return RejectRequest(c, "Password confirmation doesn't match password")
|
||||
return c.RejectRequest("Password confirmation doesn't match password")
|
||||
}
|
||||
|
||||
hashed := auth.HashPassword(password)
|
||||
|
||||
c.Perf.StartBlock("SQL", "Update user's password and delete reset token")
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
tag, err := tx.Exec(c.Context(),
|
||||
tag, err := tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET password = $1
|
||||
|
@ -638,7 +638,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if validationResult.User.Status == models.UserStatusInactive {
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -652,7 +652,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM one_time_token
|
||||
WHERE id = $1
|
||||
|
@ -663,7 +663,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete onetimetoken"))
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit password reset to the db"))
|
||||
}
|
||||
|
@ -698,7 +698,7 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
|
|||
// re-hash and save the user's password if necessary
|
||||
if hashed.IsOutdated() {
|
||||
newHashed := auth.HashPassword(password)
|
||||
err := auth.UpdatePassword(c.Context(), c.Conn, user.Username, newHashed)
|
||||
err := auth.UpdatePassword(c, c.Conn, user.Username, newHashed)
|
||||
if err != nil {
|
||||
c.Logger.Error().Err(err).Msg("failed to update user's password")
|
||||
}
|
||||
|
@ -711,15 +711,15 @@ func tryLogin(c *RequestContext, user *models.User, password string) (bool, erro
|
|||
func loginUser(c *RequestContext, user *models.User, responseData *ResponseData) error {
|
||||
c.Perf.StartBlock("SQL", "Setting last login and creating session")
|
||||
defer c.Perf.EndBlock()
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return oops.New(err, "failed to start db transaction")
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET last_login = $1
|
||||
|
@ -732,12 +732,12 @@ func loginUser(c *RequestContext, user *models.User, responseData *ResponseData)
|
|||
return oops.New(err, "failed to update last_login for user")
|
||||
}
|
||||
|
||||
session, err := auth.CreateSession(c.Context(), c.Conn, user.Username)
|
||||
session, err := auth.CreateSession(c, c.Conn, user.Username)
|
||||
if err != nil {
|
||||
return oops.New(err, "failed to create session")
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return oops.New(err, "failed to commit transaction")
|
||||
}
|
||||
|
@ -749,7 +749,7 @@ func logoutUser(c *RequestContext, res *ResponseData) {
|
|||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||
if err == nil {
|
||||
// clear the session from the db immediately, no expiration
|
||||
err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value)
|
||||
err := auth.DeleteSession(c, c.Conn, sessionCookie.Value)
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Msg("failed to delete session on logout")
|
||||
}
|
||||
|
@ -772,7 +772,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
|||
User models.User `db:"hmn_user"`
|
||||
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
||||
}
|
||||
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
|
||||
data, err := db.QueryOne[userAndTokenQuery](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM hmn_user
|
||||
|
|
|
@ -37,7 +37,7 @@ func BlogIndex(c *RequestContext) ResponseData {
|
|||
|
||||
const postsPerPage = 20
|
||||
|
||||
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -51,7 +51,7 @@ func BlogIndex(c *RequestContext) ResponseData {
|
|||
return c.Redirect(c.UrlContext.BuildBlog(page), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
Limit: postsPerPage,
|
||||
|
@ -78,7 +78,7 @@ func BlogIndex(c *RequestContext) ResponseData {
|
|||
canCreate := false
|
||||
if c.CurrentProject.HasBlog() && c.CurrentUser != nil {
|
||||
isProjectOwner := false
|
||||
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID)
|
||||
owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project owners"))
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func BlogThread(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
thread, posts, err := hmndata.FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{
|
||||
thread, posts, err := hmndata.FetchThreadPosts(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -155,7 +155,7 @@ func BlogThread(c *RequestContext) ResponseData {
|
|||
// Update thread last read info
|
||||
if c.CurrentUser != nil {
|
||||
c.Perf.StartBlock("SQL", "Update TLRI")
|
||||
_, err := c.Conn.Exec(c.Context(),
|
||||
_, err := c.Conn.Exec(c,
|
||||
`
|
||||
INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
|
||||
VALUES ($1, $2, $3)
|
||||
|
@ -196,7 +196,7 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
|
||||
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -227,11 +227,11 @@ func BlogNewThread(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func BlogNewThreadSubmit(c *RequestContext) ResponseData {
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
err = c.Req.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -240,15 +240,15 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
|
|||
title := c.Req.Form.Get("title")
|
||||
unparsed := c.Req.Form.Get("body")
|
||||
if title == "" {
|
||||
return RejectRequest(c, "You must provide a title for your post.")
|
||||
return c.RejectRequest("You must provide a title for your post.")
|
||||
}
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "You must provide a body for your post.")
|
||||
return c.RejectRequest("You must provide a body for your post.")
|
||||
}
|
||||
|
||||
// Create thread
|
||||
var threadId int
|
||||
err = tx.QueryRow(c.Context(),
|
||||
err = tx.QueryRow(c,
|
||||
`
|
||||
INSERT INTO thread (title, type, project_id, first_id, last_id)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
|
@ -265,9 +265,9 @@ func BlogNewThreadSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// Create everything else
|
||||
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
|
||||
hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new blog post"))
|
||||
}
|
||||
|
@ -282,11 +282,11 @@ func BlogPostEdit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -326,17 +326,17 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -351,16 +351,16 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
|
|||
unparsed := c.Req.Form.Get("body")
|
||||
editReason := c.Req.Form.Get("editreason")
|
||||
if title != "" && post.Thread.FirstID != post.Post.ID {
|
||||
return RejectRequest(c, "You can only edit the title by editing the first post.")
|
||||
return c.RejectRequest("You can only edit the title by editing the first post.")
|
||||
}
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "You must provide a post body.")
|
||||
return c.RejectRequest("You must provide a post body.")
|
||||
}
|
||||
|
||||
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
|
||||
hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
|
||||
|
||||
if title != "" {
|
||||
_, err := tx.Exec(c.Context(),
|
||||
_, err := tx.Exec(c,
|
||||
`
|
||||
UPDATE thread SET title = $1 WHERE id = $2
|
||||
`,
|
||||
|
@ -372,7 +372,7 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post"))
|
||||
}
|
||||
|
@ -387,7 +387,7 @@ func BlogPostReply(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -421,11 +421,11 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
err = c.Req.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -433,12 +433,12 @@ func BlogPostReplySubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
unparsed := c.Req.Form.Get("body")
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "Your reply cannot be empty.")
|
||||
return c.RejectRequest("Your reply cannot be empty.")
|
||||
}
|
||||
|
||||
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host)
|
||||
newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, cd.ThreadID, models.ThreadTypeProjectBlogPost, c.CurrentUser.ID, &cd.PostID, unparsed, c.Req.Host)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to blog post"))
|
||||
}
|
||||
|
@ -453,11 +453,11 @@ func BlogPostDelete(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -503,19 +503,19 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID)
|
||||
threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
|
||||
}
|
||||
|
@ -523,7 +523,7 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData {
|
|||
if threadDeleted {
|
||||
return c.Redirect(c.UrlContext.BuildHomepage(), http.StatusSeeOther)
|
||||
} else {
|
||||
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
|
||||
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, cd.ThreadID, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
})
|
||||
|
@ -560,7 +560,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
|||
res.ThreadID = threadId
|
||||
|
||||
c.Perf.StartBlock("SQL", "Verify that the thread exists")
|
||||
threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||
threadExists, err := db.QueryOneScalar[bool](c, c.Conn,
|
||||
`
|
||||
SELECT COUNT(*) > 0
|
||||
FROM thread
|
||||
|
@ -588,7 +588,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
|||
res.PostID = postId
|
||||
|
||||
c.Perf.StartBlock("SQL", "Verify that the post exists")
|
||||
postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||
postExists, err := db.QueryOneScalar[bool](c, c.Conn,
|
||||
`
|
||||
SELECT COUNT(*) > 0
|
||||
FROM post
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/auth"
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"git.handmade.network/hmn/hmn/src/hmndata"
|
||||
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
"git.handmade.network/hmn/hmn/src/models"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
)
|
||||
|
||||
func loadCommonData(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
|
||||
{
|
||||
// get user
|
||||
{
|
||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||
if err == nil {
|
||||
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
|
||||
}
|
||||
|
||||
c.CurrentUser = user
|
||||
c.CurrentSession = session
|
||||
}
|
||||
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
|
||||
}
|
||||
|
||||
// get current official project (HMN or otherwise, by subdomain)
|
||||
{
|
||||
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost())
|
||||
slug := strings.TrimRight(hostPrefix, ".")
|
||||
var owners []*models.User
|
||||
|
||||
if len(slug) > 0 {
|
||||
dbProject, err := hmndata.FetchProjectBySlug(c, c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err == nil {
|
||||
c.CurrentProject = &dbProject.Project
|
||||
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
|
||||
owners = dbProject.Owners
|
||||
} else {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
// do nothing, this is fine
|
||||
} else {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.CurrentProject == nil {
|
||||
dbProject, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(oops.New(err, "failed to fetch HMN project"))
|
||||
}
|
||||
c.CurrentProject = &dbProject.Project
|
||||
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
|
||||
}
|
||||
|
||||
if c.CurrentProject == nil {
|
||||
panic("failed to load project data")
|
||||
}
|
||||
|
||||
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
|
||||
|
||||
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
|
||||
}
|
||||
|
||||
c.Theme = "light"
|
||||
if c.CurrentUser != nil && c.CurrentUser.DarkTheme {
|
||||
c.Theme = "dark"
|
||||
}
|
||||
}
|
||||
c.Perf.EndBlock()
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Given a session id, fetches user data from the database. Will return nil if
|
||||
// the user cannot be found, and will only return an error if it's serious.
|
||||
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
|
||||
session, err := auth.GetSession(c, c.Conn, sessionId)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrNoSession) {
|
||||
return nil, nil, nil
|
||||
} else {
|
||||
return nil, nil, oops.New(err, "failed to get current session")
|
||||
}
|
||||
}
|
||||
|
||||
user, err := hmndata.FetchUserByUsername(c, c.Conn, nil, session.Username, hmndata.UsersQuery{
|
||||
AnyStatus: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
|
||||
return nil, nil, nil // user was deleted or something
|
||||
} else {
|
||||
return nil, nil, oops.New(err, "failed to get user for session")
|
||||
}
|
||||
}
|
||||
|
||||
return user, session, nil
|
||||
}
|
||||
|
||||
func addCORSHeaders(c *RequestContext, res *ResponseData) {
|
||||
parsed, err := url.Parse(config.Config.BaseUrl)
|
||||
if err != nil {
|
||||
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
|
||||
return
|
||||
}
|
||||
origin := ""
|
||||
origins, found := c.Req.Header["Origin"]
|
||||
if found {
|
||||
origin = origins[0]
|
||||
}
|
||||
if strings.HasSuffix(origin, parsed.Host) {
|
||||
res.Header().Add("Access-Control-Allow-Origin", origin)
|
||||
res.Header().Add("Access-Control-Allow-Credentials", "true")
|
||||
res.Header().Add("Vary", "Origin")
|
||||
}
|
||||
}
|
|
@ -35,31 +35,31 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
|
|||
// This occurs when the user cancels. Just go back to the profile page.
|
||||
return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther)
|
||||
} else {
|
||||
return RejectRequest(c, "Failed to authenticate with Discord.")
|
||||
return c.RejectRequest("Failed to authenticate with Discord.")
|
||||
}
|
||||
}
|
||||
|
||||
// Do the actual token exchange
|
||||
code := query.Get("code")
|
||||
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback())
|
||||
res, err := discord.ExchangeOAuthCode(c, code, hmnurl.BuildDiscordOAuthCallback())
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
|
||||
}
|
||||
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
|
||||
|
||||
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken)
|
||||
user, err := discord.GetCurrentUserAsOAuth(c, res.AccessToken)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
|
||||
}
|
||||
|
||||
// Add the role on Discord
|
||||
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID)
|
||||
err = discord.AddGuildMemberRole(c, user.ID, config.Config.Discord.MemberRoleID)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
|
||||
}
|
||||
|
||||
// Add the user to our database
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
_, err = c.Conn.Exec(c,
|
||||
`
|
||||
INSERT INTO discord_user (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
|
@ -79,7 +79,7 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if c.CurrentUser.Status == models.UserStatusConfirmed {
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
_, err = c.Conn.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -98,13 +98,13 @@ func DiscordOAuthCallback(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func DiscordUnlink(c *RequestContext) ResponseData {
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx,
|
||||
discordUser, err := db.QueryOne[models.DiscordUser](c, tx,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM discord_user
|
||||
|
@ -120,7 +120,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM discord_user
|
||||
WHERE id = $1
|
||||
|
@ -131,12 +131,12 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete Discord user"))
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit Discord user delete"))
|
||||
}
|
||||
|
||||
err = discord.RemoveGuildMemberRole(c.Context(), discordUser.UserID, config.Config.Discord.MemberRoleID)
|
||||
err = discord.RemoveGuildMemberRole(c, discordUser.UserID, config.Config.Discord.MemberRoleID)
|
||||
if err != nil {
|
||||
c.Logger.Warn().Err(err).Msg("failed to remove member role on unlink")
|
||||
}
|
||||
|
@ -145,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||
duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
|
||||
`SELECT $columns FROM discord_user WHERE hmn_user_id = $1`,
|
||||
c.CurrentUser.ID,
|
||||
)
|
||||
|
@ -157,7 +157,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
|
||||
}
|
||||
|
||||
msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||
msgIDs, err := db.QueryScalar[string](c, c.Conn,
|
||||
`
|
||||
SELECT msg.id
|
||||
FROM
|
||||
|
@ -174,12 +174,12 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
for _, msgID := range msgIDs {
|
||||
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
|
||||
interned, err := discord.FetchInternedMessage(c, c.Conn, msgID)
|
||||
if err != nil && !errors.Is(err, db.NotFound) {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
} else if err == nil {
|
||||
// NOTE(asaf): Creating snippet even if the checkbox is off because the user asked us to.
|
||||
err = discord.HandleSnippetForInternedMessage(c.Context(), c.Conn, interned, true)
|
||||
err = discord.HandleSnippetForInternedMessage(c, c.Conn, interned, true)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,31 @@
|
|||
package website
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
)
|
||||
|
||||
func FourOhFour(c *RequestContext) ResponseData {
|
||||
var res ResponseData
|
||||
res.StatusCode = http.StatusNotFound
|
||||
|
||||
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
|
||||
templateData := struct {
|
||||
templates.BaseData
|
||||
Wanted string
|
||||
}{
|
||||
BaseData: getBaseData(c, "Page not found", nil),
|
||||
Wanted: c.FullUrl(),
|
||||
}
|
||||
res.MustWriteTemplate("404.html", templateData, c.Perf)
|
||||
} else {
|
||||
res.Write([]byte("Not Found"))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// A SafeError can be used to wrap another error and explicitly provide
|
||||
// an error message that is safe to show to a user. This allows the original
|
||||
|
|
|
@ -33,7 +33,7 @@ var feedThreadTypes = []models.ThreadType{
|
|||
}
|
||||
|
||||
func Feed(c *RequestContext) ResponseData {
|
||||
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ThreadTypes: feedThreadTypes,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -156,7 +156,7 @@ func AtomFeed(c *RequestContext) ResponseData {
|
|||
if hasAll {
|
||||
itemsPerFeed = 100000
|
||||
}
|
||||
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, nil, hmndata.ProjectsQuery{
|
||||
projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, nil, hmndata.ProjectsQuery{
|
||||
Limit: itemsPerFeed,
|
||||
Types: hmndata.OfficialProjects,
|
||||
OrderBy: "date_approved DESC",
|
||||
|
@ -188,7 +188,7 @@ func AtomFeed(c *RequestContext) ResponseData {
|
|||
feedData.AtomFeedUrl = hmnurl.BuildAtomFeedForShowcase()
|
||||
feedData.FeedUrl = hmnurl.BuildShowcase()
|
||||
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
Limit: itemsPerFeed,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -215,7 +215,7 @@ func AtomFeed(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostListItem, error) {
|
||||
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ThreadTypes: feedThreadTypes,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
|
@ -225,7 +225,7 @@ func fetchAllPosts(c *RequestContext, offset int, limit int) ([]templates.PostLi
|
|||
return nil, err
|
||||
}
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ func Fishbowl(c *RequestContext) ResponseData {
|
|||
func FishbowlFiles(c *RequestContext) ResponseData {
|
||||
var res ResponseData
|
||||
fishbowlHTTPFS.ServeHTTP(&res, c.Req)
|
||||
AddCORSHeaders(c, &res)
|
||||
addCORSHeaders(c, &res)
|
||||
return res
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,7 @@ func linkifyDiscordContent(c *RequestContext, dbConn db.ConnOrTx, content string
|
|||
discordUserIds = append(discordUserIds, id)
|
||||
}
|
||||
|
||||
hmnUsers, err := hmndata.FetchUsers(c.Context(), dbConn, c.CurrentUser, hmndata.UsersQuery{
|
||||
hmnUsers, err := hmndata.FetchUsers(c, dbConn, c.CurrentUser, hmndata.UsersQuery{
|
||||
DiscordUserIDs: discordUserIds,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -91,7 +91,7 @@ func Forum(c *RequestContext) ResponseData {
|
|||
|
||||
currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID)
|
||||
|
||||
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
SubforumIDs: []int{cd.SubforumID},
|
||||
|
@ -107,7 +107,7 @@ func Forum(c *RequestContext) ResponseData {
|
|||
}
|
||||
howManyThreadsToSkip := (page - 1) * threadsPerPage
|
||||
|
||||
mainThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
mainThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
SubforumIDs: []int{cd.SubforumID},
|
||||
|
@ -141,7 +141,7 @@ func Forum(c *RequestContext) ResponseData {
|
|||
subforumNodes := cd.SubforumTree[cd.SubforumID].Children
|
||||
|
||||
for _, sfNode := range subforumNodes {
|
||||
numThreads, err := hmndata.CountThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
numThreads, err := hmndata.CountThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
SubforumIDs: []int{sfNode.ID},
|
||||
|
@ -150,7 +150,7 @@ func Forum(c *RequestContext) ResponseData {
|
|||
panic(oops.New(err, "failed to get count of threads"))
|
||||
}
|
||||
|
||||
subforumThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
subforumThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
SubforumIDs: []int{sfNode.ID},
|
||||
|
@ -203,7 +203,7 @@ func Forum(c *RequestContext) ResponseData {
|
|||
|
||||
func ForumMarkRead(c *RequestContext) ResponseData {
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
@ -212,16 +212,16 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
sfIds := []int{sfId}
|
||||
if sfId == 0 {
|
||||
// Mark literally everything as read
|
||||
_, err := tx.Exec(c.Context(),
|
||||
_, err := tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET marked_all_read_at = NOW()
|
||||
|
@ -234,7 +234,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// Delete thread unread info
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM thread_last_read_info
|
||||
WHERE user_id = $1;
|
||||
|
@ -246,7 +246,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// Delete subforum unread info
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM subforum_last_read_info
|
||||
WHERE user_id = $1;
|
||||
|
@ -258,7 +258,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
}
|
||||
} else {
|
||||
c.Perf.StartBlock("SQL", "Update SLRIs")
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
INSERT INTO subforum_last_read_info (subforum_id, user_id, lastread)
|
||||
SELECT id, $2, $3
|
||||
|
@ -277,7 +277,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Delete TLRIs")
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
DELETE FROM thread_last_read_info
|
||||
WHERE
|
||||
|
@ -298,7 +298,7 @@ func ForumMarkRead(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to commit SLRI/TLRI updates"))
|
||||
}
|
||||
|
@ -332,7 +332,7 @@ func ForumThread(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
threads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
threads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadIDs: []int{cd.ThreadID},
|
||||
})
|
||||
|
@ -351,7 +351,7 @@ func ForumThread(c *RequestContext) ResponseData {
|
|||
return c.Redirect(correctThreadUrl, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
ThreadIDs: []int{cd.ThreadID},
|
||||
|
@ -374,7 +374,7 @@ func ForumThread(c *RequestContext) ResponseData {
|
|||
PreviousUrl: c.UrlContext.BuildForumThread(currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)),
|
||||
}
|
||||
|
||||
postsAndStuff, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
postsAndStuff, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadIDs: []int{thread.ID},
|
||||
Limit: threadPostsPerPage,
|
||||
|
@ -396,7 +396,7 @@ func ForumThread(c *RequestContext) ResponseData {
|
|||
post.ReplyPost = &reply
|
||||
}
|
||||
|
||||
addAuthorCountsToPost(c.Context(), c.Conn, &post)
|
||||
addAuthorCountsToPost(c, c.Conn, &post)
|
||||
|
||||
posts = append(posts, post)
|
||||
}
|
||||
|
@ -404,7 +404,7 @@ func ForumThread(c *RequestContext) ResponseData {
|
|||
// Update thread last read info
|
||||
if c.CurrentUser != nil {
|
||||
c.Perf.StartBlock("SQL", "Update TLRI")
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
_, err = c.Conn.Exec(c,
|
||||
`
|
||||
INSERT INTO thread_last_read_info (thread_id, user_id, lastread)
|
||||
VALUES ($1, $2, $3)
|
||||
|
@ -445,7 +445,7 @@ func ForumPostRedirect(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
ThreadIDs: []int{cd.ThreadID},
|
||||
|
@ -495,11 +495,11 @@ func ForumNewThread(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func ForumNewThreadSubmit(c *RequestContext) ResponseData {
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
cd, ok := getCommonForumData(c)
|
||||
if !ok {
|
||||
|
@ -517,15 +517,15 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
|
|||
sticky = true
|
||||
}
|
||||
if title == "" {
|
||||
return RejectRequest(c, "You must provide a title for your post.")
|
||||
return c.RejectRequest("You must provide a title for your post.")
|
||||
}
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "You must provide a body for your post.")
|
||||
return c.RejectRequest("You must provide a body for your post.")
|
||||
}
|
||||
|
||||
// Create thread
|
||||
var threadId int
|
||||
err = tx.QueryRow(c.Context(),
|
||||
err = tx.QueryRow(c,
|
||||
`
|
||||
INSERT INTO thread (title, sticky, type, project_id, subforum_id, first_id, last_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
|
@ -544,9 +544,9 @@ func ForumNewThreadSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// Create everything else
|
||||
hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
|
||||
hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, threadId, models.ThreadTypeForumPost, c.CurrentUser.ID, nil, unparsed, c.Req.Host)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create new forum thread"))
|
||||
}
|
||||
|
@ -561,7 +561,7 @@ func ForumPostReply(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
})
|
||||
|
@ -600,11 +600,11 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
err = c.Req.ParseForm()
|
||||
if err != nil {
|
||||
|
@ -612,10 +612,10 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
unparsed := c.Req.Form.Get("body")
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "Your reply cannot be empty.")
|
||||
return c.RejectRequest("Your reply cannot be empty.")
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
})
|
||||
|
@ -629,9 +629,9 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData {
|
|||
replyPostId = &post.Post.ID
|
||||
}
|
||||
|
||||
newPostId, _ := hmndata.CreateNewPost(c.Context(), tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host)
|
||||
newPostId, _ := hmndata.CreateNewPost(c, tx, c.CurrentProject.ID, post.Thread.ID, models.ThreadTypeForumPost, c.CurrentUser.ID, replyPostId, unparsed, c.Req.Host)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to reply to forum post"))
|
||||
}
|
||||
|
@ -646,11 +646,11 @@ func ForumPostEdit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
})
|
||||
|
@ -688,17 +688,17 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, tx, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
})
|
||||
|
@ -713,16 +713,16 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
|
|||
unparsed := c.Req.Form.Get("body")
|
||||
editReason := c.Req.Form.Get("editreason")
|
||||
if title != "" && post.Thread.FirstID != post.Post.ID {
|
||||
return RejectRequest(c, "You can only edit the title by editing the first post.")
|
||||
return c.RejectRequest("You can only edit the title by editing the first post.")
|
||||
}
|
||||
if unparsed == "" {
|
||||
return RejectRequest(c, "You must provide a body for your post.")
|
||||
return c.RejectRequest("You must provide a body for your post.")
|
||||
}
|
||||
|
||||
hmndata.CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
|
||||
hmndata.CreatePostVersion(c, tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID)
|
||||
|
||||
if title != "" {
|
||||
_, err := tx.Exec(c.Context(),
|
||||
_, err := tx.Exec(c,
|
||||
`
|
||||
UPDATE thread SET title = $1 WHERE id = $2
|
||||
`,
|
||||
|
@ -734,7 +734,7 @@ func ForumPostEditSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit forum post"))
|
||||
}
|
||||
|
@ -749,11 +749,11 @@ func ForumPostDelete(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
post, err := hmndata.FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
post, err := hmndata.FetchThreadPost(c, c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost},
|
||||
})
|
||||
|
@ -798,19 +798,19 @@ func ForumPostDeleteSubmit(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
if !hmndata.UserCanEditPost(c.Context(), c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
if !hmndata.UserCanEditPost(c, c.Conn, *c.CurrentUser, cd.PostID) {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
threadDeleted := hmndata.DeletePost(c.Context(), tx, cd.ThreadID, cd.PostID)
|
||||
threadDeleted := hmndata.DeletePost(c, tx, cd.ThreadID, cd.PostID)
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete post"))
|
||||
}
|
||||
|
@ -831,7 +831,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
thread, err := hmndata.FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{
|
||||
thread, err := hmndata.FetchThread(c, c.Conn, c.CurrentUser, threadId, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
// This is the rare query where we want all thread types!
|
||||
})
|
||||
|
@ -842,7 +842,7 @@ func WikiArticleRedirect(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
@ -874,7 +874,7 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) {
|
|||
defer c.Perf.EndBlock()
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
|
|||
img.Seek(0, io.SeekStart)
|
||||
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
|
||||
sha1sum := hasher.Sum(nil)
|
||||
imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn,
|
||||
imageFile, err := db.QueryOne[models.ImageFile](c, dbConn,
|
||||
`
|
||||
INSERT INTO image_file (file, size, sha1sum, protected, width, height)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
|
|
|
@ -51,7 +51,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
tagId := -1
|
||||
jamTag, err := hmndata.FetchTag(c.Context(), c.Conn, hmndata.TagQuery{
|
||||
jamTag, err := hmndata.FetchTag(c, c.Conn, hmndata.TagQuery{
|
||||
Text: []string{"wheeljam"},
|
||||
})
|
||||
if err == nil {
|
||||
|
@ -60,7 +60,7 @@ func JamIndex2021(c *RequestContext) ResponseData {
|
|||
c.Logger.Warn().Err(err).Msg("failed to fetch jam tag; will fetch all snippets as a result")
|
||||
}
|
||||
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
Tags: []int{tagId},
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -34,13 +34,13 @@ type LandingTemplateData struct {
|
|||
|
||||
func Index(c *RequestContext) ResponseData {
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
var timelineItems []templates.TimelineItem
|
||||
|
||||
numPosts, err := hmndata.CountPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
numPosts, err := hmndata.CountPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ThreadTypes: feedThreadTypes,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -65,7 +65,7 @@ func Index(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// This is essentially an alternate for feed page 1.
|
||||
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ThreadTypes: feedThreadTypes,
|
||||
Limit: feedPostsPerPage,
|
||||
SortDescending: true,
|
||||
|
@ -84,7 +84,7 @@ func Index(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Get news")
|
||||
newsThreads, err := hmndata.FetchThreads(c.Context(), c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
newsThreads, err := hmndata.FetchThreads(c, c.Conn, c.CurrentUser, hmndata.ThreadsQuery{
|
||||
ProjectIDs: []int{models.HMNProjectID},
|
||||
ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost},
|
||||
Limit: 1,
|
||||
|
@ -106,7 +106,7 @@ func Index(c *RequestContext) ResponseData {
|
|||
}
|
||||
c.Perf.EndBlock()
|
||||
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
Limit: 40,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/auth"
|
||||
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"git.handmade.network/hmn/hmn/src/perf"
|
||||
"git.handmade.network/hmn/hmn/src/utils"
|
||||
)
|
||||
|
||||
func panicCatcherMiddleware(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
maybeError, ok := recovered.(*error)
|
||||
var err error
|
||||
if ok {
|
||||
err = *maybeError
|
||||
} else {
|
||||
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
|
||||
}
|
||||
res = c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
}()
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func trackRequestPerf(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
|
||||
defer func() {
|
||||
c.Perf.EndRequest()
|
||||
log := logging.Info()
|
||||
blockStack := make([]time.Time, 0)
|
||||
for i, block := range c.Perf.Blocks {
|
||||
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
|
||||
blockStack = blockStack[:len(blockStack)-1]
|
||||
}
|
||||
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
|
||||
blockStack = append(blockStack, block.End)
|
||||
}
|
||||
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
|
||||
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
|
||||
}()
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func needsAuth(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
if c.CurrentUser == nil {
|
||||
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func adminsOnly(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func csrfMiddleware(h Handler) Handler {
|
||||
// CSRF mitigation actions per the OWASP cheat sheet:
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
|
||||
return func(c *RequestContext) ResponseData {
|
||||
c.Req.ParseMultipartForm(100 * 1024 * 1024)
|
||||
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
|
||||
if csrfToken != c.CurrentSession.CSRFToken {
|
||||
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
|
||||
|
||||
res := c.Redirect("/", http.StatusSeeOther)
|
||||
logoutUser(c, &res)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func securityTimerMiddleware(duration time.Duration, h Handler) Handler {
|
||||
// NOTE(asaf): Will make sure that the request takes at least `duration` to finish. Adds a 10% random duration.
|
||||
return func(c *RequestContext) ResponseData {
|
||||
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
|
||||
timer := time.NewTimer(duration + additionalDuration)
|
||||
res := h(c)
|
||||
select {
|
||||
case <-c.Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
|
||||
func logContextErrors(c *RequestContext, errs ...error) {
|
||||
for _, err := range errs {
|
||||
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
|
||||
}
|
||||
}
|
||||
|
||||
func logContextErrorsMiddleware(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
res := h(c)
|
||||
logContextErrors(c, res.Errors...)
|
||||
return res
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
)
|
||||
|
||||
const NoticesCookieName = "hmn_notices"
|
||||
|
||||
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
|
||||
cookie, err := c.Req.Cookie(NoticesCookieName)
|
||||
if err != nil {
|
||||
if !errors.Is(err, http.ErrNoCookie) {
|
||||
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return deserializeNoticesFromCookie(cookie.Value)
|
||||
}
|
||||
|
||||
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
|
||||
serialized := serializeNoticesForCookie(c, res.FutureNotices)
|
||||
if serialized != "" {
|
||||
noticesCookie := http.Cookie{
|
||||
Name: NoticesCookieName,
|
||||
Value: serialized,
|
||||
Path: "/",
|
||||
Domain: config.Config.Auth.CookieDomain,
|
||||
Expires: time.Now().Add(time.Minute * 5),
|
||||
Secure: config.Config.Auth.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
res.SetCookie(¬icesCookie)
|
||||
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
|
||||
// NOTE(asaf): Don't clear on redirect
|
||||
noticesCookie := http.Cookie{
|
||||
Name: NoticesCookieName,
|
||||
Path: "/",
|
||||
Domain: config.Config.Auth.CookieDomain,
|
||||
MaxAge: -1,
|
||||
}
|
||||
res.SetCookie(¬icesCookie)
|
||||
}
|
||||
}
|
||||
|
||||
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
|
||||
var builder strings.Builder
|
||||
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
|
||||
size := 0
|
||||
for i, notice := range notices {
|
||||
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
|
||||
if i != 0 {
|
||||
sizeIncrease += 1
|
||||
}
|
||||
if size+sizeIncrease > maxSize {
|
||||
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
|
||||
break
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
builder.WriteString("\t")
|
||||
}
|
||||
builder.WriteString(notice.Class)
|
||||
builder.WriteString("|")
|
||||
builder.WriteString(string(notice.Content))
|
||||
|
||||
size += sizeIncrease
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
|
||||
var result []templates.Notice
|
||||
notices := strings.Split(cookieVal, "\t")
|
||||
for _, notice := range notices {
|
||||
parts := strings.SplitN(notice, "|", 2)
|
||||
if len(parts) == 2 {
|
||||
result = append(result, templates.Notice{
|
||||
Class: parts[0],
|
||||
Content: template.HTML(parts[1]),
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func storeNoticesInCookieMiddleware(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
res := h(c)
|
||||
storeNoticesInCookie(c, &res)
|
||||
return res
|
||||
}
|
||||
}
|
|
@ -126,29 +126,29 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
title := c.Req.Form.Get("title")
|
||||
if len(strings.TrimSpace(title)) == 0 {
|
||||
return RejectRequest(c, "Podcast title is empty")
|
||||
return c.RejectRequest("Podcast title is empty")
|
||||
}
|
||||
description := c.Req.Form.Get("description")
|
||||
if len(strings.TrimSpace(description)) == 0 {
|
||||
return RejectRequest(c, "Podcast description is empty")
|
||||
return c.RejectRequest("Podcast description is empty")
|
||||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Updating podcast")
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
imageSaveResult := SaveImageFile(c, tx, "podcast_image", maxFileSize, fmt.Sprintf("podcast/%s/logo%d", c.CurrentProject.Slug, time.Now().UTC().Unix()))
|
||||
if imageSaveResult.ValidationError != "" {
|
||||
return RejectRequest(c, imageSaveResult.ValidationError)
|
||||
return c.RejectRequest(imageSaveResult.ValidationError)
|
||||
} else if imageSaveResult.FatalError != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(imageSaveResult.FatalError, "Failed to save podcast image"))
|
||||
}
|
||||
|
||||
if imageSaveResult.ImageFile != nil {
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE podcast
|
||||
SET
|
||||
|
@ -166,7 +166,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to update podcast"))
|
||||
}
|
||||
} else {
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE podcast
|
||||
SET
|
||||
|
@ -179,7 +179,7 @@ func PodcastEditSubmit(c *RequestContext) ResponseData {
|
|||
podcastResult.Podcast.ID,
|
||||
)
|
||||
}
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
c.Perf.EndBlock()
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to commit db transaction"))
|
||||
|
@ -357,16 +357,16 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
|
|||
c.Req.ParseForm()
|
||||
title := c.Req.Form.Get("title")
|
||||
if len(strings.TrimSpace(title)) == 0 {
|
||||
return RejectRequest(c, "Episode title is empty")
|
||||
return c.RejectRequest("Episode title is empty")
|
||||
}
|
||||
description := c.Req.Form.Get("description")
|
||||
if len(strings.TrimSpace(description)) == 0 {
|
||||
return RejectRequest(c, "Episode description is empty")
|
||||
return c.RejectRequest("Episode description is empty")
|
||||
}
|
||||
episodeNumberStr := c.Req.Form.Get("episode_number")
|
||||
episodeNumber, err := strconv.Atoi(episodeNumberStr)
|
||||
if err != nil {
|
||||
return RejectRequest(c, "Episode number can't be parsed")
|
||||
return c.RejectRequest("Episode number can't be parsed")
|
||||
}
|
||||
episodeFile := c.Req.Form.Get("episode_file")
|
||||
found = false
|
||||
|
@ -378,7 +378,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
if !found {
|
||||
return RejectRequest(c, "Requested episode file not found")
|
||||
return c.RejectRequest("Requested episode file not found")
|
||||
}
|
||||
|
||||
c.Perf.StartBlock("MP3", "Parsing mp3 file for duration")
|
||||
|
@ -417,7 +417,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
|
|||
if isEdit {
|
||||
guidStr = podcastResult.Episodes[0].GUID.String()
|
||||
c.Perf.StartBlock("SQL", "Updating podcast episode")
|
||||
_, err := c.Conn.Exec(c.Context(),
|
||||
_, err := c.Conn.Exec(c,
|
||||
`
|
||||
UPDATE podcast_episode
|
||||
SET
|
||||
|
@ -446,7 +446,7 @@ func PodcastEpisodeSubmit(c *RequestContext) ResponseData {
|
|||
guid := uuid.New()
|
||||
guidStr = guid.String()
|
||||
c.Perf.StartBlock("SQL", "Creating new podcast episode")
|
||||
_, err := c.Conn.Exec(c.Context(),
|
||||
_, err := c.Conn.Exec(c,
|
||||
`
|
||||
INSERT INTO podcast_episode
|
||||
(guid, title, description, description_rendered, audio_filename, duration, pub_date, episode_number, podcast_id)
|
||||
|
@ -532,7 +532,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
|||
Podcast models.Podcast `db:"podcast"`
|
||||
ImageFilename string `db:"imagefile.file"`
|
||||
}
|
||||
podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn,
|
||||
podcastQueryResult, err := db.QueryOne[podcastQuery](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -558,7 +558,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
|||
if fetchEpisodes {
|
||||
if episodeGUID == "" {
|
||||
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
|
||||
episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn,
|
||||
episodes, err := db.Query[models.PodcastEpisode](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM podcast_episode AS episode
|
||||
|
@ -578,7 +578,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
|||
return result, err
|
||||
}
|
||||
c.Perf.StartBlock("SQL", "Fetch podcast episode")
|
||||
episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn,
|
||||
episode, err := db.QueryOne[models.PodcastEpisode](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM podcast_episode AS episode
|
||||
|
|
|
@ -26,11 +26,52 @@ import (
|
|||
"git.handmade.network/hmn/hmn/src/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/teacat/noire"
|
||||
)
|
||||
|
||||
const maxPersonalProjects = 5
|
||||
const maxProjectOwners = 5
|
||||
|
||||
func ProjectCSS(c *RequestContext) ResponseData {
|
||||
color := c.URL().Query().Get("color")
|
||||
if color == "" {
|
||||
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
|
||||
}
|
||||
|
||||
baseData := getBaseData(c, "", nil)
|
||||
|
||||
bgColor := noire.NewHex(color)
|
||||
h, s, l := bgColor.HSL()
|
||||
if baseData.Theme == "dark" {
|
||||
l = 15
|
||||
} else {
|
||||
l = 95
|
||||
}
|
||||
if s > 20 {
|
||||
s = 20
|
||||
}
|
||||
bgColor = noire.NewHSL(h, s, l)
|
||||
|
||||
templateData := struct {
|
||||
templates.BaseData
|
||||
Color string
|
||||
PostBgColor string
|
||||
}{
|
||||
BaseData: baseData,
|
||||
Color: color,
|
||||
PostBgColor: bgColor.HTML(),
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
res.Header().Add("Content-Type", "text/css")
|
||||
err := res.WriteTemplate("project.css", templateData, c.Perf)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type ProjectTemplateData struct {
|
||||
templates.BaseData
|
||||
|
||||
|
@ -48,7 +89,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
|
|||
const maxCarouselProjects = 10
|
||||
const maxPersonalProjects = 10
|
||||
|
||||
officialProjects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
officialProjects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
Types: hmndata.OfficialProjects,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -123,7 +164,7 @@ func ProjectIndex(c *RequestContext) ResponseData {
|
|||
// Fetch and highlight a random selection of personal projects
|
||||
var personalProjects []templates.Project
|
||||
{
|
||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
projects, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
Types: hmndata.PersonalProjects,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -181,13 +222,13 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
|||
// There are no further permission checks to do, because permissions are
|
||||
// checked whatever way we fetch the project.
|
||||
|
||||
owners, err := hmndata.FetchProjectOwners(c.Context(), c.Conn, c.CurrentProject.ID)
|
||||
owners, err := hmndata.FetchProjectOwners(c, c.Conn, c.CurrentProject.ID)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetching screenshots")
|
||||
screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||
screenshotFilenames, err := db.QueryScalar[string](c, c.Conn,
|
||||
`
|
||||
SELECT screenshot.file
|
||||
FROM
|
||||
|
@ -204,7 +245,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
|||
c.Perf.EndBlock()
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
projectLinks, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -221,12 +262,12 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
|||
c.Perf.EndBlock()
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetching project timeline")
|
||||
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
ProjectIDs: []int{c.CurrentProject.ID},
|
||||
Limit: maxRecentActivity,
|
||||
SortDescending: true,
|
||||
|
@ -241,7 +282,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
|||
Value: c.CurrentProject.Blurb,
|
||||
})
|
||||
|
||||
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{
|
||||
p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, c.CurrentProject.ID, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
|
@ -317,7 +358,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
|||
tagId = *c.CurrentProject.TagID
|
||||
}
|
||||
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
Tags: []int{tagId},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -364,7 +405,7 @@ type ProjectEditData struct {
|
|||
}
|
||||
|
||||
func ProjectNew(c *RequestContext) ResponseData {
|
||||
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
OwnerIDs: []int{c.CurrentUser.ID},
|
||||
Types: hmndata.PersonalProjects,
|
||||
})
|
||||
|
@ -372,7 +413,7 @@ func ProjectNew(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
|
||||
}
|
||||
if numProjects >= maxPersonalProjects {
|
||||
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
|
||||
return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
|
||||
}
|
||||
|
||||
var project templates.ProjectSettings
|
||||
|
@ -397,16 +438,16 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
|
||||
}
|
||||
if len(formResult.RejectionReason) != 0 {
|
||||
return RejectRequest(c, formResult.RejectionReason)
|
||||
return c.RejectRequest(formResult.RejectionReason)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
numProjects, err := hmndata.CountProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
numProjects, err := hmndata.CountProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
OwnerIDs: []int{c.CurrentUser.ID},
|
||||
Types: hmndata.PersonalProjects,
|
||||
})
|
||||
|
@ -414,11 +455,11 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check number of personal projects"))
|
||||
}
|
||||
if numProjects >= maxPersonalProjects {
|
||||
return RejectRequest(c, fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
|
||||
return c.RejectRequest(fmt.Sprintf("You have already reached the maximum of %d personal projects.", maxPersonalProjects))
|
||||
}
|
||||
|
||||
var projectId int
|
||||
err = tx.QueryRow(c.Context(),
|
||||
err = tx.QueryRow(c,
|
||||
`
|
||||
INSERT INTO project
|
||||
(name, blurb, description, descparsed, lifecycle, date_created, all_last_updated)
|
||||
|
@ -439,12 +480,12 @@ func ProjectNewSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
formResult.Payload.ProjectID = projectId
|
||||
|
||||
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload)
|
||||
err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
tx.Commit(c.Context())
|
||||
tx.Commit(c)
|
||||
|
||||
urlContext := &hmnurl.UrlContext{
|
||||
PersonalProject: true,
|
||||
|
@ -461,7 +502,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
p, err := hmndata.FetchProject(
|
||||
c.Context(), c.Conn,
|
||||
c, c.Conn,
|
||||
c.CurrentUser, c.CurrentProject.ID,
|
||||
hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
|
@ -473,7 +514,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
projectLinks, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -524,23 +565,23 @@ func ProjectEditSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, formResult.Error)
|
||||
}
|
||||
if len(formResult.RejectionReason) != 0 {
|
||||
return RejectRequest(c, formResult.RejectionReason)
|
||||
return c.RejectRequest(formResult.RejectionReason)
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to start db transaction"))
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
formResult.Payload.ProjectID = c.CurrentProject.ID
|
||||
|
||||
err = updateProject(c.Context(), tx, c.CurrentUser, &formResult.Payload)
|
||||
err = updateProject(c, tx, c.CurrentUser, &formResult.Payload)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
tx.Commit(c.Context())
|
||||
tx.Commit(c)
|
||||
|
||||
urlContext := &hmnurl.UrlContext{
|
||||
PersonalProject: formResult.Payload.Personal,
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
|
@ -43,15 +44,22 @@ func (r *Route) String() string {
|
|||
}
|
||||
|
||||
type RouteBuilder struct {
|
||||
Router *Router
|
||||
Prefixes []*regexp.Regexp
|
||||
Middleware Middleware
|
||||
Router *Router
|
||||
Prefixes []*regexp.Regexp
|
||||
Middlewares []Middleware
|
||||
}
|
||||
|
||||
type Handler func(c *RequestContext) ResponseData
|
||||
|
||||
type Middleware func(h Handler) Handler
|
||||
|
||||
func applyMiddlewares(h Handler, ms []Middleware) Handler {
|
||||
result := h
|
||||
for i := len(ms) - 1; i >= 0; i-- {
|
||||
result = ms[i](result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler) {
|
||||
// Ensure that this regex matches the start of the string
|
||||
regexStr := regex.String()
|
||||
|
@ -59,7 +67,7 @@ func (rb *RouteBuilder) Handle(methods []string, regex *regexp.Regexp, h Handler
|
|||
panic("All routing regexes must begin with '^'")
|
||||
}
|
||||
|
||||
h = rb.Middleware(h)
|
||||
h = applyMiddlewares(h, rb.Middlewares)
|
||||
for _, method := range methods {
|
||||
rb.Router.Routes = append(rb.Router.Routes, Route{
|
||||
Method: method,
|
||||
|
@ -81,10 +89,19 @@ func (rb *RouteBuilder) POST(regex *regexp.Regexp, h Handler) {
|
|||
rb.Handle([]string{http.MethodPost}, regex, h)
|
||||
}
|
||||
|
||||
func (rb *RouteBuilder) Group(regex *regexp.Regexp, addRoutes func(rb *RouteBuilder)) {
|
||||
func (rb *RouteBuilder) WithMiddleware(ms ...Middleware) RouteBuilder {
|
||||
newRb := *rb
|
||||
newRb.Middlewares = append(rb.Middlewares, ms...)
|
||||
|
||||
return newRb
|
||||
}
|
||||
|
||||
func (rb *RouteBuilder) Group(regex *regexp.Regexp, ms ...Middleware) RouteBuilder {
|
||||
newRb := *rb
|
||||
newRb.Prefixes = append(newRb.Prefixes, regex)
|
||||
addRoutes(&newRb)
|
||||
newRb.Middlewares = append(rb.Middlewares, ms...)
|
||||
|
||||
return newRb
|
||||
}
|
||||
|
||||
func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -138,6 +155,8 @@ nextroute:
|
|||
Req: req,
|
||||
Res: rw,
|
||||
PathParams: params,
|
||||
|
||||
ctx: req.Context(),
|
||||
}
|
||||
c.PathParams = params
|
||||
|
||||
|
@ -174,13 +193,33 @@ type RequestContext struct {
|
|||
ctx context.Context
|
||||
}
|
||||
|
||||
func (c *RequestContext) Context() context.Context {
|
||||
if c.ctx == nil {
|
||||
c.ctx = c.Req.Context()
|
||||
}
|
||||
return c.ctx
|
||||
// Our RequestContext is a context.Context
|
||||
|
||||
var _ context.Context = &RequestContext{}
|
||||
|
||||
func (c *RequestContext) Deadline() (time.Time, bool) {
|
||||
return c.ctx.Deadline()
|
||||
}
|
||||
|
||||
func (c *RequestContext) Done() <-chan struct{} {
|
||||
return c.ctx.Done()
|
||||
}
|
||||
|
||||
func (c *RequestContext) Err() error {
|
||||
return c.ctx.Err()
|
||||
}
|
||||
|
||||
func (c *RequestContext) Value(key any) any {
|
||||
switch key {
|
||||
case perf.PerfContextKey:
|
||||
return c.Perf
|
||||
default:
|
||||
return c.ctx.Value(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Plus it does many other things specific to us
|
||||
|
||||
func (c *RequestContext) URL() *url.URL {
|
||||
return c.Req.URL
|
||||
}
|
||||
|
@ -325,7 +364,7 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData {
|
|||
func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
LogContextErrors(c, errs...)
|
||||
logContextErrors(c, errs...)
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
@ -338,6 +377,23 @@ func (c *RequestContext) ErrorResponse(status int, errs ...error) ResponseData {
|
|||
return res
|
||||
}
|
||||
|
||||
func (c *RequestContext) RejectRequest(reason string) ResponseData {
|
||||
type RejectData struct {
|
||||
templates.BaseData
|
||||
RejectReason string
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
err := res.WriteTemplate("reject.html", RejectData{
|
||||
BaseData: getBaseData(c, "Rejected", nil),
|
||||
RejectReason: reason,
|
||||
}, c.Perf)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
type ResponseData struct {
|
||||
StatusCode int
|
||||
Body *bytes.Buffer
|
||||
|
|
|
@ -1,159 +1,46 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/auth"
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"git.handmade.network/hmn/hmn/src/email"
|
||||
"git.handmade.network/hmn/hmn/src/hmndata"
|
||||
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
"git.handmade.network/hmn/hmn/src/models"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"git.handmade.network/hmn/hmn/src/perf"
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
"git.handmade.network/hmn/hmn/src/utils"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/teacat/noire"
|
||||
)
|
||||
|
||||
func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) http.Handler {
|
||||
func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
|
||||
router := &Router{}
|
||||
routes := RouteBuilder{
|
||||
Router: router,
|
||||
Middleware: func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
c.Conn = conn
|
||||
|
||||
logPerf := TrackRequestPerf(c)
|
||||
defer logPerf()
|
||||
|
||||
defer LogContextErrorsFromResponse(c, &res)
|
||||
defer MiddlewarePanicCatcher(c, &res)
|
||||
|
||||
return h(c)
|
||||
}
|
||||
Middlewares: []Middleware{
|
||||
setDBConn(conn),
|
||||
trackRequestPerf,
|
||||
logContextErrorsMiddleware,
|
||||
panicCatcherMiddleware,
|
||||
},
|
||||
}
|
||||
|
||||
anyProject := routes
|
||||
anyProject.Middleware = func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
c.Conn = conn
|
||||
|
||||
logPerf := TrackRequestPerf(c)
|
||||
defer logPerf()
|
||||
|
||||
defer LogContextErrorsFromResponse(c, &res)
|
||||
defer MiddlewarePanicCatcher(c, &res)
|
||||
|
||||
defer storeNoticesInCookie(c, &res)
|
||||
|
||||
ok, errRes := LoadCommonWebsiteData(c)
|
||||
if !ok {
|
||||
return errRes
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
hmnOnly := routes
|
||||
hmnOnly.Middleware = func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
c.Conn = conn
|
||||
|
||||
logPerf := TrackRequestPerf(c)
|
||||
defer logPerf()
|
||||
|
||||
defer LogContextErrorsFromResponse(c, &res)
|
||||
defer MiddlewarePanicCatcher(c, &res)
|
||||
|
||||
defer storeNoticesInCookie(c, &res)
|
||||
|
||||
ok, errRes := LoadCommonWebsiteData(c)
|
||||
if !ok {
|
||||
return errRes
|
||||
}
|
||||
|
||||
if !c.CurrentProject.IsHMN() {
|
||||
return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
authMiddleware := func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
if c.CurrentUser == nil {
|
||||
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
adminMiddleware := func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
|
||||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
csrfMiddleware := func(h Handler) Handler {
|
||||
// CSRF mitigation actions per the OWASP cheat sheet:
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
|
||||
return func(c *RequestContext) ResponseData {
|
||||
c.Req.ParseMultipartForm(100 * 1024 * 1024)
|
||||
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
|
||||
if csrfToken != c.CurrentSession.CSRFToken {
|
||||
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
|
||||
|
||||
res := c.Redirect("/", http.StatusSeeOther)
|
||||
logoutUser(c, &res)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
securityTimerMiddleware := func(duration time.Duration, h Handler) Handler {
|
||||
// NOTE(asaf): Will make sure that the request takes at least `delayMs` to finish. Adds a 10% random duration.
|
||||
return func(c *RequestContext) ResponseData {
|
||||
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
|
||||
timer := time.NewTimer(duration + additionalDuration)
|
||||
res := h(c)
|
||||
select {
|
||||
case <-longRequestContext.Done():
|
||||
case <-c.Context().Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
return res
|
||||
}
|
||||
}
|
||||
anyProject := routes.WithMiddleware(
|
||||
storeNoticesInCookieMiddleware,
|
||||
loadCommonData,
|
||||
)
|
||||
hmnOnly := anyProject.WithMiddleware(
|
||||
redirectToHMN,
|
||||
)
|
||||
|
||||
routes.GET(hmnurl.RegexPublic, func(c *RequestContext) ResponseData {
|
||||
var res ResponseData
|
||||
http.StripPrefix("/public/", http.FileServer(http.Dir("public"))).ServeHTTP(&res, c.Req)
|
||||
AddCORSHeaders(c, &res)
|
||||
addCORSHeaders(c, &res)
|
||||
return res
|
||||
})
|
||||
routes.GET(hmnurl.RegexFishbowlFiles, FishbowlFiles)
|
||||
|
@ -189,10 +76,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
hmnOnly.POST(hmnurl.RegexDoPasswordReset, DoPasswordResetSubmit)
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexAdminAtomFeed, AdminAtomFeed)
|
||||
hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminMiddleware(AdminApprovalQueue))
|
||||
hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminMiddleware(csrfMiddleware(AdminApprovalQueueSubmit)))
|
||||
hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminMiddleware(csrfMiddleware(UserProfileAdminSetStatus)))
|
||||
hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminMiddleware(csrfMiddleware(UserProfileAdminNuke)))
|
||||
hmnOnly.GET(hmnurl.RegexAdminApprovalQueue, adminsOnly(AdminApprovalQueue))
|
||||
hmnOnly.POST(hmnurl.RegexAdminApprovalQueue, adminsOnly(csrfMiddleware(AdminApprovalQueueSubmit)))
|
||||
hmnOnly.POST(hmnurl.RegexAdminSetUserStatus, adminsOnly(csrfMiddleware(UserProfileAdminSetStatus)))
|
||||
hmnOnly.POST(hmnurl.RegexAdminNukeUser, adminsOnly(csrfMiddleware(UserProfileAdminNuke)))
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexFeed, Feed)
|
||||
hmnOnly.GET(hmnurl.RegexAtomFeed, AtomFeed)
|
||||
|
@ -200,19 +87,19 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
hmnOnly.GET(hmnurl.RegexSnippet, Snippet)
|
||||
hmnOnly.GET(hmnurl.RegexProjectIndex, ProjectIndex)
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexProjectNew, authMiddleware(ProjectNew))
|
||||
hmnOnly.POST(hmnurl.RegexProjectNew, authMiddleware(csrfMiddleware(ProjectNewSubmit)))
|
||||
hmnOnly.GET(hmnurl.RegexProjectNew, needsAuth(ProjectNew))
|
||||
hmnOnly.POST(hmnurl.RegexProjectNew, needsAuth(csrfMiddleware(ProjectNewSubmit)))
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback))
|
||||
hmnOnly.POST(hmnurl.RegexDiscordUnlink, authMiddleware(csrfMiddleware(DiscordUnlink)))
|
||||
hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, authMiddleware(csrfMiddleware(DiscordShowcaseBacklog)))
|
||||
hmnOnly.GET(hmnurl.RegexDiscordOAuthCallback, needsAuth(DiscordOAuthCallback))
|
||||
hmnOnly.POST(hmnurl.RegexDiscordUnlink, needsAuth(csrfMiddleware(DiscordUnlink)))
|
||||
hmnOnly.POST(hmnurl.RegexDiscordShowcaseBacklog, needsAuth(csrfMiddleware(DiscordShowcaseBacklog)))
|
||||
|
||||
hmnOnly.POST(hmnurl.RegexTwitchEventSubCallback, TwitchEventSubCallback)
|
||||
hmnOnly.GET(hmnurl.RegexTwitchDebugPage, TwitchDebugPage)
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexUserProfile, UserProfile)
|
||||
hmnOnly.GET(hmnurl.RegexUserSettings, authMiddleware(UserSettings))
|
||||
hmnOnly.POST(hmnurl.RegexUserSettings, authMiddleware(csrfMiddleware(UserSettingsSave)))
|
||||
hmnOnly.GET(hmnurl.RegexUserSettings, needsAuth(UserSettings))
|
||||
hmnOnly.POST(hmnurl.RegexUserSettings, needsAuth(csrfMiddleware(UserSettingsSave)))
|
||||
|
||||
hmnOnly.GET(hmnurl.RegexPodcast, PodcastIndex)
|
||||
hmnOnly.GET(hmnurl.RegexPodcastEdit, PodcastEdit)
|
||||
|
@ -231,6 +118,9 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
|
||||
hmnOnly.GET(hmnurl.RegexLibraryAny, LibraryNotPortedYet)
|
||||
|
||||
// Project routes can appear either at the root (e.g. hero.handmade.network/edit)
|
||||
// or on a personal project path (e.g. handmade.network/p/123/hero/edit). So, we
|
||||
// have pulled all those routes into this function.
|
||||
attachProjectRoutes := func(rb *RouteBuilder) {
|
||||
rb.GET(hmnurl.RegexHomepage, func(c *RequestContext) ResponseData {
|
||||
if c.CurrentProject.IsHMN() {
|
||||
|
@ -240,8 +130,8 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
}
|
||||
})
|
||||
|
||||
rb.GET(hmnurl.RegexProjectEdit, authMiddleware(ProjectEdit))
|
||||
rb.POST(hmnurl.RegexProjectEdit, authMiddleware(csrfMiddleware(ProjectEditSubmit)))
|
||||
rb.GET(hmnurl.RegexProjectEdit, needsAuth(ProjectEdit))
|
||||
rb.POST(hmnurl.RegexProjectEdit, needsAuth(csrfMiddleware(ProjectEditSubmit)))
|
||||
|
||||
// Middleware used for forum action routes - anything related to actually creating or editing forum content
|
||||
needsForums := func(h Handler) Handler {
|
||||
|
@ -251,14 +141,14 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
return FourOhFour(c)
|
||||
}
|
||||
// Require auth if forums are enabled
|
||||
return authMiddleware(h)(c)
|
||||
return needsAuth(h)(c)
|
||||
}
|
||||
}
|
||||
rb.POST(hmnurl.RegexForumNewThreadSubmit, needsForums(csrfMiddleware(ForumNewThreadSubmit)))
|
||||
rb.GET(hmnurl.RegexForumNewThread, needsForums(ForumNewThread))
|
||||
rb.GET(hmnurl.RegexForumThread, ForumThread)
|
||||
rb.GET(hmnurl.RegexForum, Forum)
|
||||
rb.POST(hmnurl.RegexForumMarkRead, authMiddleware(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled
|
||||
rb.POST(hmnurl.RegexForumMarkRead, needsAuth(csrfMiddleware(ForumMarkRead))) // needs auth but doesn't need forums enabled
|
||||
rb.GET(hmnurl.RegexForumPost, ForumPostRedirect)
|
||||
rb.GET(hmnurl.RegexForumPostReply, needsForums(ForumPostReply))
|
||||
rb.POST(hmnurl.RegexForumPostReply, needsForums(csrfMiddleware(ForumPostReplySubmit)))
|
||||
|
@ -276,7 +166,7 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
return FourOhFour(c)
|
||||
}
|
||||
// Require auth if blogs are enabled
|
||||
return authMiddleware(h)(c)
|
||||
return needsAuth(h)(c)
|
||||
}
|
||||
}
|
||||
rb.GET(hmnurl.RegexBlog, BlogIndex)
|
||||
|
@ -296,64 +186,10 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
), http.StatusMovedPermanently)
|
||||
})
|
||||
}
|
||||
hmnOnly.Group(hmnurl.RegexPersonalProject, func(rb *RouteBuilder) {
|
||||
// TODO(ben): Perhaps someday we can make this middleware modification feel better? It seems
|
||||
// pretty common to run the outermost middleware first before doing other stuff, but having
|
||||
// to nest functions this way feels real bad.
|
||||
rb.Middleware = func(h Handler) Handler {
|
||||
return hmnOnly.Middleware(func(c *RequestContext) ResponseData {
|
||||
// At this point we are definitely on the plain old HMN subdomain.
|
||||
|
||||
// Fetch personal project and do whatever
|
||||
id, err := strconv.Atoi(c.PathParams["projectid"])
|
||||
if err != nil {
|
||||
panic(oops.New(err, "project id was not numeric (bad regex in routing)"))
|
||||
}
|
||||
p, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
return FourOhFour(c)
|
||||
} else {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
|
||||
}
|
||||
}
|
||||
|
||||
c.CurrentProject = &p.Project
|
||||
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
|
||||
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
|
||||
|
||||
if !p.Project.Personal {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(p.Project.Name) {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
})
|
||||
}
|
||||
attachProjectRoutes(rb)
|
||||
})
|
||||
anyProject.Group(regexp.MustCompile("^"), func(rb *RouteBuilder) {
|
||||
rb.Middleware = func(h Handler) Handler {
|
||||
return anyProject.Middleware(func(c *RequestContext) ResponseData {
|
||||
// We could be on any project's subdomain.
|
||||
|
||||
// Check if the current project (matched by subdomain) is actually no longer official
|
||||
// and therefore needs to be redirected to the personal project version of the route.
|
||||
if c.CurrentProject.Personal {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
return h(c)
|
||||
})
|
||||
}
|
||||
attachProjectRoutes(rb)
|
||||
})
|
||||
officialProjectRoutes := anyProject.WithMiddleware(officialProjectMiddleware)
|
||||
personalProjectRoutes := hmnOnly.Group(hmnurl.RegexPersonalProject, personalProjectMiddleware)
|
||||
attachProjectRoutes(&officialProjectRoutes)
|
||||
attachProjectRoutes(&personalProjectRoutes)
|
||||
|
||||
anyProject.POST(hmnurl.RegexAssetUpload, AssetUpload)
|
||||
|
||||
|
@ -375,318 +211,69 @@ func NewWebsiteRoutes(longRequestContext context.Context, conn *pgxpool.Pool) ht
|
|||
return router
|
||||
}
|
||||
|
||||
func ProjectCSS(c *RequestContext) ResponseData {
|
||||
color := c.URL().Query().Get("color")
|
||||
if color == "" {
|
||||
return c.ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
|
||||
}
|
||||
|
||||
baseData := getBaseData(c, "", nil)
|
||||
|
||||
bgColor := noire.NewHex(color)
|
||||
h, s, l := bgColor.HSL()
|
||||
if baseData.Theme == "dark" {
|
||||
l = 15
|
||||
} else {
|
||||
l = 95
|
||||
}
|
||||
if s > 20 {
|
||||
s = 20
|
||||
}
|
||||
bgColor = noire.NewHSL(h, s, l)
|
||||
|
||||
templateData := struct {
|
||||
templates.BaseData
|
||||
Color string
|
||||
PostBgColor string
|
||||
}{
|
||||
BaseData: baseData,
|
||||
Color: color,
|
||||
PostBgColor: bgColor.HTML(),
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
res.Header().Add("Content-Type", "text/css")
|
||||
err := res.WriteTemplate("project.css", templateData, c.Perf)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func FourOhFour(c *RequestContext) ResponseData {
|
||||
var res ResponseData
|
||||
res.StatusCode = http.StatusNotFound
|
||||
|
||||
if c.Req.Header["Accept"] != nil && strings.Contains(c.Req.Header["Accept"][0], "text/html") {
|
||||
templateData := struct {
|
||||
templates.BaseData
|
||||
Wanted string
|
||||
}{
|
||||
BaseData: getBaseData(c, "Page not found", nil),
|
||||
Wanted: c.FullUrl(),
|
||||
func setDBConn(conn *pgxpool.Pool) Middleware {
|
||||
return func(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
c.Conn = conn
|
||||
return h(c)
|
||||
}
|
||||
res.MustWriteTemplate("404.html", templateData, c.Perf)
|
||||
} else {
|
||||
res.Write([]byte("Not Found"))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
type RejectData struct {
|
||||
templates.BaseData
|
||||
RejectReason string
|
||||
}
|
||||
|
||||
func RejectRequest(c *RequestContext, reason string) ResponseData {
|
||||
var res ResponseData
|
||||
err := res.WriteTemplate("reject.html", RejectData{
|
||||
BaseData: getBaseData(c, "Rejected", nil),
|
||||
RejectReason: reason,
|
||||
}, c.Perf)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "Failed to render reject template"))
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func LoadCommonWebsiteData(c *RequestContext) (bool, ResponseData) {
|
||||
c.Perf.StartBlock("MIDDLEWARE", "Load common website data")
|
||||
defer c.Perf.EndBlock()
|
||||
|
||||
// get user
|
||||
{
|
||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||
if err == nil {
|
||||
user, session, err := getCurrentUserAndSession(c, sessionCookie.Value)
|
||||
if err != nil {
|
||||
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user"))
|
||||
}
|
||||
|
||||
c.CurrentUser = user
|
||||
c.CurrentSession = session
|
||||
func redirectToHMN(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
if !c.CurrentProject.IsHMN() {
|
||||
return c.Redirect(hmnurl.Url(c.URL().Path, hmnurl.QFromURL(c.URL())), http.StatusMovedPermanently)
|
||||
}
|
||||
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
|
||||
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
// get official project
|
||||
{
|
||||
hostPrefix := strings.TrimSuffix(c.Req.Host, hmnurl.GetBaseHost())
|
||||
slug := strings.TrimRight(hostPrefix, ".")
|
||||
var owners []*models.User
|
||||
func officialProjectMiddleware(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
// Check if the current project (matched by subdomain) is actually no longer official
|
||||
// and therefore needs to be redirected to the personal project version of the route.
|
||||
if c.CurrentProject.Personal {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
if len(slug) > 0 {
|
||||
dbProject, err := hmndata.FetchProjectBySlug(c.Context(), c.Conn, c.CurrentUser, slug, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err == nil {
|
||||
c.CurrentProject = &dbProject.Project
|
||||
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
|
||||
owners = dbProject.Owners
|
||||
return h(c)
|
||||
}
|
||||
}
|
||||
|
||||
func personalProjectMiddleware(h Handler) Handler {
|
||||
return func(c *RequestContext) ResponseData {
|
||||
hmnProject := c.CurrentProject
|
||||
|
||||
id := utils.Must1(strconv.Atoi(c.PathParams["projectid"]))
|
||||
p, err := hmndata.FetchProject(c, c.Conn, c.CurrentUser, id, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
return FourOhFour(c)
|
||||
} else {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
// do nothing, this is fine
|
||||
} else {
|
||||
return false, c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
|
||||
}
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch personal project"))
|
||||
}
|
||||
}
|
||||
|
||||
if c.CurrentProject == nil {
|
||||
dbProject, err := hmndata.FetchProject(c.Context(), c.Conn, c.CurrentUser, models.HMNProjectID, hmndata.ProjectsQuery{
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(oops.New(err, "failed to fetch HMN project"))
|
||||
}
|
||||
c.CurrentProject = &dbProject.Project
|
||||
c.CurrentProjectLogoUrl = templates.ProjectLogoUrl(&dbProject.Project, dbProject.LogoLightAsset, dbProject.LogoDarkAsset, c.Theme)
|
||||
}
|
||||
|
||||
if c.CurrentProject == nil {
|
||||
panic("failed to load project data")
|
||||
}
|
||||
|
||||
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, owners)
|
||||
c.CurrentProject = &p.Project
|
||||
c.CurrentProject.Color1 = hmnProject.Color1
|
||||
c.CurrentProject.Color2 = hmnProject.Color2
|
||||
|
||||
c.UrlContext = hmndata.UrlContextForProject(c.CurrentProject)
|
||||
}
|
||||
c.CurrentUserCanEditCurrentProject = CanEditProject(c.CurrentUser, p.Owners)
|
||||
|
||||
c.Theme = "light"
|
||||
if c.CurrentUser != nil && c.CurrentUser.DarkTheme {
|
||||
c.Theme = "dark"
|
||||
}
|
||||
|
||||
return true, ResponseData{}
|
||||
}
|
||||
|
||||
func AddCORSHeaders(c *RequestContext, res *ResponseData) {
|
||||
parsed, err := url.Parse(config.Config.BaseUrl)
|
||||
if err != nil {
|
||||
c.Logger.Error().Str("Config.BaseUrl", config.Config.BaseUrl).Msg("Config.BaseUrl cannot be parsed. Skipping CORS headers")
|
||||
return
|
||||
}
|
||||
origin := ""
|
||||
origins, found := c.Req.Header["Origin"]
|
||||
if found {
|
||||
origin = origins[0]
|
||||
}
|
||||
if strings.HasSuffix(origin, parsed.Host) {
|
||||
res.Header().Add("Access-Control-Allow-Origin", origin)
|
||||
res.Header().Add("Access-Control-Allow-Credentials", "true")
|
||||
res.Header().Add("Vary", "Origin")
|
||||
}
|
||||
}
|
||||
|
||||
// Given a session id, fetches user data from the database. Will return nil if
|
||||
// the user cannot be found, and will only return an error if it's serious.
|
||||
func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User, *models.Session, error) {
|
||||
session, err := auth.GetSession(c.Context(), c.Conn, sessionId)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrNoSession) {
|
||||
return nil, nil, nil
|
||||
} else {
|
||||
return nil, nil, oops.New(err, "failed to get current session")
|
||||
}
|
||||
}
|
||||
|
||||
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, nil, session.Username, hmndata.UsersQuery{
|
||||
AnyStatus: true,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
|
||||
return nil, nil, nil // user was deleted or something
|
||||
} else {
|
||||
return nil, nil, oops.New(err, "failed to get user for session")
|
||||
}
|
||||
}
|
||||
|
||||
return user, session, nil
|
||||
}
|
||||
|
||||
func TrackRequestPerf(c *RequestContext) (after func()) {
|
||||
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
|
||||
c.ctx = context.WithValue(c.Context(), perf.PerfContextKey, c.Perf)
|
||||
|
||||
return func() {
|
||||
c.Perf.EndRequest()
|
||||
log := logging.Info()
|
||||
blockStack := make([]time.Time, 0)
|
||||
for i, block := range c.Perf.Blocks {
|
||||
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
|
||||
blockStack = blockStack[:len(blockStack)-1]
|
||||
}
|
||||
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
|
||||
blockStack = append(blockStack, block.End)
|
||||
}
|
||||
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
|
||||
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
|
||||
}
|
||||
}
|
||||
|
||||
func LogContextErrors(c *RequestContext, errs ...error) {
|
||||
for _, err := range errs {
|
||||
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
|
||||
}
|
||||
}
|
||||
|
||||
func LogContextErrorsFromResponse(c *RequestContext, res *ResponseData) {
|
||||
LogContextErrors(c, res.Errors...)
|
||||
}
|
||||
|
||||
func MiddlewarePanicCatcher(c *RequestContext, res *ResponseData) {
|
||||
if recovered := recover(); recovered != nil {
|
||||
maybeError, ok := recovered.(*error)
|
||||
var err error
|
||||
if ok {
|
||||
err = *maybeError
|
||||
} else {
|
||||
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
|
||||
}
|
||||
*res = c.ErrorResponse(http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
const NoticesCookieName = "hmn_notices"
|
||||
|
||||
func getNoticesFromCookie(c *RequestContext) []templates.Notice {
|
||||
cookie, err := c.Req.Cookie(NoticesCookieName)
|
||||
if err != nil {
|
||||
if !errors.Is(err, http.ErrNoCookie) {
|
||||
c.Logger.Warn().Err(err).Msg("failed to get notices cookie")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return deserializeNoticesFromCookie(cookie.Value)
|
||||
}
|
||||
|
||||
func storeNoticesInCookie(c *RequestContext, res *ResponseData) {
|
||||
serialized := serializeNoticesForCookie(c, res.FutureNotices)
|
||||
if serialized != "" {
|
||||
noticesCookie := http.Cookie{
|
||||
Name: NoticesCookieName,
|
||||
Value: serialized,
|
||||
Path: "/",
|
||||
Domain: config.Config.Auth.CookieDomain,
|
||||
Expires: time.Now().Add(time.Minute * 5),
|
||||
Secure: config.Config.Auth.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
res.SetCookie(¬icesCookie)
|
||||
} else if !(res.StatusCode >= 300 && res.StatusCode < 400) {
|
||||
// NOTE(asaf): Don't clear on redirect
|
||||
noticesCookie := http.Cookie{
|
||||
Name: NoticesCookieName,
|
||||
Path: "/",
|
||||
Domain: config.Config.Auth.CookieDomain,
|
||||
MaxAge: -1,
|
||||
}
|
||||
res.SetCookie(¬icesCookie)
|
||||
}
|
||||
}
|
||||
|
||||
func serializeNoticesForCookie(c *RequestContext, notices []templates.Notice) string {
|
||||
var builder strings.Builder
|
||||
maxSize := 1024 // NOTE(asaf): Make sure we don't use too much space for notices.
|
||||
size := 0
|
||||
for i, notice := range notices {
|
||||
sizeIncrease := len(notice.Class) + len(string(notice.Content)) + 1
|
||||
if i != 0 {
|
||||
sizeIncrease += 1
|
||||
}
|
||||
if size+sizeIncrease > maxSize {
|
||||
c.Logger.Warn().Interface("Notices", notices).Msg("Notices too big for cookie")
|
||||
break
|
||||
if !c.CurrentProject.Personal {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
builder.WriteString("\t")
|
||||
if c.PathParams["projectslug"] != models.GeneratePersonalProjectSlug(c.CurrentProject.Name) {
|
||||
return c.Redirect(c.UrlContext.RewriteProjectUrl(c.URL()), http.StatusSeeOther)
|
||||
}
|
||||
builder.WriteString(notice.Class)
|
||||
builder.WriteString("|")
|
||||
builder.WriteString(string(notice.Content))
|
||||
|
||||
size += sizeIncrease
|
||||
return h(c)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func deserializeNoticesFromCookie(cookieVal string) []templates.Notice {
|
||||
var result []templates.Notice
|
||||
notices := strings.Split(cookieVal, "\t")
|
||||
for _, notice := range notices {
|
||||
parts := strings.SplitN(notice, "|", 2)
|
||||
if len(parts) == 2 {
|
||||
result = append(result, templates.Notice{
|
||||
Class: parts[0],
|
||||
Content: template.HTML(parts[1]),
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ func TestLogContextErrors(t *testing.T) {
|
|||
Middleware: func(h Handler) Handler {
|
||||
return func(c *RequestContext) (res ResponseData) {
|
||||
c.Logger = &logger
|
||||
defer LogContextErrorsFromResponse(c, &res)
|
||||
defer logContextErrorsMiddleware(c, &res)
|
||||
return h(c)
|
||||
}
|
||||
},
|
||||
|
|
|
@ -16,7 +16,7 @@ type ShowcaseData struct {
|
|||
}
|
||||
|
||||
func Showcase(c *RequestContext) ResponseData {
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{})
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{})
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippets"))
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ func Snippet(c *RequestContext) ResponseData {
|
|||
return FourOhFour(c)
|
||||
}
|
||||
|
||||
s, err := hmndata.FetchSnippet(c.Context(), c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{})
|
||||
s, err := hmndata.FetchSnippet(c, c.Conn, c.CurrentUser, snippetId, hmndata.SnippetQuery{})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
return FourOhFour(c)
|
||||
|
|
|
@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
func TwitchDebugPage(c *RequestContext) ResponseData {
|
||||
streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn,
|
||||
streams, err := db.Query[models.TwitchStream](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
|
|
@ -52,7 +52,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
|||
if c.CurrentUser != nil && strings.ToLower(c.CurrentUser.Username) == username {
|
||||
profileUser = c.CurrentUser
|
||||
} else {
|
||||
user, err := hmndata.FetchUserByUsername(c.Context(), c.Conn, c.CurrentUser, username, hmndata.UsersQuery{})
|
||||
user, err := hmndata.FetchUserByUsername(c, c.Conn, c.CurrentUser, username, hmndata.UsersQuery{})
|
||||
if err != nil {
|
||||
if errors.Is(err, db.NotFound) {
|
||||
return FourOhFour(c)
|
||||
|
@ -72,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch user links")
|
||||
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
userLinks, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -92,7 +92,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
|||
}
|
||||
c.Perf.EndBlock()
|
||||
|
||||
projectsAndStuff, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
projectsAndStuff, err := hmndata.FetchProjects(c, c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
OwnerIDs: []int{profileUser.ID},
|
||||
Lifecycles: models.AllProjectLifecycles,
|
||||
IncludeHidden: true,
|
||||
|
@ -111,13 +111,13 @@ func UserProfile(c *RequestContext) ResponseData {
|
|||
c.Perf.EndBlock()
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch posts")
|
||||
posts, err := hmndata.FetchPosts(c.Context(), c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
posts, err := hmndata.FetchPosts(c, c.Conn, c.CurrentUser, hmndata.PostsQuery{
|
||||
UserIDs: []int{profileUser.ID},
|
||||
SortDescending: true,
|
||||
})
|
||||
c.Perf.EndBlock()
|
||||
|
||||
snippets, err := hmndata.FetchSnippets(c.Context(), c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
snippets, err := hmndata.FetchSnippets(c, c.Conn, c.CurrentUser, hmndata.SnippetQuery{
|
||||
OwnerIDs: []int{profileUser.ID},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -125,7 +125,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
c.Perf.StartBlock("SQL", "Fetch subforum tree")
|
||||
subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn)
|
||||
subforumTree := models.GetFullSubforumTree(c, c.Conn)
|
||||
lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree)
|
||||
c.Perf.EndBlock()
|
||||
|
||||
|
@ -213,7 +213,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
|||
DiscordShowcaseBacklogUrl string
|
||||
}
|
||||
|
||||
links, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
links, err := db.Query[models.Link](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM link
|
||||
|
@ -230,7 +230,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
|||
|
||||
var tduser *templates.DiscordUser
|
||||
var numUnsavedMessages int
|
||||
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||
duser, err := db.QueryOne[models.DiscordUser](c, c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM discord_user
|
||||
|
@ -246,7 +246,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
|||
tmp := templates.DiscordUserToTemplate(duser)
|
||||
tduser = &tmp
|
||||
|
||||
numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||
numUnsavedMessages, err = db.QueryOneScalar[int](c, c.Conn,
|
||||
`
|
||||
SELECT COUNT(*)
|
||||
FROM
|
||||
|
@ -299,11 +299,11 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse form"))
|
||||
}
|
||||
|
||||
tx, err := c.Conn.Begin(c.Context())
|
||||
tx, err := c.Conn.Begin(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer tx.Rollback(c.Context())
|
||||
defer tx.Rollback(c)
|
||||
|
||||
form, err := c.GetFormValues()
|
||||
if err != nil {
|
||||
|
@ -315,7 +315,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
|
||||
email := form.Get("email")
|
||||
if !hmnemail.IsEmail(email) {
|
||||
return RejectRequest(c, "Your email was not valid.")
|
||||
return c.RejectRequest("Your email was not valid.")
|
||||
}
|
||||
|
||||
showEmail := form.Get("showemail") != ""
|
||||
|
@ -328,7 +328,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
discordShowcaseAuto := form.Get("discord-showcase-auto") != ""
|
||||
discordDeleteSnippetOnMessageDelete := form.Get("discord-snippet-keep") == ""
|
||||
|
||||
_, err = tx.Exec(c.Context(),
|
||||
_, err = tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET
|
||||
|
@ -360,15 +360,15 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
}
|
||||
|
||||
// Process links
|
||||
twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil)
|
||||
twitchLoginsPreChange, preErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
|
||||
linksText := form.Get("links")
|
||||
links := ParseLinks(linksText)
|
||||
_, err = tx.Exec(c.Context(), `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID)
|
||||
_, err = tx.Exec(c, `DELETE FROM link WHERE user_id = $1`, c.CurrentUser.ID)
|
||||
if err != nil {
|
||||
c.Logger.Warn().Err(err).Msg("failed to delete old links")
|
||||
} else {
|
||||
for i, link := range links {
|
||||
_, err := tx.Exec(c.Context(),
|
||||
_, err := tx.Exec(c,
|
||||
`
|
||||
INSERT INTO link (name, url, ordering, user_id)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
|
@ -384,7 +384,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
}
|
||||
twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c.Context(), tx, &c.CurrentUser.ID, nil)
|
||||
twitchLoginsPostChange, postErr := hmndata.FetchTwitchLoginsForUserOrProject(c, tx, &c.CurrentUser.ID, nil)
|
||||
if preErr == nil && postErr == nil {
|
||||
twitch.UserOrProjectLinksUpdated(twitchLoginsPreChange, twitchLoginsPostChange)
|
||||
}
|
||||
|
@ -407,7 +407,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
}
|
||||
var avatarUUID *uuid.UUID
|
||||
if newAvatar.Exists {
|
||||
avatarAsset, err := assets.Create(c.Context(), tx, assets.CreateInput{
|
||||
avatarAsset, err := assets.Create(c, tx, assets.CreateInput{
|
||||
Content: newAvatar.Content,
|
||||
Filename: newAvatar.Filename,
|
||||
ContentType: newAvatar.Mime,
|
||||
|
@ -421,7 +421,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
avatarUUID = &avatarAsset.ID
|
||||
}
|
||||
if newAvatar.Exists || newAvatar.Remove {
|
||||
_, err := tx.Exec(c.Context(),
|
||||
_, err := tx.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET
|
||||
|
@ -437,7 +437,7 @@ func UserSettingsSave(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
|
||||
err = tx.Commit(c.Context())
|
||||
err = tx.Commit(c)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save user settings"))
|
||||
}
|
||||
|
@ -454,7 +454,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
|
|||
userIdStr := c.Req.Form.Get("user_id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
return RejectRequest(c, "No user id provided")
|
||||
return c.RejectRequest("No user id provided")
|
||||
}
|
||||
|
||||
status := c.Req.Form.Get("status")
|
||||
|
@ -469,10 +469,10 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
|
|||
case "banned":
|
||||
desiredStatus = models.UserStatusBanned
|
||||
default:
|
||||
return RejectRequest(c, "No legal user status provided")
|
||||
return c.RejectRequest("No legal user status provided")
|
||||
}
|
||||
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
_, err = c.Conn.Exec(c,
|
||||
`
|
||||
UPDATE hmn_user
|
||||
SET status = $1
|
||||
|
@ -485,7 +485,7 @@ func UserProfileAdminSetStatus(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update user status"))
|
||||
}
|
||||
if desiredStatus == models.UserStatusBanned {
|
||||
err = auth.DeleteSessionForUser(c.Context(), c.Conn, c.Req.Form.Get("username"))
|
||||
err = auth.DeleteSessionForUser(c, c.Conn, c.Req.Form.Get("username"))
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to log out user"))
|
||||
}
|
||||
|
@ -500,10 +500,10 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
|
|||
userIdStr := c.Req.Form.Get("user_id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
return RejectRequest(c, "No user id provided")
|
||||
return c.RejectRequest("No user id provided")
|
||||
}
|
||||
|
||||
err = deleteAllPostsForUser(c.Context(), c.Conn, userId)
|
||||
err = deleteAllPostsForUser(c, c.Conn, userId)
|
||||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to delete user posts"))
|
||||
}
|
||||
|
@ -514,7 +514,7 @@ func UserProfileAdminNuke(c *RequestContext) ResponseData {
|
|||
|
||||
func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *ResponseData {
|
||||
if new != confirm {
|
||||
res := RejectRequest(c, "Your password and password confirmation did not match.")
|
||||
res := c.RejectRequest("Your password and password confirmation did not match.")
|
||||
return &res
|
||||
}
|
||||
|
||||
|
@ -531,12 +531,12 @@ func updatePassword(c *RequestContext, tx pgx.Tx, old, new, confirm string) *Res
|
|||
}
|
||||
|
||||
if !ok {
|
||||
res := RejectRequest(c, "The old password you provided was not correct.")
|
||||
res := c.RejectRequest("The old password you provided was not correct.")
|
||||
return &res
|
||||
}
|
||||
|
||||
newHashedPassword := auth.HashPassword(new)
|
||||
err = auth.UpdatePassword(c.Context(), tx, c.CurrentUser.Username, newHashedPassword)
|
||||
err = auth.UpdatePassword(c, tx, c.CurrentUser.Username, newHashedPassword)
|
||||
if err != nil {
|
||||
res := c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to update password"))
|
||||
return &res
|
||||
|
|
|
@ -33,14 +33,13 @@ var WebsiteCommand = &cobra.Command{
|
|||
logging.Info().Msg("Hello, HMN!")
|
||||
|
||||
backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background())
|
||||
longRequestContext, cancelLongRequests := context.WithCancel(context.Background())
|
||||
|
||||
conn := db.NewConnPool()
|
||||
perfCollector := perf.RunPerfCollector(backgroundJobContext)
|
||||
|
||||
server := http.Server{
|
||||
Addr: config.Config.Addr,
|
||||
Handler: NewWebsiteRoutes(longRequestContext, conn),
|
||||
Handler: NewWebsiteRoutes(conn),
|
||||
}
|
||||
|
||||
backgroundJobsDone := jobs.Zip(
|
||||
|
@ -59,8 +58,6 @@ var WebsiteCommand = &cobra.Command{
|
|||
<-signals
|
||||
logging.Info().Msg("Shutting down the website")
|
||||
go func() {
|
||||
logging.Info().Msg("cancelling long requests")
|
||||
cancelLongRequests()
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
logging.Info().Msg("shutting down web server")
|
||||
|
|
Loading…
Reference in New Issue