diff --git a/src/admintools/admintools.go b/src/admintools/admintools.go index 63a3993..6f76b28 100644 --- a/src/admintools/admintools.go +++ b/src/admintools/admintools.go @@ -86,7 +86,7 @@ func init() { conn := db.NewConnPool(1, 1) defer conn.Close() - res, err := conn.Exec(ctx, "UPDATE auth_user SET status = $1 WHERE LOWER(username) = LOWER($2);", models.UserStatusActive, username) + res, err := conn.Exec(ctx, "UPDATE auth_user SET status = $1 WHERE LOWER(username) = LOWER($2);", models.UserStatusConfirmed, username) if err != nil { panic(err) } diff --git a/src/auth/session.go b/src/auth/session.go index 0e7e537..0ac2d05 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -47,7 +47,7 @@ var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return nil, ErrNoSession } else { return nil, oops.New(err, "failed to get session") diff --git a/src/db/db.go b/src/db/db.go index cf76995..69defd5 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -269,9 +269,8 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix strin return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) } - for i := 0; i < destType.NumField(); i++ { - field := destType.Field(i) - path := append(pathSoFar, i) + for _, field := range reflect.VisibleFields(destType) { + path := append(pathSoFar, field.Index...) if columnName := field.Tag.Get("db"); columnName != "" { fieldType := field.Type @@ -298,7 +297,12 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix strin return columnNames, fieldPaths, nil } -var ErrNoMatchingRows = errors.New("no matching rows") +/* +A general error to be used when no results are found. This is the error returned +by QueryOne, and can generally be used by other database helpers that fetch a single +result but find nothing. +*/ +var NotFound = errors.New("not found") func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { rows, err := Query(ctx, conn, destExample, query, args...) @@ -309,7 +313,7 @@ func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query result, hasRow := rows.Next() if !hasRow { - return nil, ErrNoMatchingRows + return nil, NotFound } return result, nil @@ -335,7 +339,7 @@ func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...inter return vals[0], nil } - return nil, ErrNoMatchingRows + return nil, NotFound } func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { diff --git a/src/db/db_test.go b/src/db/db_test.go index a969a47..41b9491 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -26,8 +26,12 @@ func TestPaths(t *testing.T) { NoTag S } + type Embedded struct { + NoTag S + Nested + } - names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Nested{}), nil, "") + names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") if assert.Nil(t, err) { assert.Equal(t, []string{ "S.I", "S.PI", @@ -38,16 +42,39 @@ func TestPaths(t *testing.T) { "PS.B", "PS.PB", }, names) assert.Equal(t, [][]int{ - {0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}, - {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 5}, + {1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, + {1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, }, paths) assert.True(t, len(names) == len(paths)) } - testStruct := Nested{} + testStruct := Embedded{} for i, path := range paths { val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) assert.True(t, val.IsValid()) assert.True(t, strings.Contains(names[i], field.Name)) } } + +func TestQueryBuilder(t *testing.T) { + t.Run("happy time", func(t *testing.T) { + var qb QueryBuilder + qb.Add("SELECT stuff FROM thing WHERE foo = $? AND bar = $?", 3, "hello") + qb.Add("AND (baz = $?)", true) + + assert.Equal(t, "SELECT stuff FROM thing WHERE foo = $1 AND bar = $2\nAND (baz = $3)\n", qb.String()) + assert.Equal(t, []interface{}{3, "hello", true}, qb.Args()) + }) + t.Run("too few arguments", func(t *testing.T) { + var qb QueryBuilder + assert.Panics(t, func() { + qb.Add("HELLO $? $? $?", 1, 2) + }) + }) + t.Run("too many arguments", func(t *testing.T) { + var qb QueryBuilder + assert.Panics(t, func() { + qb.Add("HELLO $? $? $?", 1, 2, 3, 4) + }) + }) +} diff --git a/src/db/query_builder.go b/src/db/query_builder.go new file mode 100644 index 0000000..1f794c9 --- /dev/null +++ b/src/db/query_builder.go @@ -0,0 +1,42 @@ +package db + +import ( + "fmt" + "strings" +) + +type QueryBuilder struct { + sql strings.Builder + args []interface{} +} + +/* +Adds the given SQL and arguments to the query. Any occurrences +of `$?` will be replaced with the correct argument number. + +foo $? bar $? baz $? +foo ARG1 bar ARG2 baz $? +foo ARG1 bar ARG2 baz ARG3 +*/ +func (qb *QueryBuilder) Add(sql string, args ...interface{}) { + numPlaceholders := strings.Count(sql, "$?") + if numPlaceholders != len(args) { + panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args))) + } + + for _, arg := range args { + sql = strings.Replace(sql, "$?", fmt.Sprintf("$%d", len(qb.args)+1), 1) + qb.args = append(qb.args, arg) + } + + qb.sql.WriteString(sql) + qb.sql.WriteString("\n") +} + +func (qb *QueryBuilder) String() string { + return qb.sql.String() +} + +func (qb *QueryBuilder) Args() []interface{} { + return qb.args +} diff --git a/src/discord/gateway.go b/src/discord/gateway.go index f48faa4..1e63d22 100644 --- a/src/discord/gateway.go +++ b/src/discord/gateway.go @@ -252,7 +252,7 @@ func (bot *botInstance) connect(ctx context.Context) error { shouldResume := true isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // No session yet! Just identify and get on with it shouldResume = false } else { @@ -630,7 +630,7 @@ func (bot *botInstance) messageDelete(ctx context.Context, msgDelete MessageDele `, msgDelete.ID, msgDelete.ChannelID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return } else if err != nil { log.Error().Err(err).Msg("failed to check for message to delete") diff --git a/src/discord/showcase.go b/src/discord/showcase.go index fa7b267..f2c26c4 100644 --- a/src/discord/showcase.go +++ b/src/discord/showcase.go @@ -133,7 +133,7 @@ func SaveMessage( `, msg.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { if !msg.OriginalHasFields("author", "timestamp") { return nil, errNotEnoughInfo } @@ -213,7 +213,7 @@ func SaveMessageAndContents( `, newMsg.UserID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return newMsg, nil } else if err != nil { return nil, oops.New(err, "failed to look up linked Discord user") @@ -337,7 +337,7 @@ func saveAttachment( ) if err == nil { return iexisting.(*models.DiscordMessageAttachment), nil - } else if errors.Is(err, db.ErrNoMatchingRows) { + } else if errors.Is(err, db.NotFound) { // this is fine, just create it } else { return nil, oops.New(err, "failed to check for existing attachment") @@ -522,7 +522,7 @@ func AllowedToCreateMessageSnippet(ctx context.Context, tx db.ConnOrTx, discordU `, discordUserId, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return false, nil } else if err != nil { return false, oops.New(err, "failed to check if we can save Discord message") diff --git a/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go b/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go new file mode 100644 index 0000000..bf6a624 --- /dev/null +++ b/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go @@ -0,0 +1,89 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "git.handmade.network/hmn/hmn/src/oops" + "github.com/jackc/pgx/v4" +) + +func init() { + registerMigration(ApproveActiveUsers{}) +} + +type ApproveActiveUsers struct{} + +func (m ApproveActiveUsers) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2021, 9, 23, 4, 15, 21, 0, time.UTC)) +} + +func (m ApproveActiveUsers) Name() string { + return "ApproveActiveUsers" +} + +func (m ApproveActiveUsers) Description() string { + return "Give legit users the Approved status" +} + +func (m ApproveActiveUsers) Up(ctx context.Context, tx pgx.Tx) error { + /* + See models/user.go. + The old statuses were: + 2 = Active + 3 = Banned + The new statuses are: + 2 = Confirmed (valid email) + 3 = Approved (allowed to post) + 4 = Banned + */ + + _, err := tx.Exec(ctx, ` + UPDATE auth_user + SET status = 4 + WHERE status = 3 + `) + if err != nil { + return oops.New(err, "failed to update status of banned users") + } + + _, err = tx.Exec(ctx, ` + UPDATE auth_user + SET status = 3 + WHERE + status = 2 + AND id IN ( + SELECT author_id + FROM handmade_post + WHERE author_id IS NOT NULL + ) + `) + if err != nil { + return oops.New(err, "failed to update user statuses") + } + + return nil +} + +func (m ApproveActiveUsers) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + UPDATE auth_user + SET status = 2 + WHERE status = 3 + `) + if err != nil { + return oops.New(err, "failed to revert approved users back to confirmed") + } + + _, err = tx.Exec(ctx, ` + UPDATE auth_user + SET status = 3 + WHERE status = 4 + `) + if err != nil { + return oops.New(err, "failed to update status of banned users") + } + + return nil +} diff --git a/src/models/post.go b/src/models/post.go index 0799f78..cb4b533 100644 --- a/src/models/post.go +++ b/src/models/post.go @@ -12,6 +12,7 @@ type Post struct { ThreadID int `db:"thread_id"` CurrentID int `db:"current_id"` // The id of the current PostVersion ProjectID int `db:"project_id"` + ReplyID *int `db:"reply_id"` ThreadType ThreadType `db:"thread_type"` diff --git a/src/models/user.go b/src/models/user.go index 58c3a03..4b0699a 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -10,9 +10,10 @@ var UserType = reflect.TypeOf(User{}) type UserStatus int const ( - UserStatusInactive UserStatus = iota + 1 - UserStatusActive - UserStatusBanned + UserStatusInactive UserStatus = 1 // Default for new users + UserStatusConfirmed = 2 // Confirmed email address + UserStatusApproved = 3 // Approved by an admin and allowed to publicly post + UserStatusBanned = 4 // BALEETED ) type User struct { @@ -54,5 +55,5 @@ func (u *User) BestName() string { } func (u *User) IsActive() bool { - return u.Status == UserStatusActive + return u.Status == UserStatusConfirmed } diff --git a/src/perf/perf.go b/src/perf/perf.go index a7af05f..ee0f55c 100644 --- a/src/perf/perf.go +++ b/src/perf/perf.go @@ -27,12 +27,20 @@ func MakeNewRequestPerf(route string, method string, path string) *RequestPerf { } func (rp *RequestPerf) EndRequest() { + if rp == nil { + return + } + for rp.EndBlock() { } rp.End = time.Now() } func (rp *RequestPerf) Checkpoint(category, description string) { + if rp == nil { + return + } + now := time.Now() checkpoint := PerfBlock{ Start: now, @@ -44,6 +52,10 @@ func (rp *RequestPerf) Checkpoint(category, description string) { } func (rp *RequestPerf) StartBlock(category, description string) { + if rp == nil { + return + } + now := time.Now() checkpoint := PerfBlock{ Start: now, @@ -55,6 +67,10 @@ func (rp *RequestPerf) StartBlock(category, description string) { } func (rp *RequestPerf) EndBlock() bool { + if rp == nil { + return false + } + for i := len(rp.Blocks) - 1; i >= 0; i -= 1 { if rp.Blocks[i].End.Equal(time.Time{}) { rp.Blocks[i].End = time.Now() @@ -65,6 +81,10 @@ func (rp *RequestPerf) EndBlock() bool { } func (rp *RequestPerf) MsFromStart(block *PerfBlock) float64 { + if rp == nil { + return 0 + } + return float64(block.Start.Sub(rp.Start).Nanoseconds()) / 1000 / 1000 } diff --git a/src/website/auth.go b/src/website/auth.go index 57a6c30..5b208fd 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -77,7 +77,7 @@ func Login(c *RequestContext) ResponseData { userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE LOWER(username) = LOWER($1)", username) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return showLoginWithFailure(c, redirect) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) @@ -172,7 +172,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { username, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { userAlreadyExists = false } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) @@ -193,7 +193,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { emailAddress, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { emailAlreadyExists = false } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) @@ -372,7 +372,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { SET status = $1 WHERE id = $2 `, - models.UserStatusActive, + models.UserStatusConfirmed, validationResult.User.ID, ) if err != nil { @@ -459,7 +459,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } @@ -482,7 +482,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) } } @@ -644,7 +644,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { SET status = $1 WHERE id = $2 `, - models.UserStatusActive, + models.UserStatusConfirmed, validationResult.User.ID, ) if err != nil { @@ -786,7 +786,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, ) var result validateUserAndTokenResult if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { result.Error = oops.New(err, "failed to fetch user and token from db") return result } diff --git a/src/website/blogs.go b/src/website/blogs.go index 0f01fa7..2ff59ad 100644 --- a/src/website/blogs.go +++ b/src/website/blogs.go @@ -1,6 +1,7 @@ package website import ( + "errors" "fmt" "html/template" "net/http" @@ -34,21 +35,10 @@ func BlogIndex(c *RequestContext) ResponseData { const postsPerPage = 5 - c.Perf.StartBlock("SQL", "Fetch count of posts") - numPosts, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM - handmade_thread - WHERE - project_id = $1 - AND type = $2 - AND NOT deleted - `, - c.CurrentProject.ID, - models.ThreadTypeProjectBlogPost, - ) - c.Perf.EndBlock() + numPosts, err := CountPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch total number of blog posts")) } @@ -59,48 +49,24 @@ func BlogIndex(c *RequestContext) ResponseData { c.Redirect(hmnurl.BuildBlog(c.CurrentProject.Slug, page), http.StatusSeeOther) } - type blogIndexQuery struct { - Thread models.Thread `db:"thread"` - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - } - c.Perf.StartBlock("SQL", "Fetch blog posts") - postsResult, err := db.Query(c.Context(), c.Conn, blogIndexQuery{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS post ON thread.first_id = post.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - WHERE - post.project_id = $1 - AND post.thread_type = $2 - AND NOT thread.deleted - ORDER BY post.postdate DESC - LIMIT $3 OFFSET $4 - `, - c.CurrentProject.ID, - models.ThreadTypeProjectBlogPost, - postsPerPage, - (page-1)*postsPerPage, - ) - c.Perf.EndBlock() + threads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + Limit: postsPerPage, + Offset: (page - 1) * postsPerPage, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch blog posts for index")) } var entries []blogIndexEntry - for _, irow := range postsResult.ToSlice() { - row := irow.(*blogIndexQuery) - + for _, thread := range threads { entries = append(entries, blogIndexEntry{ - Title: row.Thread.Title, - Url: hmnurl.BuildBlogThread(c.CurrentProject.Slug, row.Thread.ID, row.Thread.Title), - Author: templates.UserToTemplate(row.Author, c.Theme), - Date: row.Post.PostDate, - Content: template.HTML(row.CurrentVersion.TextParsed), + Title: thread.Thread.Title, + Url: hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.Thread.ID, thread.Thread.Title), + Author: templates.UserToTemplate(thread.FirstPostAuthor, c.Theme), + Date: thread.FirstPost.PostDate, + Content: template.HTML(thread.FirstPostCurrentVersion.TextParsed), }) } @@ -158,12 +124,15 @@ func BlogThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread, posts, preview := FetchThreadPostsAndStuff( - c.Context(), - c.Conn, - cd.ThreadID, - 0, 0, - ) + thread, posts, err := FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for blog thread")) + } var templatePosts []templates.Post for _, p := range posts { @@ -202,7 +171,7 @@ func BlogThread(c *RequestContext) ResponseData { baseData := getBaseData(c, thread.Title, []templates.Breadcrumb{BlogBreadcrumb(c.CurrentProject.Slug)}) baseData.OpenGraphItems = append(baseData.OpenGraphItems, templates.OpenGraphItem{ Property: "og:description", - Value: preview, + Value: posts[0].Post.Preview, }) var res ResponseData @@ -223,9 +192,17 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread for blog redirect")) + } - threadUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, thread.Title, cd.PostID) + threadUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, thread.Thread.Title, cd.PostID) return c.Redirect(threadUrl, http.StatusFound) } @@ -305,24 +282,32 @@ func BlogPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to edit")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Editing \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Editing \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } else { title = fmt.Sprintf("Editing Post | %s", c.CurrentProject.Name) } baseData := getBaseData( c, title, - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - editData := getEditorDataForEdit(c.CurrentUser, baseData, postData) + editData := getEditorDataForEdit(c.CurrentUser, baseData, post) editData.SubmitUrl = hmnurl.BuildBlogPostEdit(c.CurrentProject.Slug, cd.ThreadID, cd.PostID) editData.SubmitLabel = "Submit Edited Post" - if postData.Thread.FirstID != postData.Post.ID { + if post.Thread.FirstID != post.Post.ID { editData.SubmitLabel = "Submit Edited Comment" } @@ -347,27 +332,35 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData { } defer tx.Rollback(c.Context()) - postData := FetchPostAndStuff(c.Context(), tx, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to submit edits")) + } c.Req.ParseForm() title := c.Req.Form.Get("title") unparsed := c.Req.Form.Get("body") editReason := c.Req.Form.Get("editreason") - if title != "" && postData.Thread.FirstID != postData.Post.ID { + if title != "" && post.Thread.FirstID != post.Post.ID { return RejectRequest(c, "You can only edit the title by editing the first post.") } if unparsed == "" { return RejectRequest(c, "You must provide a post body.") } - CreatePostVersion(c.Context(), tx, postData.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) + CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) err = tx.Commit(c.Context()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post")) } - postUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, postData.Thread.Title, cd.PostID) + postUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, post.Thread.Title, cd.PostID) return c.Redirect(postUrl, http.StatusSeeOther) } @@ -377,16 +370,24 @@ func BlogPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post for reply")) + } baseData := getBaseData( c, - fmt.Sprintf("Replying to comment in \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name), - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + fmt.Sprintf("Replying to comment in \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - replyPost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - replyPost.AddContentVersion(postData.CurrentVersion, postData.Editor) + replyPost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + replyPost.AddContentVersion(post.CurrentVersion, post.Editor) editData := getEditorDataForNew(c.CurrentUser, baseData, &replyPost) editData.SubmitUrl = hmnurl.BuildBlogPostReply(c.CurrentProject.Slug, cd.ThreadID, cd.PostID) @@ -439,23 +440,30 @@ func BlogPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to delete")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Deleting \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Deleting \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } else { - title = fmt.Sprintf("Deleting comment in \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + title = fmt.Sprintf("Deleting comment in \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } baseData := getBaseData( c, title, - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - // TODO(ben): Set breadcrumbs - templatePost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - templatePost.AddContentVersion(postData.CurrentVersion, postData.Editor) + templatePost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + templatePost.AddContentVersion(post.CurrentVersion, post.Editor) type blogPostDeleteData struct { templates.BaseData @@ -499,8 +507,16 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData { projectUrl := hmnurl.BuildProjectHomepage(c.CurrentProject.Slug) return c.Redirect(projectUrl, http.StatusSeeOther) } else { - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) - threadUrl := hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.ID, thread.Title) + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + panic(oops.New(err, "the thread was supposedly not deleted after deleting a post in a blog, but the thread was not found afterwards")) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread after blog post delete")) + } + threadUrl := hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.Thread.ID, thread.Thread.Title) return c.Redirect(threadUrl, http.StatusSeeOther) } } diff --git a/src/website/discord.go b/src/website/discord.go index 9a75fba..d5439ee 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -97,7 +97,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { c.CurrentUser.ID, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink")) @@ -134,7 +134,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { `SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`, c.CurrentUser.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // Nothing to do c.Logger.Warn().Msg("could not do showcase backlog because no discord user exists") return c.Redirect(hmnurl.BuildUserProfile(c.CurrentUser.Username), http.StatusSeeOther) diff --git a/src/website/forums.go b/src/website/forums.go index ed3f9ec..35debc8 100644 --- a/src/website/forums.go +++ b/src/website/forums.go @@ -70,7 +70,7 @@ func getEditorDataForNew(currentUser *models.User, baseData templates.BaseData, return result } -func getEditorDataForEdit(currentUser *models.User, baseData templates.BaseData, p postAndRelatedModels) editorData { +func getEditorDataForEdit(currentUser *models.User, baseData templates.BaseData, p PostAndStuff) editorData { return editorData{ BaseData: baseData, Title: p.Thread.Title, @@ -92,21 +92,14 @@ func Forum(c *RequestContext) ResponseData { currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) - c.Perf.StartBlock("SQL", "Fetch count of page threads") - numThreads, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_thread AS thread - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - `, - cd.SubforumID, - ) + numThreads, err := CountThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{cd.SubforumID}, + }) if err != nil { panic(oops.New(err, "failed to get count of threads")) } - c.Perf.EndBlock() numPages := utils.NumPages(numThreads, threadsPerPage) page, ok := ParsePageNumber(c, "page", numPages) @@ -115,78 +108,28 @@ func Forum(c *RequestContext) ResponseData { } howManyThreadsToSkip := (page - 1) * threadsPerPage - var currentUserId *int - if c.CurrentUser != nil { - currentUserId = &c.CurrentUser.ID - } - - c.Perf.StartBlock("SQL", "Fetch page threads") - type threadQueryResult struct { - Thread models.Thread `db:"thread"` - FirstPost models.Post `db:"firstpost"` - LastPost models.Post `db:"lastpost"` - FirstUser *models.User `db:"firstuser"` - LastUser *models.User `db:"lastuser"` - ThreadLastReadTime *time.Time `db:"tlri.lastread"` - ForumLastReadTime *time.Time `db:"slri.lastread"` - } - itMainThreads, err := db.Query(c.Context(), c.Conn, threadQueryResult{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS firstpost ON thread.first_id = firstpost.id - JOIN handmade_post AS lastpost ON thread.last_id = lastpost.id - LEFT JOIN auth_user AS firstuser ON firstpost.author_id = firstuser.id - LEFT JOIN auth_user AS lastuser ON lastpost.author_id = lastuser.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = thread.id - AND tlri.user_id = $2 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = $1 - AND slri.user_id = $2 - ) - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - ORDER BY lastpost.postdate DESC - LIMIT $3 OFFSET $4 - `, - cd.SubforumID, - currentUserId, - threadsPerPage, - howManyThreadsToSkip, - ) - if err != nil { - panic(oops.New(err, "failed to fetch threads")) - } - c.Perf.EndBlock() - defer itMainThreads.Close() - - makeThreadListItem := func(row *threadQueryResult) templates.ThreadListItem { - hasRead := false - if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.LastPost.PostDate) { - hasRead = true - } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.LastPost.PostDate) { - hasRead = true - } + mainThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{cd.SubforumID}, + Limit: threadsPerPage, + Offset: howManyThreadsToSkip, + }) + makeThreadListItem := func(row ThreadAndStuff) templates.ThreadListItem { return templates.ThreadListItem{ Title: row.Thread.Title, Url: hmnurl.BuildForumThread(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(*row.Thread.SubforumID), row.Thread.ID, row.Thread.Title, 1), - FirstUser: templates.UserToTemplate(row.FirstUser, c.Theme), + FirstUser: templates.UserToTemplate(row.FirstPostAuthor, c.Theme), FirstDate: row.FirstPost.PostDate, - LastUser: templates.UserToTemplate(row.LastUser, c.Theme), + LastUser: templates.UserToTemplate(row.LastPostAuthor, c.Theme), LastDate: row.LastPost.PostDate, - - Unread: !hasRead, + Unread: row.Unread, } } var threads []templates.ThreadListItem - for _, irow := range itMainThreads.ToSlice() { - row := irow.(*threadQueryResult) + for _, row := range mainThreads { threads = append(threads, makeThreadListItem(row)) } @@ -199,59 +142,25 @@ func Forum(c *RequestContext) ResponseData { subforumNodes := cd.SubforumTree[cd.SubforumID].Children for _, sfNode := range subforumNodes { - c.Perf.StartBlock("SQL", "Fetch count of subforum threads") - numThreads, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_thread AS thread - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - `, - sfNode.ID, - ) + numThreads, err := CountThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{sfNode.ID}, + }) if err != nil { panic(oops.New(err, "failed to get count of threads")) } - c.Perf.EndBlock() - c.Perf.StartBlock("SQL", "Fetch subforum threads") - itThreads, err := db.Query(c.Context(), c.Conn, threadQueryResult{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS firstpost ON thread.first_id = firstpost.id - JOIN handmade_post AS lastpost ON thread.last_id = lastpost.id - LEFT JOIN auth_user AS firstuser ON firstpost.author_id = firstuser.id - LEFT JOIN auth_user AS lastuser ON lastpost.author_id = lastuser.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = thread.id - AND tlri.user_id = $2 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = $1 - AND slri.user_id = $2 - ) - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - ORDER BY lastpost.postdate DESC - LIMIT 3 - `, - sfNode.ID, - currentUserId, - ) - if err != nil { - panic(err) - } - defer itThreads.Close() - c.Perf.EndBlock() + subforumThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{sfNode.ID}, + Limit: 3, + }) var threads []templates.ThreadListItem - for _, irow := range itThreads.ToSlice() { - threadRow := irow.(*threadQueryResult) - threads = append(threads, makeThreadListItem(threadRow)) + for _, row := range subforumThreads { + threads = append(threads, makeThreadListItem(row)) } subforums = append(subforums, forumSubforumData{ @@ -415,7 +324,8 @@ type forumThreadData struct { Pagination templates.Pagination } -var threadViewPostsPerPage = 15 +// How many posts to display on a single page of a forum thread. +var threadPostsPerPage = 15 func ForumThread(c *RequestContext) ResponseData { cd, ok := getCommonForumData(c) @@ -425,22 +335,28 @@ func ForumThread(c *RequestContext) ResponseData { currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) + threads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadIDs: []int{cd.ThreadID}, + }) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get thread")) + } + if len(threads) == 0 { + return FourOhFour(c) + } + threadResult := threads[0] + thread := threadResult.Thread - numPosts, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_post - WHERE - thread_id = $1 - AND NOT deleted - `, - thread.ID, - ) + numPosts, err := CountPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + ThreadIDs: []int{cd.ThreadID}, + }) if err != nil { panic(oops.New(err, "failed to get count of posts for thread")) } - page, numPages, ok := getPageInfo(c.PathParams["page"], numPosts, threadViewPostsPerPage) + page, numPages, ok := getPageInfo(c.PathParams["page"], numPosts, threadPostsPerPage) if !ok { urlNoPage := hmnurl.BuildForumThread(c.CurrentProject.Slug, currentSubforumSlugs, thread.ID, thread.Title, 1) return c.Redirect(urlNoPage, http.StatusSeeOther) @@ -455,16 +371,16 @@ func ForumThread(c *RequestContext) ResponseData { PreviousUrl: hmnurl.BuildForumThread(c.CurrentProject.Slug, currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)), } - c.Perf.StartBlock("SQL", "Fetch posts") - _, postsAndStuff, preview := FetchThreadPostsAndStuff( - c.Context(), - c.Conn, - cd.ThreadID, - page, threadViewPostsPerPage, - ) - c.Perf.EndBlock() + postsAndStuff, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadIDs: []int{thread.ID}, + Limit: threadPostsPerPage, + Offset: (page - 1) * threadPostsPerPage, + }) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread posts")) + } - c.Perf.StartBlock("TEMPLATE", "Create template posts") var posts []templates.Post for _, p := range postsAndStuff { post := templates.PostToTemplate(&p.Post, p.Author, c.Theme) @@ -481,18 +397,17 @@ func ForumThread(c *RequestContext) ResponseData { posts = append(posts, post) } - c.Perf.EndBlock() // Update thread last read info if c.CurrentUser != nil { c.Perf.StartBlock("SQL", "Update TLRI") _, err = c.Conn.Exec(c.Context(), ` - INSERT INTO handmade_threadlastreadinfo (thread_id, user_id, lastread) - VALUES ($1, $2, $3) - ON CONFLICT (thread_id, user_id) DO UPDATE - SET lastread = EXCLUDED.lastread - `, + INSERT INTO handmade_threadlastreadinfo (thread_id, user_id, lastread) + VALUES ($1, $2, $3) + ON CONFLICT (thread_id, user_id) DO UPDATE + SET lastread = EXCLUDED.lastread + `, cd.ThreadID, c.CurrentUser.ID, time.Now(), @@ -506,7 +421,7 @@ func ForumThread(c *RequestContext) ResponseData { baseData := getBaseData(c, thread.Title, SubforumBreadcrumbs(cd.LineageBuilder, c.CurrentProject, cd.SubforumID)) baseData.OpenGraphItems = append(baseData.OpenGraphItems, templates.OpenGraphItem{ Property: "og:description", - Value: preview, + Value: threadResult.FirstPost.Preview, }) var res ResponseData @@ -527,30 +442,20 @@ func ForumPostRedirect(c *RequestContext) ResponseData { return FourOhFour(c) } - c.Perf.StartBlock("SQL", "Fetch post ids for thread") - type postQuery struct { - PostID int `db:"post.id"` - } - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - WHERE - post.thread_id = $1 - AND NOT post.deleted - ORDER BY postdate - `, - cd.ThreadID, - ) + posts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + ThreadIDs: []int{cd.ThreadID}, + }) if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post ids")) + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for redirect")) } - postQuerySlice := postQueryResult.ToSlice() - c.Perf.EndBlock() + + var post PostAndStuff postIdx := -1 - for i, id := range postQuerySlice { - if id.(*postQuery).PostID == cd.PostID { + for i, p := range posts { + if p.Post.ID == cd.PostID { + post = p postIdx = i break } @@ -559,31 +464,13 @@ func ForumPostRedirect(c *RequestContext) ResponseData { return FourOhFour(c) } - c.Perf.StartBlock("SQL", "Fetch thread title") - type threadTitleQuery struct { - ThreadTitle string `db:"thread.title"` - } - threadTitleQueryResult, err := db.QueryOne(c.Context(), c.Conn, threadTitleQuery{}, - ` - SELECT $columns - FROM handmade_thread AS thread - WHERE thread.id = $1 - `, - cd.ThreadID, - ) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread title")) - } - c.Perf.EndBlock() - threadTitle := threadTitleQueryResult.(*threadTitleQuery).ThreadTitle - - page := (postIdx / threadViewPostsPerPage) + 1 + page := (postIdx / threadPostsPerPage) + 1 return c.Redirect(hmnurl.BuildForumThreadWithPostHash( c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, - threadTitle, + post.Thread.Title, page, cd.PostID, ), http.StatusSeeOther) @@ -672,16 +559,24 @@ func ForumPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for reply")) + } baseData := getBaseData( c, fmt.Sprintf("Replying to post | %s", cd.SubforumTree[cd.SubforumID].Name), - ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread), + ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread), ) - replyPost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - replyPost.AddContentVersion(postData.CurrentVersion, postData.Editor) + replyPost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + replyPost.AddContentVersion(post.CurrentVersion, post.Editor) editData := getEditorDataForNew(c.CurrentUser, baseData, &replyPost) editData.SubmitUrl = hmnurl.BuildForumPostReply(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, cd.PostID) @@ -713,11 +608,17 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData { return RejectRequest(c, "Your reply cannot be empty.") } - thread := FetchThread(c.Context(), tx, cd.ThreadID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } // Replies to the OP should not be considered replies var replyPostId *int - if cd.PostID != thread.FirstID { + if cd.PostID != post.Thread.FirstID { replyPostId = &cd.PostID } @@ -742,17 +643,25 @@ func ForumPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for editing")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Editing \"%s\" | %s", postData.Thread.Title, cd.SubforumTree[cd.SubforumID].Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Editing \"%s\" | %s", post.Thread.Title, cd.SubforumTree[cd.SubforumID].Name) } else { title = fmt.Sprintf("Editing Post | %s", cd.SubforumTree[cd.SubforumID].Name) } - baseData := getBaseData(c, title, ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread)) + baseData := getBaseData(c, title, ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread)) - editData := getEditorDataForEdit(c.CurrentUser, baseData, postData) + editData := getEditorDataForEdit(c.CurrentUser, baseData, post) editData.SubmitUrl = hmnurl.BuildForumPostEdit(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, cd.PostID) editData.SubmitLabel = "Submit Edited Post" @@ -805,16 +714,24 @@ func ForumPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for delete")) + } baseData := getBaseData( c, - fmt.Sprintf("Deleting post in \"%s\" | %s", postData.Thread.Title, cd.SubforumTree[cd.SubforumID].Name), - ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread), + fmt.Sprintf("Deleting post in \"%s\" | %s", post.Thread.Title, cd.SubforumTree[cd.SubforumID].Name), + ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread), ) - templatePost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - templatePost.AddContentVersion(postData.CurrentVersion, postData.Editor) + templatePost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + templatePost.AddContentVersion(post.CurrentVersion, post.Editor) type forumPostDeleteData struct { templates.BaseData @@ -870,31 +787,22 @@ func WikiArticleRedirect(c *RequestContext) ResponseData { panic(err) } - ithread, err := db.QueryOne(c.Context(), c.Conn, models.Thread{}, - ` - SELECT $columns - FROM handmade_thread - WHERE - id = $1 - AND project_id = $2 - AND NOT deleted - `, - threadId, - c.CurrentProject.ID, - ) - if errors.Is(err, db.ErrNoMatchingRows) { + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + // This is the rare query where we want all thread types! + }) + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up wiki thread")) } - thread := ithread.(*models.Thread) c.Perf.StartBlock("SQL", "Fetch subforum tree") subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() - dest := UrlForGenericThread(thread, lineageBuilder, c.CurrentProject.Slug) + dest := UrlForGenericThread(&thread.Thread, lineageBuilder, c.CurrentProject.Slug) return c.Redirect(dest, http.StatusFound) } @@ -911,9 +819,11 @@ type commonForumData struct { /* Gets data that is used on basically every forums-related route. Parses path params for subforum, -thread, and post ids and validates that all those resources do in fact exist. +thread, and post ids. -Returns false if any data is invalid and you should return a 404. +Does NOT validate that the requested thread and post ID are valid. + +If this returns false, then something was malformed and you should 404. */ func getCommonForumData(c *RequestContext) (commonForumData, bool) { c.Perf.StartBlock("FORUMS", "Fetch common forum data") @@ -936,8 +846,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.SubforumID = sfId - - // No need to validate that subforum exists here; it's handled by validateSubforums. } if threadIdStr, hasThreadId := c.PathParams["threadid"]; hasThreadId { @@ -946,27 +854,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.ThreadID = threadId - - c.Perf.StartBlock("SQL", "Verify that the thread exists") - threadExists, err := db.QueryBool(c.Context(), c.Conn, - ` - SELECT COUNT(*) > 0 - FROM handmade_thread - WHERE - id = $1 - AND subforum_id = $2 - AND NOT deleted - `, - res.ThreadID, - res.SubforumID, - ) - c.Perf.EndBlock() - if err != nil { - panic(err) - } - if !threadExists { - return commonForumData{}, false - } } if postIdStr, hasPostId := c.PathParams["postid"]; hasPostId { @@ -975,27 +862,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.PostID = postId - - c.Perf.StartBlock("SQL", "Verify that the post exists") - postExists, err := db.QueryBool(c.Context(), c.Conn, - ` - SELECT COUNT(*) > 0 - FROM handmade_post - WHERE - id = $1 - AND thread_id = $2 - AND NOT deleted - `, - res.PostID, - res.ThreadID, - ) - c.Perf.EndBlock() - if err != nil { - panic(err) - } - if !postExists { - return commonForumData{}, false - } } return res, true @@ -1042,6 +908,8 @@ func addForumUrlsToPost(p *templates.Post, projectSlug string, subforums []strin p.ReplyUrl = hmnurl.BuildForumPostReply(projectSlug, subforums, threadId, postId) } +// Takes a template post and adds information about how many posts the user has made +// on the site. func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) { numPosts, err := db.QueryInt(ctx, conn, ` diff --git a/src/website/landing.go b/src/website/landing.go index fef6044..badc34b 100644 --- a/src/website/landing.go +++ b/src/website/landing.go @@ -81,56 +81,22 @@ func Index(c *RequestContext) ResponseData { lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() - var currentUserId *int - if c.CurrentUser != nil { - currentUserId = &c.CurrentUser.ID - } - c.Perf.StartBlock("LANDING", "Process projects") for _, projRow := range allProjects { proj := projRow.(*models.Project) c.Perf.StartBlock("SQL", fmt.Sprintf("Fetch posts for %s", proj.Name)) - type projectPostQuery struct { - Post models.Post `db:"post"` - Thread models.Thread `db:"thread"` - User models.User `db:"auth_user"` - ThreadLastReadTime *time.Time `db:"tlri.lastread"` - SubforumLastReadTime *time.Time `db:"slri.lastread"` - } - projectPostIter, err := db.Query(c.Context(), c.Conn, projectPostQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - JOIN handmade_thread AS thread ON post.thread_id = thread.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = post.thread_id - AND tlri.user_id = $1 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = thread.subforum_id - AND slri.user_id = $1 - ) - LEFT JOIN auth_user ON post.author_id = auth_user.id - WHERE - post.project_id = $2 - AND post.thread_type = ANY ($3) - AND post.deleted = FALSE - ORDER BY postdate DESC - LIMIT $4 - `, - currentUserId, - proj.ID, - []models.ThreadType{models.ThreadTypeProjectBlogPost, models.ThreadTypeForumPost}, - maxPosts, - ) + projectPosts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{proj.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost, models.ThreadTypeForumPost}, + Limit: maxPosts, + SortDescending: true, + }) c.Perf.EndBlock() if err != nil { c.Logger.Error().Err(err).Msg("failed to fetch project posts") continue } - projectPosts := projectPostIter.ToSlice() forumsUrl := "" if proj.ForumID != nil { @@ -144,18 +110,7 @@ func Index(c *RequestContext) ResponseData { ForumsUrl: forumsUrl, } - for _, projectPostRow := range projectPosts { - projectPost := projectPostRow.(*projectPostQuery) - - hasRead := false - if c.CurrentUser != nil && c.CurrentUser.MarkedAllReadAt.After(projectPost.Post.PostDate) { - hasRead = true - } else if projectPost.ThreadLastReadTime != nil && projectPost.ThreadLastReadTime.After(projectPost.Post.PostDate) { - hasRead = true - } else if projectPost.SubforumLastReadTime != nil && projectPost.SubforumLastReadTime.After(projectPost.Post.PostDate) { - hasRead = true - } - + for _, projectPost := range projectPosts { featurable := (!proj.IsHMN() && projectPost.Post.ThreadType == models.ThreadTypeProjectBlogPost && projectPost.Thread.FirstID == projectPost.Post.ID && @@ -185,9 +140,9 @@ func Index(c *RequestContext) ResponseData { landingPageProject.FeaturedPost = &LandingPageFeaturedPost{ Title: projectPost.Thread.Title, Url: hmnurl.BuildBlogThread(proj.Slug, projectPost.Thread.ID, projectPost.Thread.Title), - User: templates.UserToTemplate(&projectPost.User, c.Theme), + User: templates.UserToTemplate(projectPost.Author, c.Theme), Date: projectPost.Post.PostDate, - Unread: !hasRead, + Unread: projectPost.ThreadUnread, Content: template.HTML(content), } } else { @@ -198,8 +153,8 @@ func Index(c *RequestContext) ResponseData { proj, &projectPost.Thread, &projectPost.Post, - &projectPost.User, - !hasRead, + projectPost.Author, + projectPost.ThreadUnread, false, c.Theme, ), @@ -248,35 +203,18 @@ func Index(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Get news") - type newsPostQuery struct { - Post models.Post `db:"post"` - PostVersion models.PostVersion `db:"ver"` - Thread models.Thread `db:"thread"` - User models.User `db:"auth_user"` - } - newsPostRow, err := db.QueryOne(c.Context(), c.Conn, newsPostQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - JOIN handmade_thread AS thread ON post.thread_id = thread.id - JOIN auth_user ON post.author_id = auth_user.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - WHERE - post.project_id = $1 - AND thread.type = $2 - AND post.id = thread.first_id - AND NOT thread.deleted - ORDER BY post.postdate DESC - LIMIT 1 - `, - models.HMNProjectID, - models.ThreadTypeProjectBlogPost, - ) + newsThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{models.HMNProjectID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + Limit: 1, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch news post")) } - newsPostResult := newsPostRow.(*newsPostQuery) + var newsThread ThreadAndStuff + if len(newsThreads) > 0 { + newsThread = newsThreads[0] + } c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetch showcase snippets") @@ -334,12 +272,12 @@ func Index(c *RequestContext) ResponseData { ShowUrl: "https://handmadedev.show/", ShowcaseUrl: hmnurl.BuildShowcase(), NewsPost: LandingPageFeaturedPost{ - Title: newsPostResult.Thread.Title, - Url: hmnurl.BuildBlogThread(models.HMNProjectSlug, newsPostResult.Thread.ID, newsPostResult.Thread.Title), - User: templates.UserToTemplate(&newsPostResult.User, c.Theme), - Date: newsPostResult.Post.PostDate, + Title: newsThread.Thread.Title, + Url: hmnurl.BuildBlogThread(models.HMNProjectSlug, newsThread.Thread.ID, newsThread.Thread.Title), + User: templates.UserToTemplate(newsThread.FirstPostAuthor, c.Theme), + Date: newsThread.FirstPost.PostDate, Unread: true, - Content: template.HTML(newsPostResult.PostVersion.TextParsed), + Content: template.HTML(newsThread.FirstPostCurrentVersion.TextParsed), }, PostColumns: cols, ShowcaseTimelineJson: showcaseJson, diff --git a/src/website/podcast.go b/src/website/podcast.go index a8c21ff..40feb4a 100644 --- a/src/website/podcast.go +++ b/src/website/podcast.go @@ -565,7 +565,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return result, nil } else { return result, oops.New(err, "failed to fetch podcast") @@ -615,7 +615,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return result, nil } else { return result, oops.New(err, "failed to fetch podcast episode") diff --git a/src/website/projects.go b/src/website/projects.go index 9eeeb5a..de55395 100644 --- a/src/website/projects.go +++ b/src/website/projects.go @@ -234,7 +234,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project by slug")) diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index 208efd3..6ca84ec 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -128,10 +128,15 @@ type RequestContext struct { Theme string Perf *perf.RequestPerf + + ctx context.Context } func (c *RequestContext) Context() context.Context { - return c.Req.Context() + if c.ctx == nil { + c.ctx = c.Req.Context() + } + return c.ctx } func (c *RequestContext) URL() *url.URL { diff --git a/src/website/routes.go b/src/website/routes.go index 5c358de..95ff9e9 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -376,7 +376,7 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* if err == nil { subdomainProject := subdomainProjectRow.(*models.Project) return subdomainProject, nil - } else if !errors.Is(err, db.ErrNoMatchingRows) { + } else if !errors.Is(err, db.NotFound) { return nil, oops.New(err, "failed to get projects by slug") } else { return nil, nil @@ -384,7 +384,7 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* } else { defaultProjectRow, err := db.QueryOne(ctx, conn, models.Project{}, "SELECT $columns FROM handmade_project WHERE id = $1", models.HMNProjectID) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return nil, oops.New(nil, "default project didn't exist in the database") } else { return nil, oops.New(err, "failed to get default project") @@ -546,7 +546,7 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found") return nil, nil, nil // user was deleted or something } else { @@ -558,8 +558,12 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User return user, session, nil } +const PerfContextKey = "HMNPerf" + func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (after func()) { c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path) + c.ctx = context.WithValue(c.Context(), PerfContextKey, c.Perf) + return func() { c.Perf.EndRequest() log := logging.Info() @@ -576,6 +580,14 @@ func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (aft } } +func ExtractPerf(ctx context.Context) *perf.RequestPerf { + iperf := ctx.Value(PerfContextKey) + if iperf == nil { + return nil + } + return iperf.(*perf.RequestPerf) +} + func LogContextErrors(c *RequestContext, errs ...error) { for _, err := range errs { c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request") diff --git a/src/website/snippet.go b/src/website/snippet.go index a539723..1d9ce7f 100644 --- a/src/website/snippet.go +++ b/src/website/snippet.go @@ -50,7 +50,7 @@ func Snippet(c *RequestContext) ResponseData { snippetId, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippet")) diff --git a/src/website/threads_and_posts_helper.go b/src/website/threads_and_posts_helper.go index 87c6008..ea65786 100644 --- a/src/website/threads_and_posts_helper.go +++ b/src/website/threads_and_posts_helper.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math" "net" "strings" "time" @@ -19,198 +18,574 @@ import ( "github.com/jackc/pgx/v4" ) -type postAndRelatedModels struct { - Thread models.Thread - Post models.Post - CurrentVersion models.PostVersion +type ThreadsQuery struct { + // Available on all thread queries. + ProjectIDs []int // if empty, all projects + ThreadTypes []models.ThreadType // if empty, all types (you do not want to do this) + SubforumIDs []int // if empty, all subforums - Author *models.User - Editor *models.User + // Ignored when using FetchThread. + ThreadIDs []int - ReplyPost *models.Post - ReplyAuthor *models.User + // Ignored when using FetchThread or CountThreads. + Limit, Offset int // if empty, no pagination +} + +type ThreadAndStuff struct { + Project models.Project `db:"project"` + Thread models.Thread `db:"thread"` + FirstPost models.Post `db:"first_post"` + LastPost models.Post `db:"last_post"` + FirstPostCurrentVersion models.PostVersion `db:"first_version"` + LastPostCurrentVersion models.PostVersion `db:"last_version"` + FirstPostAuthor *models.User `db:"first_author"` // Can be nil in case of a deleted user + LastPostAuthor *models.User `db:"last_author"` // Can be nil in case of a deleted user + Unread bool } /* -Fetches the thread defined by your (already parsed) path params. - -YOU MUST VERIFY THAT THE THREAD ID IS VALID BEFORE CALLING THIS FUNCTION. It will -not check, for example, that the thread belongs to the correct subforum. +Fetches threads and related models from the database according to all the given +query params. For the most correct results, provide as much information as you have +on hand. */ -func FetchThread(ctx context.Context, connOrTx db.ConnOrTx, threadId int) models.Thread { - type threadQueryResult struct { - Thread models.Thread `db:"thread"` +func FetchThreads( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q ThreadsQuery, +) ([]ThreadAndStuff, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Fetch threads") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID } - irow, err := db.QueryOne(ctx, connOrTx, threadQueryResult{}, + + qb.Add( ` SELECT $columns FROM handmade_thread AS thread + JOIN handmade_project AS project ON thread.project_id = project.id + JOIN handmade_post AS first_post ON first_post.id = thread.first_id + JOIN handmade_post AS last_post ON last_post.id = thread.last_id + JOIN handmade_postversion AS first_version ON first_version.id = first_post.current_id + JOIN handmade_postversion AS last_version ON last_version.id = last_post.current_id + LEFT JOIN auth_user AS first_author ON first_author.id = first_post.author_id + LEFT JOIN auth_user AS last_author ON last_author.id = last_post.author_id + LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( + tlri.thread_id = thread.id + AND tlri.user_id = $? + ) + LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( + slri.subforum_id = thread.subforum_id + AND slri.user_id = $? + ) WHERE - id = $1 - AND NOT deleted + NOT thread.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) `, - threadId, + currentUserID, + currentUserID, + models.VisibleProjectLifecycles, + models.HMNProjectID, ) - if err != nil { - // We shouldn't encounter db.ErrNoMatchingRows, because validation should have verified that everything exists. - panic(oops.New(err, "failed to fetch thread")) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.SubforumIDs) > 0 { + qb.Add(`AND thread.subforum_id = ANY ($?)`, q.SubforumIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND thread.id = ANY ($?)`, q.ThreadIDs) + } + if currentUser == nil { + qb.Add( + `AND first_author.status = $? -- thread author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + first_author.status = $? -- thread author is Approved + OR first_author.id = $? -- current user is the thread author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + qb.Add(`ORDER BY last_post.postdate DESC`) + if q.Limit > 0 { + qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - thread := irow.(*threadQueryResult).Thread - return thread -} - -/* -Fetches the post, the thread, and author / editor information for the post defined in -your path params. - -YOU MUST VERIFY THAT THE THREAD ID AND POST ID ARE VALID BEFORE CALLING THIS FUNCTION. -It will not check that the post belongs to the correct subforum, for example, or the -correct project blog. This logic varies per route and per use of threads, so it doesn't -happen here. -*/ -func FetchPostAndStuff( - ctx context.Context, - connOrTx db.ConnOrTx, - threadId, postId int, -) postAndRelatedModels { type resultRow struct { - Thread models.Thread `db:"thread"` - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - Editor *models.User `db:"editor"` - ReplyPost *models.Post `db:"reply"` - ReplyAuthor *models.User `db:"reply_author"` - } - postQueryResult, err := db.QueryOne(ctx, connOrTx, resultRow{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS post ON post.thread_id = thread.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id - LEFT JOIN handmade_post AS reply ON post.reply_id = reply.id - LEFT JOIN auth_user AS reply_author ON reply.author_id = reply_author.id - WHERE - post.thread_id = $1 - AND post.id = $2 - AND NOT post.deleted - `, - threadId, - postId, - ) - if err != nil { - // We shouldn't encounter db.ErrNoMatchingRows, because validation should have verified that everything exists. - panic(oops.New(err, "failed to fetch post and related data")) + ThreadAndStuff + ThreadLastReadTime *time.Time `db:"tlri.lastread"` + ForumLastReadTime *time.Time `db:"slri.lastread"` } - result := postQueryResult.(*resultRow) - return postAndRelatedModels{ - Thread: result.Thread, - Post: result.Post, - CurrentVersion: result.CurrentVersion, - Author: result.Author, - Editor: result.Editor, - ReplyPost: result.ReplyPost, - ReplyAuthor: result.ReplyAuthor, + it, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + if err != nil { + return nil, oops.New(err, "failed to fetch threads") } + iresults := it.ToSlice() + + result := make([]ThreadAndStuff, len(iresults)) + for i, iresult := range iresults { + row := *iresult.(*resultRow) + + hasRead := false + if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.LastPost.PostDate) { + hasRead = true + } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.LastPost.PostDate) { + hasRead = true + } + row.Unread = !hasRead + + result[i] = row.ThreadAndStuff + } + + return result, nil } /* -Fetches all the posts (and related models) for a given thread. +Fetches a single thread and related data. A wrapper around FetchThreads. +As with FetchThreads, provide as much information as you know to get the +most correct results. -YOU MUST VERIFY THAT THE THREAD ID IS VALID BEFORE CALLING THIS FUNCTION. It will -not check, for example, that the thread belongs to the correct subforum. +Returns db.NotFound if no result is found. */ -func FetchThreadPostsAndStuff( +func FetchThread( ctx context.Context, - connOrTx db.ConnOrTx, - threadId int, - page, postsPerPage int, -) (models.Thread, []postAndRelatedModels, string) { - limit := postsPerPage - offset := (page - 1) * postsPerPage - if postsPerPage == 0 { - limit = math.MaxInt32 - offset = 0 + dbConn db.ConnOrTx, + currentUser *models.User, + threadID int, + q ThreadsQuery, +) (ThreadAndStuff, error) { + q.ThreadIDs = []int{threadID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchThreads(ctx, dbConn, currentUser, q) + if err != nil { + return ThreadAndStuff{}, oops.New(err, "failed to fetch thread") } - thread := FetchThread(ctx, connOrTx, threadId) - - type postResult struct { - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - Editor *models.User `db:"editor"` - ReplyPost *models.Post `db:"reply"` - ReplyAuthor *models.User `db:"reply_author"` + if len(res) == 0 { + return ThreadAndStuff{}, db.NotFound } - itPosts, err := db.Query(ctx, connOrTx, postResult{}, + + return res[0], nil +} + +func CountThreads( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q ThreadsQuery, +) (int, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Count threads") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + qb.Add( + ` + SELECT COUNT(*) + FROM + handmade_thread AS thread + JOIN handmade_project AS project ON thread.project_id = project.id + JOIN handmade_post AS first_post ON first_post.id = thread.first_id + LEFT JOIN auth_user AS first_author ON first_author.id = first_post.author_id + WHERE + NOT thread.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) + `, + models.VisibleProjectLifecycles, + models.HMNProjectID, + ) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.SubforumIDs) > 0 { + qb.Add(`AND thread.subforum_id = ANY ($?)`, q.SubforumIDs) + } + if currentUser == nil { + qb.Add( + `AND first_author.status = $? -- thread author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + first_author.status = $? -- thread author is Approved + OR first_author.id = $? -- current user is the thread author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + + count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + if err != nil { + return 0, oops.New(err, "failed to fetch count of threads") + } + + return count, nil +} + +type PostsQuery struct { + // Available on all post queries. + ProjectIDs []int + UserIDs []int + ThreadTypes []models.ThreadType + + // Ignored when using FetchPost. + ThreadIDs []int + PostIDs []int + + // Ignored when using FetchPost or CountPosts. + Limit, Offset int + SortDescending bool +} + +type PostAndStuff struct { + Project models.Project `db:"project"` + Thread models.Thread `db:"thread"` + ThreadUnread bool + Post models.Post `db:"post"` + CurrentVersion models.PostVersion `db:"ver"` + Author *models.User `db:"author"` // Can be nil in case of a deleted user + Editor *models.User `db:"editor"` + ReplyPost *models.Post `db:"reply_post"` + ReplyAuthor *models.User `db:"reply_author"` +} + +/* +Fetches posts and related models from the database according to all the given +query params. For the most correct results, provide as much information as you have +on hand. +*/ +func FetchPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q PostsQuery, +) ([]PostAndStuff, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Fetch posts") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + type resultRow struct { + PostAndStuff + ThreadLastReadTime *time.Time `db:"tlri.lastread"` + ForumLastReadTime *time.Time `db:"slri.lastread"` + } + + qb.Add( ` SELECT $columns - FROM - handmade_post AS post - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id - LEFT JOIN handmade_post AS reply ON post.reply_id = reply.id - LEFT JOIN auth_user AS reply_author ON reply.author_id = reply_author.id - WHERE - post.thread_id = $1 - AND NOT post.deleted - ORDER BY post.postdate - LIMIT $2 OFFSET $3 - `, - thread.ID, - limit, - offset, - ) - if err != nil { - panic(oops.New(err, "failed to fetch posts for thread")) - } - defer itPosts.Close() - - var posts []postAndRelatedModels - for { - irow, hasNext := itPosts.Next() - if !hasNext { - break - } - - row := irow.(*postResult) - posts = append(posts, postAndRelatedModels{ - Thread: thread, - Post: row.Post, - CurrentVersion: row.CurrentVersion, - Author: row.Author, - Editor: row.Editor, - ReplyPost: row.ReplyPost, - ReplyAuthor: row.ReplyAuthor, - }) - } - - preview, err := db.QueryString(ctx, connOrTx, - ` - SELECT post.preview FROM handmade_post AS post JOIN handmade_thread AS thread ON post.thread_id = thread.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id + JOIN handmade_project AS project ON post.project_id = project.id + JOIN handmade_postversion AS ver ON ver.id = post.current_id + LEFT JOIN auth_user AS author ON author.id = post.author_id + LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id + LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( + tlri.thread_id = thread.id + AND tlri.user_id = $? + ) + LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( + slri.subforum_id = thread.subforum_id + AND slri.user_id = $? + ) + -- Unconditionally fetch reply info, but make sure to check it + -- later and possibly remove these fields if the permission + -- check fails. + LEFT JOIN handmade_post AS reply_post ON reply_post.id = post.reply_id + LEFT JOIN auth_user AS reply_author ON reply_post.author_id = reply_author.id WHERE - post.thread_id = $1 - AND thread.first_id = post.id + NOT thread.deleted + AND NOT post.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) `, - thread.ID, + currentUserID, + currentUserID, + models.VisibleProjectLifecycles, + models.HMNProjectID, ) - if err != nil && !errors.Is(err, db.ErrNoMatchingRows) { - panic(oops.New(err, "failed to fetch posts for thread")) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.UserIDs) > 0 { + qb.Add(`AND post.author_id = ANY ($?)`, q.UserIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND post.thread_id = ANY ($?)`, q.ThreadIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.PostIDs) > 0 { + qb.Add(`AND post.id = ANY ($?)`, q.PostIDs) + } + if currentUser == nil { + qb.Add( + `AND author.status = $? -- post author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + author.status = $? -- post author is Approved + OR author.id = $? -- current user is the post author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + qb.Add(`ORDER BY post.postdate`) + if q.SortDescending { + qb.Add(`DESC`) + } + if q.Limit > 0 { + qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - return thread, posts, preview + it, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + if err != nil { + return nil, oops.New(err, "failed to fetch posts") + } + iresults := it.ToSlice() + + result := make([]PostAndStuff, len(iresults)) + for i, iresult := range iresults { + row := *iresult.(*resultRow) + + hasRead := false + if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.Post.PostDate) { + hasRead = true + } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.Post.PostDate) { + hasRead = true + } + row.ThreadUnread = !hasRead + + if row.ReplyPost != nil && row.ReplyAuthor != nil { + replyAuthorIsNotApproved := row.ReplyAuthor.Status != models.UserStatusApproved + canSeeUnapprovedReply := currentUser != nil && (row.ReplyAuthor.ID == currentUser.ID || currentUser.IsStaff) + if replyAuthorIsNotApproved && !canSeeUnapprovedReply { + row.ReplyPost = nil + row.ReplyAuthor = nil + } + } + + result[i] = row.PostAndStuff + } + + return result, nil +} + +/* +Fetches posts for a given thread. A convenient wrapper around FetchPosts that returns +the posts and the actual thread model. + +Return db.NotFound if nothing is found (no thread or no posts). +*/ +func FetchThreadPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + threadID int, + q PostsQuery, +) (models.Thread, []PostAndStuff, error) { + q.ThreadIDs = []int{threadID} + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return models.Thread{}, nil, oops.New(err, "failed to fetch posts for thread") + } + + if len(res) == 0 { + // We shouldn't have threads without posts anyway. + return models.Thread{}, nil, db.NotFound + } + + return res[0].Thread, res, nil +} + +/* +Fetches a single post for a thread and its related data. A wrapper +around FetchPosts. As with FetchPosts, provide as much information +as you know to get the most correct results. + +Returns db.NotFound if no result is found. +*/ +func FetchThreadPost( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + threadID, postID int, + q PostsQuery, +) (PostAndStuff, error) { + q.ThreadIDs = []int{threadID} + q.PostIDs = []int{postID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return PostAndStuff{}, oops.New(err, "failed to fetch post") + } + + if len(res) == 0 { + return PostAndStuff{}, db.NotFound + } + + return res[0], nil +} + +/* +Fetches a single post and its related data. A wrapper +around FetchPosts. As with FetchPosts, provide as much information +as you know to get the most correct results. + +Returns db.NotFound if no result is found. +*/ +func FetchPost( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + postID int, + q PostsQuery, +) (PostAndStuff, error) { + q.PostIDs = []int{postID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return PostAndStuff{}, oops.New(err, "failed to fetch post") + } + + if len(res) == 0 { + return PostAndStuff{}, db.NotFound + } + + return res[0], nil +} + +func CountPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q PostsQuery, +) (int, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Count posts") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + qb.Add( + ` + SELECT COUNT(*) + FROM + handmade_post AS post + JOIN handmade_thread AS thread ON post.thread_id = thread.id + JOIN handmade_project AS project ON post.project_id = project.id + LEFT JOIN auth_user AS author ON author.id = post.author_id + WHERE + NOT thread.deleted + AND NOT post.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) + `, + models.VisibleProjectLifecycles, + models.HMNProjectID, + ) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.UserIDs) > 0 { + qb.Add(`AND post.author_id = ANY ($?)`, q.UserIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND post.thread_id = ANY ($?)`, q.ThreadIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if currentUser == nil { + qb.Add( + `AND author.status = $? -- post author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + author.status = $? -- post author is Approved + OR author.id = $? -- current user is the post author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + + count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + if err != nil { + return 0, oops.New(err, "failed to count posts") + } + + return count, nil } func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User, postId int) bool { @@ -233,7 +608,7 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User postId, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return false } else { panic(oops.New(err, "failed to get author of post when checking permissions")) diff --git a/src/website/user.go b/src/website/user.go index af3afd5..d07b9e5 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -57,7 +57,7 @@ func UserProfile(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username)) @@ -65,6 +65,15 @@ func UserProfile(c *RequestContext) ResponseData { } profileUser = userResult.(*models.User) } + + { + userIsUnapproved := profileUser.Status != models.UserStatusApproved + canViewUnapprovedUser := c.CurrentUser != nil && (c.CurrentUser.ID == profileUser.ID || c.CurrentUser.IsStaff) + if userIsUnapproved && !canViewUnapprovedUser { + return FourOhFour(c) + } + } + c.Perf.StartBlock("SQL", "Fetch user links") type userLinkQuery struct { UserLink models.Link `db:"link"` @@ -119,30 +128,11 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.EndBlock() - type postQuery struct { - Post models.Post `db:"post"` - Thread models.Thread `db:"thread"` - Project models.Project `db:"project"` - } c.Perf.StartBlock("SQL", "Fetch posts") - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - INNER JOIN handmade_thread AS thread ON thread.id = post.thread_id - INNER JOIN handmade_project AS project ON project.id = post.project_id - WHERE - post.author_id = $1 - AND project.lifecycle = ANY ($2) - `, - profileUser.ID, - models.VisibleProjectLifecycles, - ) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for user: %s", username)) - } - postQuerySlice := postQueryResult.ToSlice() + posts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + UserIDs: []int{profileUser.ID}, + SortDescending: true, + }) c.Perf.EndBlock() type snippetQuery struct { @@ -175,18 +165,17 @@ func UserProfile(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("PROFILE", "Construct timeline items") - timelineItems := make([]templates.TimelineItem, 0, len(postQuerySlice)+len(snippetQuerySlice)) + timelineItems := make([]templates.TimelineItem, 0, len(posts)+len(snippetQuerySlice)) numForums := 0 numBlogs := 0 numSnippets := len(snippetQuerySlice) - for _, postRow := range postQuerySlice { - postData := postRow.(*postQuery) + for _, post := range posts { timelineItem := PostToTimelineItem( lineageBuilder, - &postData.Post, - &postData.Thread, - &postData.Project, + &post.Post, + &post.Thread, + &post.Project, profileUser, c.Theme, ) @@ -199,7 +188,7 @@ func UserProfile(c *RequestContext) ResponseData { if timelineItem.Type != templates.TimelineTypeUnknown { timelineItems = append(timelineItems, timelineItem) } else { - c.Logger.Warn().Int("post ID", postData.Post.ID).Msg("Unknown timeline item type for post") + c.Logger.Warn().Int("post ID", post.Post.ID).Msg("Unknown timeline item type for post") } } @@ -292,7 +281,7 @@ func UserSettings(c *RequestContext) ResponseData { `, c.CurrentUser.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // this is fine, but don't fetch any more messages } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))