From 986a42c1ac19bc51170ee92fc709f04739c6893a Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Mon, 13 Sep 2021 23:13:58 -0500 Subject: [PATCH] Switch to centralized helpers for fetching threads/posts This includes the ability to "shadowban" new users who have not yet been approved. We do not have UI for approving these users. Migrate deserving users to new Approved status Add post fetching helpers as well The logic in the thread/post stuff is definitely getting redundant, but I'm not sure I'm yet ready to try to abstract any of it away. The next thing to do is probably to update blogs and other places that fetch threads/posts, and delete the old helpers. Move forums and blogs fully to new helpers Use the helpers on the landing page too that was easy! Fix up some spots I missed Check user status and use helpers on the profile page --- src/admintools/admintools.go | 2 +- src/auth/session.go | 2 +- src/db/db.go | 16 +- src/db/db_test.go | 35 +- src/db/query_builder.go | 42 ++ src/discord/gateway.go | 4 +- src/discord/showcase.go | 8 +- .../2021-09-23T041521Z_ApproveActiveUsers.go | 89 +++ src/models/post.go | 1 + src/models/user.go | 9 +- src/perf/perf.go | 20 + src/website/auth.go | 16 +- src/website/blogs.go | 186 ++--- src/website/discord.go | 4 +- src/website/forums.go | 406 ++++------- src/website/landing.go | 112 +-- src/website/podcast.go | 4 +- src/website/projects.go | 2 +- src/website/requesthandling.go | 7 +- src/website/routes.go | 18 +- src/website/snippet.go | 2 +- src/website/threads_and_posts_helper.go | 685 ++++++++++++++---- src/website/user.go | 53 +- 23 files changed, 1055 insertions(+), 668 deletions(-) create mode 100644 src/db/query_builder.go create mode 100644 src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go diff --git a/src/admintools/admintools.go b/src/admintools/admintools.go index 63a3993..6f76b28 100644 --- a/src/admintools/admintools.go +++ b/src/admintools/admintools.go @@ -86,7 +86,7 @@ func init() { conn := db.NewConnPool(1, 1) defer conn.Close() - res, err := conn.Exec(ctx, "UPDATE auth_user SET status = $1 WHERE LOWER(username) = LOWER($2);", models.UserStatusActive, username) + res, err := conn.Exec(ctx, "UPDATE auth_user SET status = $1 WHERE LOWER(username) = LOWER($2);", models.UserStatusConfirmed, username) if err != nil { panic(err) } diff --git a/src/auth/session.go b/src/auth/session.go index 0e7e537..0ac2d05 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -47,7 +47,7 @@ var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return nil, ErrNoSession } else { return nil, oops.New(err, "failed to get session") diff --git a/src/db/db.go b/src/db/db.go index cf76995..69defd5 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -269,9 +269,8 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix strin return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) } - for i := 0; i < destType.NumField(); i++ { - field := destType.Field(i) - path := append(pathSoFar, i) + for _, field := range reflect.VisibleFields(destType) { + path := append(pathSoFar, field.Index...) if columnName := field.Tag.Get("db"); columnName != "" { fieldType := field.Type @@ -298,7 +297,12 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix strin return columnNames, fieldPaths, nil } -var ErrNoMatchingRows = errors.New("no matching rows") +/* +A general error to be used when no results are found. This is the error returned +by QueryOne, and can generally be used by other database helpers that fetch a single +result but find nothing. +*/ +var NotFound = errors.New("not found") func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { rows, err := Query(ctx, conn, destExample, query, args...) @@ -309,7 +313,7 @@ func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query result, hasRow := rows.Next() if !hasRow { - return nil, ErrNoMatchingRows + return nil, NotFound } return result, nil @@ -335,7 +339,7 @@ func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...inter return vals[0], nil } - return nil, ErrNoMatchingRows + return nil, NotFound } func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { diff --git a/src/db/db_test.go b/src/db/db_test.go index a969a47..41b9491 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -26,8 +26,12 @@ func TestPaths(t *testing.T) { NoTag S } + type Embedded struct { + NoTag S + Nested + } - names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Nested{}), nil, "") + names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") if assert.Nil(t, err) { assert.Equal(t, []string{ "S.I", "S.PI", @@ -38,16 +42,39 @@ func TestPaths(t *testing.T) { "PS.B", "PS.PB", }, names) assert.Equal(t, [][]int{ - {0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5}, - {1, 0}, {1, 1}, {1, 2}, {1, 3}, {1, 4}, {1, 5}, + {1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, + {1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, }, paths) assert.True(t, len(names) == len(paths)) } - testStruct := Nested{} + testStruct := Embedded{} for i, path := range paths { val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) assert.True(t, val.IsValid()) assert.True(t, strings.Contains(names[i], field.Name)) } } + +func TestQueryBuilder(t *testing.T) { + t.Run("happy time", func(t *testing.T) { + var qb QueryBuilder + qb.Add("SELECT stuff FROM thing WHERE foo = $? AND bar = $?", 3, "hello") + qb.Add("AND (baz = $?)", true) + + assert.Equal(t, "SELECT stuff FROM thing WHERE foo = $1 AND bar = $2\nAND (baz = $3)\n", qb.String()) + assert.Equal(t, []interface{}{3, "hello", true}, qb.Args()) + }) + t.Run("too few arguments", func(t *testing.T) { + var qb QueryBuilder + assert.Panics(t, func() { + qb.Add("HELLO $? $? $?", 1, 2) + }) + }) + t.Run("too many arguments", func(t *testing.T) { + var qb QueryBuilder + assert.Panics(t, func() { + qb.Add("HELLO $? $? $?", 1, 2, 3, 4) + }) + }) +} diff --git a/src/db/query_builder.go b/src/db/query_builder.go new file mode 100644 index 0000000..1f794c9 --- /dev/null +++ b/src/db/query_builder.go @@ -0,0 +1,42 @@ +package db + +import ( + "fmt" + "strings" +) + +type QueryBuilder struct { + sql strings.Builder + args []interface{} +} + +/* +Adds the given SQL and arguments to the query. Any occurrences +of `$?` will be replaced with the correct argument number. + +foo $? bar $? baz $? +foo ARG1 bar ARG2 baz $? +foo ARG1 bar ARG2 baz ARG3 +*/ +func (qb *QueryBuilder) Add(sql string, args ...interface{}) { + numPlaceholders := strings.Count(sql, "$?") + if numPlaceholders != len(args) { + panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args))) + } + + for _, arg := range args { + sql = strings.Replace(sql, "$?", fmt.Sprintf("$%d", len(qb.args)+1), 1) + qb.args = append(qb.args, arg) + } + + qb.sql.WriteString(sql) + qb.sql.WriteString("\n") +} + +func (qb *QueryBuilder) String() string { + return qb.sql.String() +} + +func (qb *QueryBuilder) Args() []interface{} { + return qb.args +} diff --git a/src/discord/gateway.go b/src/discord/gateway.go index f48faa4..1e63d22 100644 --- a/src/discord/gateway.go +++ b/src/discord/gateway.go @@ -252,7 +252,7 @@ func (bot *botInstance) connect(ctx context.Context) error { shouldResume := true isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // No session yet! Just identify and get on with it shouldResume = false } else { @@ -630,7 +630,7 @@ func (bot *botInstance) messageDelete(ctx context.Context, msgDelete MessageDele `, msgDelete.ID, msgDelete.ChannelID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return } else if err != nil { log.Error().Err(err).Msg("failed to check for message to delete") diff --git a/src/discord/showcase.go b/src/discord/showcase.go index fa7b267..f2c26c4 100644 --- a/src/discord/showcase.go +++ b/src/discord/showcase.go @@ -133,7 +133,7 @@ func SaveMessage( `, msg.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { if !msg.OriginalHasFields("author", "timestamp") { return nil, errNotEnoughInfo } @@ -213,7 +213,7 @@ func SaveMessageAndContents( `, newMsg.UserID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return newMsg, nil } else if err != nil { return nil, oops.New(err, "failed to look up linked Discord user") @@ -337,7 +337,7 @@ func saveAttachment( ) if err == nil { return iexisting.(*models.DiscordMessageAttachment), nil - } else if errors.Is(err, db.ErrNoMatchingRows) { + } else if errors.Is(err, db.NotFound) { // this is fine, just create it } else { return nil, oops.New(err, "failed to check for existing attachment") @@ -522,7 +522,7 @@ func AllowedToCreateMessageSnippet(ctx context.Context, tx db.ConnOrTx, discordU `, discordUserId, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return false, nil } else if err != nil { return false, oops.New(err, "failed to check if we can save Discord message") diff --git a/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go b/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go new file mode 100644 index 0000000..bf6a624 --- /dev/null +++ b/src/migration/migrations/2021-09-23T041521Z_ApproveActiveUsers.go @@ -0,0 +1,89 @@ +package migrations + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/migration/types" + "git.handmade.network/hmn/hmn/src/oops" + "github.com/jackc/pgx/v4" +) + +func init() { + registerMigration(ApproveActiveUsers{}) +} + +type ApproveActiveUsers struct{} + +func (m ApproveActiveUsers) Version() types.MigrationVersion { + return types.MigrationVersion(time.Date(2021, 9, 23, 4, 15, 21, 0, time.UTC)) +} + +func (m ApproveActiveUsers) Name() string { + return "ApproveActiveUsers" +} + +func (m ApproveActiveUsers) Description() string { + return "Give legit users the Approved status" +} + +func (m ApproveActiveUsers) Up(ctx context.Context, tx pgx.Tx) error { + /* + See models/user.go. + The old statuses were: + 2 = Active + 3 = Banned + The new statuses are: + 2 = Confirmed (valid email) + 3 = Approved (allowed to post) + 4 = Banned + */ + + _, err := tx.Exec(ctx, ` + UPDATE auth_user + SET status = 4 + WHERE status = 3 + `) + if err != nil { + return oops.New(err, "failed to update status of banned users") + } + + _, err = tx.Exec(ctx, ` + UPDATE auth_user + SET status = 3 + WHERE + status = 2 + AND id IN ( + SELECT author_id + FROM handmade_post + WHERE author_id IS NOT NULL + ) + `) + if err != nil { + return oops.New(err, "failed to update user statuses") + } + + return nil +} + +func (m ApproveActiveUsers) Down(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + UPDATE auth_user + SET status = 2 + WHERE status = 3 + `) + if err != nil { + return oops.New(err, "failed to revert approved users back to confirmed") + } + + _, err = tx.Exec(ctx, ` + UPDATE auth_user + SET status = 3 + WHERE status = 4 + `) + if err != nil { + return oops.New(err, "failed to update status of banned users") + } + + return nil +} diff --git a/src/models/post.go b/src/models/post.go index 0799f78..cb4b533 100644 --- a/src/models/post.go +++ b/src/models/post.go @@ -12,6 +12,7 @@ type Post struct { ThreadID int `db:"thread_id"` CurrentID int `db:"current_id"` // The id of the current PostVersion ProjectID int `db:"project_id"` + ReplyID *int `db:"reply_id"` ThreadType ThreadType `db:"thread_type"` diff --git a/src/models/user.go b/src/models/user.go index 58c3a03..4b0699a 100644 --- a/src/models/user.go +++ b/src/models/user.go @@ -10,9 +10,10 @@ var UserType = reflect.TypeOf(User{}) type UserStatus int const ( - UserStatusInactive UserStatus = iota + 1 - UserStatusActive - UserStatusBanned + UserStatusInactive UserStatus = 1 // Default for new users + UserStatusConfirmed = 2 // Confirmed email address + UserStatusApproved = 3 // Approved by an admin and allowed to publicly post + UserStatusBanned = 4 // BALEETED ) type User struct { @@ -54,5 +55,5 @@ func (u *User) BestName() string { } func (u *User) IsActive() bool { - return u.Status == UserStatusActive + return u.Status == UserStatusConfirmed } diff --git a/src/perf/perf.go b/src/perf/perf.go index a7af05f..ee0f55c 100644 --- a/src/perf/perf.go +++ b/src/perf/perf.go @@ -27,12 +27,20 @@ func MakeNewRequestPerf(route string, method string, path string) *RequestPerf { } func (rp *RequestPerf) EndRequest() { + if rp == nil { + return + } + for rp.EndBlock() { } rp.End = time.Now() } func (rp *RequestPerf) Checkpoint(category, description string) { + if rp == nil { + return + } + now := time.Now() checkpoint := PerfBlock{ Start: now, @@ -44,6 +52,10 @@ func (rp *RequestPerf) Checkpoint(category, description string) { } func (rp *RequestPerf) StartBlock(category, description string) { + if rp == nil { + return + } + now := time.Now() checkpoint := PerfBlock{ Start: now, @@ -55,6 +67,10 @@ func (rp *RequestPerf) StartBlock(category, description string) { } func (rp *RequestPerf) EndBlock() bool { + if rp == nil { + return false + } + for i := len(rp.Blocks) - 1; i >= 0; i -= 1 { if rp.Blocks[i].End.Equal(time.Time{}) { rp.Blocks[i].End = time.Now() @@ -65,6 +81,10 @@ func (rp *RequestPerf) EndBlock() bool { } func (rp *RequestPerf) MsFromStart(block *PerfBlock) float64 { + if rp == nil { + return 0 + } + return float64(block.Start.Sub(rp.Start).Nanoseconds()) / 1000 / 1000 } diff --git a/src/website/auth.go b/src/website/auth.go index 57a6c30..5b208fd 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -77,7 +77,7 @@ func Login(c *RequestContext) ResponseData { userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE LOWER(username) = LOWER($1)", username) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return showLoginWithFailure(c, redirect) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) @@ -172,7 +172,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { username, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { userAlreadyExists = false } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) @@ -193,7 +193,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { emailAddress, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { emailAlreadyExists = false } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) @@ -372,7 +372,7 @@ func EmailConfirmationSubmit(c *RequestContext) ResponseData { SET status = $1 WHERE id = $2 `, - models.UserStatusActive, + models.UserStatusConfirmed, validationResult.User.ID, ) if err != nil { @@ -459,7 +459,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } @@ -482,7 +482,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) } } @@ -644,7 +644,7 @@ func DoPasswordResetSubmit(c *RequestContext) ResponseData { SET status = $1 WHERE id = $2 `, - models.UserStatusActive, + models.UserStatusConfirmed, validationResult.User.ID, ) if err != nil { @@ -786,7 +786,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, ) var result validateUserAndTokenResult if err != nil { - if !errors.Is(err, db.ErrNoMatchingRows) { + if !errors.Is(err, db.NotFound) { result.Error = oops.New(err, "failed to fetch user and token from db") return result } diff --git a/src/website/blogs.go b/src/website/blogs.go index 0f01fa7..2ff59ad 100644 --- a/src/website/blogs.go +++ b/src/website/blogs.go @@ -1,6 +1,7 @@ package website import ( + "errors" "fmt" "html/template" "net/http" @@ -34,21 +35,10 @@ func BlogIndex(c *RequestContext) ResponseData { const postsPerPage = 5 - c.Perf.StartBlock("SQL", "Fetch count of posts") - numPosts, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM - handmade_thread - WHERE - project_id = $1 - AND type = $2 - AND NOT deleted - `, - c.CurrentProject.ID, - models.ThreadTypeProjectBlogPost, - ) - c.Perf.EndBlock() + numPosts, err := CountPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch total number of blog posts")) } @@ -59,48 +49,24 @@ func BlogIndex(c *RequestContext) ResponseData { c.Redirect(hmnurl.BuildBlog(c.CurrentProject.Slug, page), http.StatusSeeOther) } - type blogIndexQuery struct { - Thread models.Thread `db:"thread"` - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - } - c.Perf.StartBlock("SQL", "Fetch blog posts") - postsResult, err := db.Query(c.Context(), c.Conn, blogIndexQuery{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS post ON thread.first_id = post.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - WHERE - post.project_id = $1 - AND post.thread_type = $2 - AND NOT thread.deleted - ORDER BY post.postdate DESC - LIMIT $3 OFFSET $4 - `, - c.CurrentProject.ID, - models.ThreadTypeProjectBlogPost, - postsPerPage, - (page-1)*postsPerPage, - ) - c.Perf.EndBlock() + threads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + Limit: postsPerPage, + Offset: (page - 1) * postsPerPage, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch blog posts for index")) } var entries []blogIndexEntry - for _, irow := range postsResult.ToSlice() { - row := irow.(*blogIndexQuery) - + for _, thread := range threads { entries = append(entries, blogIndexEntry{ - Title: row.Thread.Title, - Url: hmnurl.BuildBlogThread(c.CurrentProject.Slug, row.Thread.ID, row.Thread.Title), - Author: templates.UserToTemplate(row.Author, c.Theme), - Date: row.Post.PostDate, - Content: template.HTML(row.CurrentVersion.TextParsed), + Title: thread.Thread.Title, + Url: hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.Thread.ID, thread.Thread.Title), + Author: templates.UserToTemplate(thread.FirstPostAuthor, c.Theme), + Date: thread.FirstPost.PostDate, + Content: template.HTML(thread.FirstPostCurrentVersion.TextParsed), }) } @@ -158,12 +124,15 @@ func BlogThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread, posts, preview := FetchThreadPostsAndStuff( - c.Context(), - c.Conn, - cd.ThreadID, - 0, 0, - ) + thread, posts, err := FetchThreadPosts(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for blog thread")) + } var templatePosts []templates.Post for _, p := range posts { @@ -202,7 +171,7 @@ func BlogThread(c *RequestContext) ResponseData { baseData := getBaseData(c, thread.Title, []templates.Breadcrumb{BlogBreadcrumb(c.CurrentProject.Slug)}) baseData.OpenGraphItems = append(baseData.OpenGraphItems, templates.OpenGraphItem{ Property: "og:description", - Value: preview, + Value: posts[0].Post.Preview, }) var res ResponseData @@ -223,9 +192,17 @@ func BlogPostRedirectToThread(c *RequestContext) ResponseData { return FourOhFour(c) } - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread for blog redirect")) + } - threadUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, thread.Title, cd.PostID) + threadUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, thread.Thread.Title, cd.PostID) return c.Redirect(threadUrl, http.StatusFound) } @@ -305,24 +282,32 @@ func BlogPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to edit")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Editing \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Editing \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } else { title = fmt.Sprintf("Editing Post | %s", c.CurrentProject.Name) } baseData := getBaseData( c, title, - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - editData := getEditorDataForEdit(c.CurrentUser, baseData, postData) + editData := getEditorDataForEdit(c.CurrentUser, baseData, post) editData.SubmitUrl = hmnurl.BuildBlogPostEdit(c.CurrentProject.Slug, cd.ThreadID, cd.PostID) editData.SubmitLabel = "Submit Edited Post" - if postData.Thread.FirstID != postData.Post.ID { + if post.Thread.FirstID != post.Post.ID { editData.SubmitLabel = "Submit Edited Comment" } @@ -347,27 +332,35 @@ func BlogPostEditSubmit(c *RequestContext) ResponseData { } defer tx.Rollback(c.Context()) - postData := FetchPostAndStuff(c.Context(), tx, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to submit edits")) + } c.Req.ParseForm() title := c.Req.Form.Get("title") unparsed := c.Req.Form.Get("body") editReason := c.Req.Form.Get("editreason") - if title != "" && postData.Thread.FirstID != postData.Post.ID { + if title != "" && post.Thread.FirstID != post.Post.ID { return RejectRequest(c, "You can only edit the title by editing the first post.") } if unparsed == "" { return RejectRequest(c, "You must provide a post body.") } - CreatePostVersion(c.Context(), tx, postData.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) + CreatePostVersion(c.Context(), tx, post.Post.ID, unparsed, c.Req.Host, editReason, &c.CurrentUser.ID) err = tx.Commit(c.Context()) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to edit blog post")) } - postUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, postData.Thread.Title, cd.PostID) + postUrl := hmnurl.BuildBlogThreadWithPostHash(c.CurrentProject.Slug, cd.ThreadID, post.Thread.Title, cd.PostID) return c.Redirect(postUrl, http.StatusSeeOther) } @@ -377,16 +370,24 @@ func BlogPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post for reply")) + } baseData := getBaseData( c, - fmt.Sprintf("Replying to comment in \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name), - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + fmt.Sprintf("Replying to comment in \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - replyPost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - replyPost.AddContentVersion(postData.CurrentVersion, postData.Editor) + replyPost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + replyPost.AddContentVersion(post.CurrentVersion, post.Editor) editData := getEditorDataForNew(c.CurrentUser, baseData, &replyPost) editData.SubmitUrl = hmnurl.BuildBlogPostReply(c.CurrentProject.Slug, cd.ThreadID, cd.PostID) @@ -439,23 +440,30 @@ func BlogPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get blog post to delete")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Deleting \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Deleting \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } else { - title = fmt.Sprintf("Deleting comment in \"%s\" | %s", postData.Thread.Title, c.CurrentProject.Name) + title = fmt.Sprintf("Deleting comment in \"%s\" | %s", post.Thread.Title, c.CurrentProject.Name) } baseData := getBaseData( c, title, - BlogThreadBreadcrumbs(c.CurrentProject.Slug, &postData.Thread), + BlogThreadBreadcrumbs(c.CurrentProject.Slug, &post.Thread), ) - // TODO(ben): Set breadcrumbs - templatePost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - templatePost.AddContentVersion(postData.CurrentVersion, postData.Editor) + templatePost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + templatePost.AddContentVersion(post.CurrentVersion, post.Editor) type blogPostDeleteData struct { templates.BaseData @@ -499,8 +507,16 @@ func BlogPostDeleteSubmit(c *RequestContext) ResponseData { projectUrl := hmnurl.BuildProjectHomepage(c.CurrentProject.Slug) return c.Redirect(projectUrl, http.StatusSeeOther) } else { - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) - threadUrl := hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.ID, thread.Title) + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + }) + if errors.Is(err, db.NotFound) { + panic(oops.New(err, "the thread was supposedly not deleted after deleting a post in a blog, but the thread was not found afterwards")) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread after blog post delete")) + } + threadUrl := hmnurl.BuildBlogThread(c.CurrentProject.Slug, thread.Thread.ID, thread.Thread.Title) return c.Redirect(threadUrl, http.StatusSeeOther) } } diff --git a/src/website/discord.go b/src/website/discord.go index 9a75fba..d5439ee 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -97,7 +97,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { c.CurrentUser.ID, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return c.Redirect(hmnurl.BuildUserSettings("discord"), http.StatusSeeOther) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink")) @@ -134,7 +134,7 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { `SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`, c.CurrentUser.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // Nothing to do c.Logger.Warn().Msg("could not do showcase backlog because no discord user exists") return c.Redirect(hmnurl.BuildUserProfile(c.CurrentUser.Username), http.StatusSeeOther) diff --git a/src/website/forums.go b/src/website/forums.go index ed3f9ec..35debc8 100644 --- a/src/website/forums.go +++ b/src/website/forums.go @@ -70,7 +70,7 @@ func getEditorDataForNew(currentUser *models.User, baseData templates.BaseData, return result } -func getEditorDataForEdit(currentUser *models.User, baseData templates.BaseData, p postAndRelatedModels) editorData { +func getEditorDataForEdit(currentUser *models.User, baseData templates.BaseData, p PostAndStuff) editorData { return editorData{ BaseData: baseData, Title: p.Thread.Title, @@ -92,21 +92,14 @@ func Forum(c *RequestContext) ResponseData { currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) - c.Perf.StartBlock("SQL", "Fetch count of page threads") - numThreads, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_thread AS thread - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - `, - cd.SubforumID, - ) + numThreads, err := CountThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{cd.SubforumID}, + }) if err != nil { panic(oops.New(err, "failed to get count of threads")) } - c.Perf.EndBlock() numPages := utils.NumPages(numThreads, threadsPerPage) page, ok := ParsePageNumber(c, "page", numPages) @@ -115,78 +108,28 @@ func Forum(c *RequestContext) ResponseData { } howManyThreadsToSkip := (page - 1) * threadsPerPage - var currentUserId *int - if c.CurrentUser != nil { - currentUserId = &c.CurrentUser.ID - } - - c.Perf.StartBlock("SQL", "Fetch page threads") - type threadQueryResult struct { - Thread models.Thread `db:"thread"` - FirstPost models.Post `db:"firstpost"` - LastPost models.Post `db:"lastpost"` - FirstUser *models.User `db:"firstuser"` - LastUser *models.User `db:"lastuser"` - ThreadLastReadTime *time.Time `db:"tlri.lastread"` - ForumLastReadTime *time.Time `db:"slri.lastread"` - } - itMainThreads, err := db.Query(c.Context(), c.Conn, threadQueryResult{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS firstpost ON thread.first_id = firstpost.id - JOIN handmade_post AS lastpost ON thread.last_id = lastpost.id - LEFT JOIN auth_user AS firstuser ON firstpost.author_id = firstuser.id - LEFT JOIN auth_user AS lastuser ON lastpost.author_id = lastuser.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = thread.id - AND tlri.user_id = $2 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = $1 - AND slri.user_id = $2 - ) - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - ORDER BY lastpost.postdate DESC - LIMIT $3 OFFSET $4 - `, - cd.SubforumID, - currentUserId, - threadsPerPage, - howManyThreadsToSkip, - ) - if err != nil { - panic(oops.New(err, "failed to fetch threads")) - } - c.Perf.EndBlock() - defer itMainThreads.Close() - - makeThreadListItem := func(row *threadQueryResult) templates.ThreadListItem { - hasRead := false - if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.LastPost.PostDate) { - hasRead = true - } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.LastPost.PostDate) { - hasRead = true - } + mainThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{cd.SubforumID}, + Limit: threadsPerPage, + Offset: howManyThreadsToSkip, + }) + makeThreadListItem := func(row ThreadAndStuff) templates.ThreadListItem { return templates.ThreadListItem{ Title: row.Thread.Title, Url: hmnurl.BuildForumThread(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(*row.Thread.SubforumID), row.Thread.ID, row.Thread.Title, 1), - FirstUser: templates.UserToTemplate(row.FirstUser, c.Theme), + FirstUser: templates.UserToTemplate(row.FirstPostAuthor, c.Theme), FirstDate: row.FirstPost.PostDate, - LastUser: templates.UserToTemplate(row.LastUser, c.Theme), + LastUser: templates.UserToTemplate(row.LastPostAuthor, c.Theme), LastDate: row.LastPost.PostDate, - - Unread: !hasRead, + Unread: row.Unread, } } var threads []templates.ThreadListItem - for _, irow := range itMainThreads.ToSlice() { - row := irow.(*threadQueryResult) + for _, row := range mainThreads { threads = append(threads, makeThreadListItem(row)) } @@ -199,59 +142,25 @@ func Forum(c *RequestContext) ResponseData { subforumNodes := cd.SubforumTree[cd.SubforumID].Children for _, sfNode := range subforumNodes { - c.Perf.StartBlock("SQL", "Fetch count of subforum threads") - numThreads, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_thread AS thread - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - `, - sfNode.ID, - ) + numThreads, err := CountThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{sfNode.ID}, + }) if err != nil { panic(oops.New(err, "failed to get count of threads")) } - c.Perf.EndBlock() - c.Perf.StartBlock("SQL", "Fetch subforum threads") - itThreads, err := db.Query(c.Context(), c.Conn, threadQueryResult{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS firstpost ON thread.first_id = firstpost.id - JOIN handmade_post AS lastpost ON thread.last_id = lastpost.id - LEFT JOIN auth_user AS firstuser ON firstpost.author_id = firstuser.id - LEFT JOIN auth_user AS lastuser ON lastpost.author_id = lastuser.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = thread.id - AND tlri.user_id = $2 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = $1 - AND slri.user_id = $2 - ) - WHERE - thread.subforum_id = $1 - AND NOT thread.deleted - ORDER BY lastpost.postdate DESC - LIMIT 3 - `, - sfNode.ID, - currentUserId, - ) - if err != nil { - panic(err) - } - defer itThreads.Close() - c.Perf.EndBlock() + subforumThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + SubforumIDs: []int{sfNode.ID}, + Limit: 3, + }) var threads []templates.ThreadListItem - for _, irow := range itThreads.ToSlice() { - threadRow := irow.(*threadQueryResult) - threads = append(threads, makeThreadListItem(threadRow)) + for _, row := range subforumThreads { + threads = append(threads, makeThreadListItem(row)) } subforums = append(subforums, forumSubforumData{ @@ -415,7 +324,8 @@ type forumThreadData struct { Pagination templates.Pagination } -var threadViewPostsPerPage = 15 +// How many posts to display on a single page of a forum thread. +var threadPostsPerPage = 15 func ForumThread(c *RequestContext) ResponseData { cd, ok := getCommonForumData(c) @@ -425,22 +335,28 @@ func ForumThread(c *RequestContext) ResponseData { currentSubforumSlugs := cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID) - thread := FetchThread(c.Context(), c.Conn, cd.ThreadID) + threads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadIDs: []int{cd.ThreadID}, + }) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get thread")) + } + if len(threads) == 0 { + return FourOhFour(c) + } + threadResult := threads[0] + thread := threadResult.Thread - numPosts, err := db.QueryInt(c.Context(), c.Conn, - ` - SELECT COUNT(*) - FROM handmade_post - WHERE - thread_id = $1 - AND NOT deleted - `, - thread.ID, - ) + numPosts, err := CountPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + ThreadIDs: []int{cd.ThreadID}, + }) if err != nil { panic(oops.New(err, "failed to get count of posts for thread")) } - page, numPages, ok := getPageInfo(c.PathParams["page"], numPosts, threadViewPostsPerPage) + page, numPages, ok := getPageInfo(c.PathParams["page"], numPosts, threadPostsPerPage) if !ok { urlNoPage := hmnurl.BuildForumThread(c.CurrentProject.Slug, currentSubforumSlugs, thread.ID, thread.Title, 1) return c.Redirect(urlNoPage, http.StatusSeeOther) @@ -455,16 +371,16 @@ func ForumThread(c *RequestContext) ResponseData { PreviousUrl: hmnurl.BuildForumThread(c.CurrentProject.Slug, currentSubforumSlugs, thread.ID, thread.Title, utils.IntClamp(1, page-1, numPages)), } - c.Perf.StartBlock("SQL", "Fetch posts") - _, postsAndStuff, preview := FetchThreadPostsAndStuff( - c.Context(), - c.Conn, - cd.ThreadID, - page, threadViewPostsPerPage, - ) - c.Perf.EndBlock() + postsAndStuff, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadIDs: []int{thread.ID}, + Limit: threadPostsPerPage, + Offset: (page - 1) * threadPostsPerPage, + }) + if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread posts")) + } - c.Perf.StartBlock("TEMPLATE", "Create template posts") var posts []templates.Post for _, p := range postsAndStuff { post := templates.PostToTemplate(&p.Post, p.Author, c.Theme) @@ -481,18 +397,17 @@ func ForumThread(c *RequestContext) ResponseData { posts = append(posts, post) } - c.Perf.EndBlock() // Update thread last read info if c.CurrentUser != nil { c.Perf.StartBlock("SQL", "Update TLRI") _, err = c.Conn.Exec(c.Context(), ` - INSERT INTO handmade_threadlastreadinfo (thread_id, user_id, lastread) - VALUES ($1, $2, $3) - ON CONFLICT (thread_id, user_id) DO UPDATE - SET lastread = EXCLUDED.lastread - `, + INSERT INTO handmade_threadlastreadinfo (thread_id, user_id, lastread) + VALUES ($1, $2, $3) + ON CONFLICT (thread_id, user_id) DO UPDATE + SET lastread = EXCLUDED.lastread + `, cd.ThreadID, c.CurrentUser.ID, time.Now(), @@ -506,7 +421,7 @@ func ForumThread(c *RequestContext) ResponseData { baseData := getBaseData(c, thread.Title, SubforumBreadcrumbs(cd.LineageBuilder, c.CurrentProject, cd.SubforumID)) baseData.OpenGraphItems = append(baseData.OpenGraphItems, templates.OpenGraphItem{ Property: "og:description", - Value: preview, + Value: threadResult.FirstPost.Preview, }) var res ResponseData @@ -527,30 +442,20 @@ func ForumPostRedirect(c *RequestContext) ResponseData { return FourOhFour(c) } - c.Perf.StartBlock("SQL", "Fetch post ids for thread") - type postQuery struct { - PostID int `db:"post.id"` - } - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - WHERE - post.thread_id = $1 - AND NOT post.deleted - ORDER BY postdate - `, - cd.ThreadID, - ) + posts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + ThreadIDs: []int{cd.ThreadID}, + }) if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post ids")) + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for redirect")) } - postQuerySlice := postQueryResult.ToSlice() - c.Perf.EndBlock() + + var post PostAndStuff postIdx := -1 - for i, id := range postQuerySlice { - if id.(*postQuery).PostID == cd.PostID { + for i, p := range posts { + if p.Post.ID == cd.PostID { + post = p postIdx = i break } @@ -559,31 +464,13 @@ func ForumPostRedirect(c *RequestContext) ResponseData { return FourOhFour(c) } - c.Perf.StartBlock("SQL", "Fetch thread title") - type threadTitleQuery struct { - ThreadTitle string `db:"thread.title"` - } - threadTitleQueryResult, err := db.QueryOne(c.Context(), c.Conn, threadTitleQuery{}, - ` - SELECT $columns - FROM handmade_thread AS thread - WHERE thread.id = $1 - `, - cd.ThreadID, - ) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch thread title")) - } - c.Perf.EndBlock() - threadTitle := threadTitleQueryResult.(*threadTitleQuery).ThreadTitle - - page := (postIdx / threadViewPostsPerPage) + 1 + page := (postIdx / threadPostsPerPage) + 1 return c.Redirect(hmnurl.BuildForumThreadWithPostHash( c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, - threadTitle, + post.Thread.Title, page, cd.PostID, ), http.StatusSeeOther) @@ -672,16 +559,24 @@ func ForumPostReply(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for reply")) + } baseData := getBaseData( c, fmt.Sprintf("Replying to post | %s", cd.SubforumTree[cd.SubforumID].Name), - ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread), + ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread), ) - replyPost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - replyPost.AddContentVersion(postData.CurrentVersion, postData.Editor) + replyPost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + replyPost.AddContentVersion(post.CurrentVersion, post.Editor) editData := getEditorDataForNew(c.CurrentUser, baseData, &replyPost) editData.SubmitUrl = hmnurl.BuildForumPostReply(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, cd.PostID) @@ -713,11 +608,17 @@ func ForumPostReplySubmit(c *RequestContext) ResponseData { return RejectRequest(c, "Your reply cannot be empty.") } - thread := FetchThread(c.Context(), tx, cd.ThreadID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } // Replies to the OP should not be considered replies var replyPostId *int - if cd.PostID != thread.FirstID { + if cd.PostID != post.Thread.FirstID { replyPostId = &cd.PostID } @@ -742,17 +643,25 @@ func ForumPostEdit(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for editing")) + } title := "" - if postData.Thread.FirstID == postData.Post.ID { - title = fmt.Sprintf("Editing \"%s\" | %s", postData.Thread.Title, cd.SubforumTree[cd.SubforumID].Name) + if post.Thread.FirstID == post.Post.ID { + title = fmt.Sprintf("Editing \"%s\" | %s", post.Thread.Title, cd.SubforumTree[cd.SubforumID].Name) } else { title = fmt.Sprintf("Editing Post | %s", cd.SubforumTree[cd.SubforumID].Name) } - baseData := getBaseData(c, title, ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread)) + baseData := getBaseData(c, title, ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread)) - editData := getEditorDataForEdit(c.CurrentUser, baseData, postData) + editData := getEditorDataForEdit(c.CurrentUser, baseData, post) editData.SubmitUrl = hmnurl.BuildForumPostEdit(c.CurrentProject.Slug, cd.LineageBuilder.GetSubforumLineageSlugs(cd.SubforumID), cd.ThreadID, cd.PostID) editData.SubmitLabel = "Submit Edited Post" @@ -805,16 +714,24 @@ func ForumPostDelete(c *RequestContext) ResponseData { return FourOhFour(c) } - postData := FetchPostAndStuff(c.Context(), c.Conn, cd.ThreadID, cd.PostID) + post, err := FetchThreadPost(c.Context(), c.Conn, c.CurrentUser, cd.ThreadID, cd.PostID, PostsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeForumPost}, + }) + if errors.Is(err, db.NotFound) { + return FourOhFour(c) + } else if err != nil { + return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch post for delete")) + } baseData := getBaseData( c, - fmt.Sprintf("Deleting post in \"%s\" | %s", postData.Thread.Title, cd.SubforumTree[cd.SubforumID].Name), - ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &postData.Thread), + fmt.Sprintf("Deleting post in \"%s\" | %s", post.Thread.Title, cd.SubforumTree[cd.SubforumID].Name), + ForumThreadBreadcrumbs(cd.LineageBuilder, c.CurrentProject, &post.Thread), ) - templatePost := templates.PostToTemplate(&postData.Post, postData.Author, c.Theme) - templatePost.AddContentVersion(postData.CurrentVersion, postData.Editor) + templatePost := templates.PostToTemplate(&post.Post, post.Author, c.Theme) + templatePost.AddContentVersion(post.CurrentVersion, post.Editor) type forumPostDeleteData struct { templates.BaseData @@ -870,31 +787,22 @@ func WikiArticleRedirect(c *RequestContext) ResponseData { panic(err) } - ithread, err := db.QueryOne(c.Context(), c.Conn, models.Thread{}, - ` - SELECT $columns - FROM handmade_thread - WHERE - id = $1 - AND project_id = $2 - AND NOT deleted - `, - threadId, - c.CurrentProject.ID, - ) - if errors.Is(err, db.ErrNoMatchingRows) { + thread, err := FetchThread(c.Context(), c.Conn, c.CurrentUser, threadId, ThreadsQuery{ + ProjectIDs: []int{c.CurrentProject.ID}, + // This is the rare query where we want all thread types! + }) + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up wiki thread")) } - thread := ithread.(*models.Thread) c.Perf.StartBlock("SQL", "Fetch subforum tree") subforumTree := models.GetFullSubforumTree(c.Context(), c.Conn) lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() - dest := UrlForGenericThread(thread, lineageBuilder, c.CurrentProject.Slug) + dest := UrlForGenericThread(&thread.Thread, lineageBuilder, c.CurrentProject.Slug) return c.Redirect(dest, http.StatusFound) } @@ -911,9 +819,11 @@ type commonForumData struct { /* Gets data that is used on basically every forums-related route. Parses path params for subforum, -thread, and post ids and validates that all those resources do in fact exist. +thread, and post ids. -Returns false if any data is invalid and you should return a 404. +Does NOT validate that the requested thread and post ID are valid. + +If this returns false, then something was malformed and you should 404. */ func getCommonForumData(c *RequestContext) (commonForumData, bool) { c.Perf.StartBlock("FORUMS", "Fetch common forum data") @@ -936,8 +846,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.SubforumID = sfId - - // No need to validate that subforum exists here; it's handled by validateSubforums. } if threadIdStr, hasThreadId := c.PathParams["threadid"]; hasThreadId { @@ -946,27 +854,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.ThreadID = threadId - - c.Perf.StartBlock("SQL", "Verify that the thread exists") - threadExists, err := db.QueryBool(c.Context(), c.Conn, - ` - SELECT COUNT(*) > 0 - FROM handmade_thread - WHERE - id = $1 - AND subforum_id = $2 - AND NOT deleted - `, - res.ThreadID, - res.SubforumID, - ) - c.Perf.EndBlock() - if err != nil { - panic(err) - } - if !threadExists { - return commonForumData{}, false - } } if postIdStr, hasPostId := c.PathParams["postid"]; hasPostId { @@ -975,27 +862,6 @@ func getCommonForumData(c *RequestContext) (commonForumData, bool) { return commonForumData{}, false } res.PostID = postId - - c.Perf.StartBlock("SQL", "Verify that the post exists") - postExists, err := db.QueryBool(c.Context(), c.Conn, - ` - SELECT COUNT(*) > 0 - FROM handmade_post - WHERE - id = $1 - AND thread_id = $2 - AND NOT deleted - `, - res.PostID, - res.ThreadID, - ) - c.Perf.EndBlock() - if err != nil { - panic(err) - } - if !postExists { - return commonForumData{}, false - } } return res, true @@ -1042,6 +908,8 @@ func addForumUrlsToPost(p *templates.Post, projectSlug string, subforums []strin p.ReplyUrl = hmnurl.BuildForumPostReply(projectSlug, subforums, threadId, postId) } +// Takes a template post and adds information about how many posts the user has made +// on the site. func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) { numPosts, err := db.QueryInt(ctx, conn, ` diff --git a/src/website/landing.go b/src/website/landing.go index fef6044..badc34b 100644 --- a/src/website/landing.go +++ b/src/website/landing.go @@ -81,56 +81,22 @@ func Index(c *RequestContext) ResponseData { lineageBuilder := models.MakeSubforumLineageBuilder(subforumTree) c.Perf.EndBlock() - var currentUserId *int - if c.CurrentUser != nil { - currentUserId = &c.CurrentUser.ID - } - c.Perf.StartBlock("LANDING", "Process projects") for _, projRow := range allProjects { proj := projRow.(*models.Project) c.Perf.StartBlock("SQL", fmt.Sprintf("Fetch posts for %s", proj.Name)) - type projectPostQuery struct { - Post models.Post `db:"post"` - Thread models.Thread `db:"thread"` - User models.User `db:"auth_user"` - ThreadLastReadTime *time.Time `db:"tlri.lastread"` - SubforumLastReadTime *time.Time `db:"slri.lastread"` - } - projectPostIter, err := db.Query(c.Context(), c.Conn, projectPostQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - JOIN handmade_thread AS thread ON post.thread_id = thread.id - LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( - tlri.thread_id = post.thread_id - AND tlri.user_id = $1 - ) - LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( - slri.subforum_id = thread.subforum_id - AND slri.user_id = $1 - ) - LEFT JOIN auth_user ON post.author_id = auth_user.id - WHERE - post.project_id = $2 - AND post.thread_type = ANY ($3) - AND post.deleted = FALSE - ORDER BY postdate DESC - LIMIT $4 - `, - currentUserId, - proj.ID, - []models.ThreadType{models.ThreadTypeProjectBlogPost, models.ThreadTypeForumPost}, - maxPosts, - ) + projectPosts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + ProjectIDs: []int{proj.ID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost, models.ThreadTypeForumPost}, + Limit: maxPosts, + SortDescending: true, + }) c.Perf.EndBlock() if err != nil { c.Logger.Error().Err(err).Msg("failed to fetch project posts") continue } - projectPosts := projectPostIter.ToSlice() forumsUrl := "" if proj.ForumID != nil { @@ -144,18 +110,7 @@ func Index(c *RequestContext) ResponseData { ForumsUrl: forumsUrl, } - for _, projectPostRow := range projectPosts { - projectPost := projectPostRow.(*projectPostQuery) - - hasRead := false - if c.CurrentUser != nil && c.CurrentUser.MarkedAllReadAt.After(projectPost.Post.PostDate) { - hasRead = true - } else if projectPost.ThreadLastReadTime != nil && projectPost.ThreadLastReadTime.After(projectPost.Post.PostDate) { - hasRead = true - } else if projectPost.SubforumLastReadTime != nil && projectPost.SubforumLastReadTime.After(projectPost.Post.PostDate) { - hasRead = true - } - + for _, projectPost := range projectPosts { featurable := (!proj.IsHMN() && projectPost.Post.ThreadType == models.ThreadTypeProjectBlogPost && projectPost.Thread.FirstID == projectPost.Post.ID && @@ -185,9 +140,9 @@ func Index(c *RequestContext) ResponseData { landingPageProject.FeaturedPost = &LandingPageFeaturedPost{ Title: projectPost.Thread.Title, Url: hmnurl.BuildBlogThread(proj.Slug, projectPost.Thread.ID, projectPost.Thread.Title), - User: templates.UserToTemplate(&projectPost.User, c.Theme), + User: templates.UserToTemplate(projectPost.Author, c.Theme), Date: projectPost.Post.PostDate, - Unread: !hasRead, + Unread: projectPost.ThreadUnread, Content: template.HTML(content), } } else { @@ -198,8 +153,8 @@ func Index(c *RequestContext) ResponseData { proj, &projectPost.Thread, &projectPost.Post, - &projectPost.User, - !hasRead, + projectPost.Author, + projectPost.ThreadUnread, false, c.Theme, ), @@ -248,35 +203,18 @@ func Index(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Get news") - type newsPostQuery struct { - Post models.Post `db:"post"` - PostVersion models.PostVersion `db:"ver"` - Thread models.Thread `db:"thread"` - User models.User `db:"auth_user"` - } - newsPostRow, err := db.QueryOne(c.Context(), c.Conn, newsPostQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - JOIN handmade_thread AS thread ON post.thread_id = thread.id - JOIN auth_user ON post.author_id = auth_user.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - WHERE - post.project_id = $1 - AND thread.type = $2 - AND post.id = thread.first_id - AND NOT thread.deleted - ORDER BY post.postdate DESC - LIMIT 1 - `, - models.HMNProjectID, - models.ThreadTypeProjectBlogPost, - ) + newsThreads, err := FetchThreads(c.Context(), c.Conn, c.CurrentUser, ThreadsQuery{ + ProjectIDs: []int{models.HMNProjectID}, + ThreadTypes: []models.ThreadType{models.ThreadTypeProjectBlogPost}, + Limit: 1, + }) if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch news post")) } - newsPostResult := newsPostRow.(*newsPostQuery) + var newsThread ThreadAndStuff + if len(newsThreads) > 0 { + newsThread = newsThreads[0] + } c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetch showcase snippets") @@ -334,12 +272,12 @@ func Index(c *RequestContext) ResponseData { ShowUrl: "https://handmadedev.show/", ShowcaseUrl: hmnurl.BuildShowcase(), NewsPost: LandingPageFeaturedPost{ - Title: newsPostResult.Thread.Title, - Url: hmnurl.BuildBlogThread(models.HMNProjectSlug, newsPostResult.Thread.ID, newsPostResult.Thread.Title), - User: templates.UserToTemplate(&newsPostResult.User, c.Theme), - Date: newsPostResult.Post.PostDate, + Title: newsThread.Thread.Title, + Url: hmnurl.BuildBlogThread(models.HMNProjectSlug, newsThread.Thread.ID, newsThread.Thread.Title), + User: templates.UserToTemplate(newsThread.FirstPostAuthor, c.Theme), + Date: newsThread.FirstPost.PostDate, Unread: true, - Content: template.HTML(newsPostResult.PostVersion.TextParsed), + Content: template.HTML(newsThread.FirstPostCurrentVersion.TextParsed), }, PostColumns: cols, ShowcaseTimelineJson: showcaseJson, diff --git a/src/website/podcast.go b/src/website/podcast.go index a8c21ff..40feb4a 100644 --- a/src/website/podcast.go +++ b/src/website/podcast.go @@ -565,7 +565,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return result, nil } else { return result, oops.New(err, "failed to fetch podcast") @@ -615,7 +615,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return result, nil } else { return result, oops.New(err, "failed to fetch podcast episode") diff --git a/src/website/projects.go b/src/website/projects.go index 9eeeb5a..de55395 100644 --- a/src/website/projects.go +++ b/src/website/projects.go @@ -234,7 +234,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch project by slug")) diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index 208efd3..6ca84ec 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -128,10 +128,15 @@ type RequestContext struct { Theme string Perf *perf.RequestPerf + + ctx context.Context } func (c *RequestContext) Context() context.Context { - return c.Req.Context() + if c.ctx == nil { + c.ctx = c.Req.Context() + } + return c.ctx } func (c *RequestContext) URL() *url.URL { diff --git a/src/website/routes.go b/src/website/routes.go index 5c358de..95ff9e9 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -376,7 +376,7 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* if err == nil { subdomainProject := subdomainProjectRow.(*models.Project) return subdomainProject, nil - } else if !errors.Is(err, db.ErrNoMatchingRows) { + } else if !errors.Is(err, db.NotFound) { return nil, oops.New(err, "failed to get projects by slug") } else { return nil, nil @@ -384,7 +384,7 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* } else { defaultProjectRow, err := db.QueryOne(ctx, conn, models.Project{}, "SELECT $columns FROM handmade_project WHERE id = $1", models.HMNProjectID) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return nil, oops.New(nil, "default project didn't exist in the database") } else { return nil, oops.New(err, "failed to get default project") @@ -546,7 +546,7 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found") return nil, nil, nil // user was deleted or something } else { @@ -558,8 +558,12 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User return user, session, nil } +const PerfContextKey = "HMNPerf" + func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (after func()) { c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path) + c.ctx = context.WithValue(c.Context(), PerfContextKey, c.Perf) + return func() { c.Perf.EndRequest() log := logging.Info() @@ -576,6 +580,14 @@ func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (aft } } +func ExtractPerf(ctx context.Context) *perf.RequestPerf { + iperf := ctx.Value(PerfContextKey) + if iperf == nil { + return nil + } + return iperf.(*perf.RequestPerf) +} + func LogContextErrors(c *RequestContext, errs ...error) { for _, err := range errs { c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request") diff --git a/src/website/snippet.go b/src/website/snippet.go index a539723..1d9ce7f 100644 --- a/src/website/snippet.go +++ b/src/website/snippet.go @@ -50,7 +50,7 @@ func Snippet(c *RequestContext) ResponseData { snippetId, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch snippet")) diff --git a/src/website/threads_and_posts_helper.go b/src/website/threads_and_posts_helper.go index 87c6008..ea65786 100644 --- a/src/website/threads_and_posts_helper.go +++ b/src/website/threads_and_posts_helper.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math" "net" "strings" "time" @@ -19,198 +18,574 @@ import ( "github.com/jackc/pgx/v4" ) -type postAndRelatedModels struct { - Thread models.Thread - Post models.Post - CurrentVersion models.PostVersion +type ThreadsQuery struct { + // Available on all thread queries. + ProjectIDs []int // if empty, all projects + ThreadTypes []models.ThreadType // if empty, all types (you do not want to do this) + SubforumIDs []int // if empty, all subforums - Author *models.User - Editor *models.User + // Ignored when using FetchThread. + ThreadIDs []int - ReplyPost *models.Post - ReplyAuthor *models.User + // Ignored when using FetchThread or CountThreads. + Limit, Offset int // if empty, no pagination +} + +type ThreadAndStuff struct { + Project models.Project `db:"project"` + Thread models.Thread `db:"thread"` + FirstPost models.Post `db:"first_post"` + LastPost models.Post `db:"last_post"` + FirstPostCurrentVersion models.PostVersion `db:"first_version"` + LastPostCurrentVersion models.PostVersion `db:"last_version"` + FirstPostAuthor *models.User `db:"first_author"` // Can be nil in case of a deleted user + LastPostAuthor *models.User `db:"last_author"` // Can be nil in case of a deleted user + Unread bool } /* -Fetches the thread defined by your (already parsed) path params. - -YOU MUST VERIFY THAT THE THREAD ID IS VALID BEFORE CALLING THIS FUNCTION. It will -not check, for example, that the thread belongs to the correct subforum. +Fetches threads and related models from the database according to all the given +query params. For the most correct results, provide as much information as you have +on hand. */ -func FetchThread(ctx context.Context, connOrTx db.ConnOrTx, threadId int) models.Thread { - type threadQueryResult struct { - Thread models.Thread `db:"thread"` +func FetchThreads( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q ThreadsQuery, +) ([]ThreadAndStuff, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Fetch threads") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID } - irow, err := db.QueryOne(ctx, connOrTx, threadQueryResult{}, + + qb.Add( ` SELECT $columns FROM handmade_thread AS thread + JOIN handmade_project AS project ON thread.project_id = project.id + JOIN handmade_post AS first_post ON first_post.id = thread.first_id + JOIN handmade_post AS last_post ON last_post.id = thread.last_id + JOIN handmade_postversion AS first_version ON first_version.id = first_post.current_id + JOIN handmade_postversion AS last_version ON last_version.id = last_post.current_id + LEFT JOIN auth_user AS first_author ON first_author.id = first_post.author_id + LEFT JOIN auth_user AS last_author ON last_author.id = last_post.author_id + LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( + tlri.thread_id = thread.id + AND tlri.user_id = $? + ) + LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( + slri.subforum_id = thread.subforum_id + AND slri.user_id = $? + ) WHERE - id = $1 - AND NOT deleted + NOT thread.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) `, - threadId, + currentUserID, + currentUserID, + models.VisibleProjectLifecycles, + models.HMNProjectID, ) - if err != nil { - // We shouldn't encounter db.ErrNoMatchingRows, because validation should have verified that everything exists. - panic(oops.New(err, "failed to fetch thread")) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.SubforumIDs) > 0 { + qb.Add(`AND thread.subforum_id = ANY ($?)`, q.SubforumIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND thread.id = ANY ($?)`, q.ThreadIDs) + } + if currentUser == nil { + qb.Add( + `AND first_author.status = $? -- thread author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + first_author.status = $? -- thread author is Approved + OR first_author.id = $? -- current user is the thread author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + qb.Add(`ORDER BY last_post.postdate DESC`) + if q.Limit > 0 { + qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - thread := irow.(*threadQueryResult).Thread - return thread -} - -/* -Fetches the post, the thread, and author / editor information for the post defined in -your path params. - -YOU MUST VERIFY THAT THE THREAD ID AND POST ID ARE VALID BEFORE CALLING THIS FUNCTION. -It will not check that the post belongs to the correct subforum, for example, or the -correct project blog. This logic varies per route and per use of threads, so it doesn't -happen here. -*/ -func FetchPostAndStuff( - ctx context.Context, - connOrTx db.ConnOrTx, - threadId, postId int, -) postAndRelatedModels { type resultRow struct { - Thread models.Thread `db:"thread"` - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - Editor *models.User `db:"editor"` - ReplyPost *models.Post `db:"reply"` - ReplyAuthor *models.User `db:"reply_author"` - } - postQueryResult, err := db.QueryOne(ctx, connOrTx, resultRow{}, - ` - SELECT $columns - FROM - handmade_thread AS thread - JOIN handmade_post AS post ON post.thread_id = thread.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id - LEFT JOIN handmade_post AS reply ON post.reply_id = reply.id - LEFT JOIN auth_user AS reply_author ON reply.author_id = reply_author.id - WHERE - post.thread_id = $1 - AND post.id = $2 - AND NOT post.deleted - `, - threadId, - postId, - ) - if err != nil { - // We shouldn't encounter db.ErrNoMatchingRows, because validation should have verified that everything exists. - panic(oops.New(err, "failed to fetch post and related data")) + ThreadAndStuff + ThreadLastReadTime *time.Time `db:"tlri.lastread"` + ForumLastReadTime *time.Time `db:"slri.lastread"` } - result := postQueryResult.(*resultRow) - return postAndRelatedModels{ - Thread: result.Thread, - Post: result.Post, - CurrentVersion: result.CurrentVersion, - Author: result.Author, - Editor: result.Editor, - ReplyPost: result.ReplyPost, - ReplyAuthor: result.ReplyAuthor, + it, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + if err != nil { + return nil, oops.New(err, "failed to fetch threads") } + iresults := it.ToSlice() + + result := make([]ThreadAndStuff, len(iresults)) + for i, iresult := range iresults { + row := *iresult.(*resultRow) + + hasRead := false + if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.LastPost.PostDate) { + hasRead = true + } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.LastPost.PostDate) { + hasRead = true + } + row.Unread = !hasRead + + result[i] = row.ThreadAndStuff + } + + return result, nil } /* -Fetches all the posts (and related models) for a given thread. +Fetches a single thread and related data. A wrapper around FetchThreads. +As with FetchThreads, provide as much information as you know to get the +most correct results. -YOU MUST VERIFY THAT THE THREAD ID IS VALID BEFORE CALLING THIS FUNCTION. It will -not check, for example, that the thread belongs to the correct subforum. +Returns db.NotFound if no result is found. */ -func FetchThreadPostsAndStuff( +func FetchThread( ctx context.Context, - connOrTx db.ConnOrTx, - threadId int, - page, postsPerPage int, -) (models.Thread, []postAndRelatedModels, string) { - limit := postsPerPage - offset := (page - 1) * postsPerPage - if postsPerPage == 0 { - limit = math.MaxInt32 - offset = 0 + dbConn db.ConnOrTx, + currentUser *models.User, + threadID int, + q ThreadsQuery, +) (ThreadAndStuff, error) { + q.ThreadIDs = []int{threadID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchThreads(ctx, dbConn, currentUser, q) + if err != nil { + return ThreadAndStuff{}, oops.New(err, "failed to fetch thread") } - thread := FetchThread(ctx, connOrTx, threadId) - - type postResult struct { - Post models.Post `db:"post"` - CurrentVersion models.PostVersion `db:"ver"` - Author *models.User `db:"author"` - Editor *models.User `db:"editor"` - ReplyPost *models.Post `db:"reply"` - ReplyAuthor *models.User `db:"reply_author"` + if len(res) == 0 { + return ThreadAndStuff{}, db.NotFound } - itPosts, err := db.Query(ctx, connOrTx, postResult{}, + + return res[0], nil +} + +func CountThreads( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q ThreadsQuery, +) (int, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Count threads") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + qb.Add( + ` + SELECT COUNT(*) + FROM + handmade_thread AS thread + JOIN handmade_project AS project ON thread.project_id = project.id + JOIN handmade_post AS first_post ON first_post.id = thread.first_id + LEFT JOIN auth_user AS first_author ON first_author.id = first_post.author_id + WHERE + NOT thread.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) + `, + models.VisibleProjectLifecycles, + models.HMNProjectID, + ) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.SubforumIDs) > 0 { + qb.Add(`AND thread.subforum_id = ANY ($?)`, q.SubforumIDs) + } + if currentUser == nil { + qb.Add( + `AND first_author.status = $? -- thread author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + first_author.status = $? -- thread author is Approved + OR first_author.id = $? -- current user is the thread author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + + count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + if err != nil { + return 0, oops.New(err, "failed to fetch count of threads") + } + + return count, nil +} + +type PostsQuery struct { + // Available on all post queries. + ProjectIDs []int + UserIDs []int + ThreadTypes []models.ThreadType + + // Ignored when using FetchPost. + ThreadIDs []int + PostIDs []int + + // Ignored when using FetchPost or CountPosts. + Limit, Offset int + SortDescending bool +} + +type PostAndStuff struct { + Project models.Project `db:"project"` + Thread models.Thread `db:"thread"` + ThreadUnread bool + Post models.Post `db:"post"` + CurrentVersion models.PostVersion `db:"ver"` + Author *models.User `db:"author"` // Can be nil in case of a deleted user + Editor *models.User `db:"editor"` + ReplyPost *models.Post `db:"reply_post"` + ReplyAuthor *models.User `db:"reply_author"` +} + +/* +Fetches posts and related models from the database according to all the given +query params. For the most correct results, provide as much information as you have +on hand. +*/ +func FetchPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q PostsQuery, +) ([]PostAndStuff, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Fetch posts") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + type resultRow struct { + PostAndStuff + ThreadLastReadTime *time.Time `db:"tlri.lastread"` + ForumLastReadTime *time.Time `db:"slri.lastread"` + } + + qb.Add( ` SELECT $columns - FROM - handmade_post AS post - JOIN handmade_postversion AS ver ON post.current_id = ver.id - LEFT JOIN auth_user AS author ON post.author_id = author.id - LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id - LEFT JOIN handmade_post AS reply ON post.reply_id = reply.id - LEFT JOIN auth_user AS reply_author ON reply.author_id = reply_author.id - WHERE - post.thread_id = $1 - AND NOT post.deleted - ORDER BY post.postdate - LIMIT $2 OFFSET $3 - `, - thread.ID, - limit, - offset, - ) - if err != nil { - panic(oops.New(err, "failed to fetch posts for thread")) - } - defer itPosts.Close() - - var posts []postAndRelatedModels - for { - irow, hasNext := itPosts.Next() - if !hasNext { - break - } - - row := irow.(*postResult) - posts = append(posts, postAndRelatedModels{ - Thread: thread, - Post: row.Post, - CurrentVersion: row.CurrentVersion, - Author: row.Author, - Editor: row.Editor, - ReplyPost: row.ReplyPost, - ReplyAuthor: row.ReplyAuthor, - }) - } - - preview, err := db.QueryString(ctx, connOrTx, - ` - SELECT post.preview FROM handmade_post AS post JOIN handmade_thread AS thread ON post.thread_id = thread.id - JOIN handmade_postversion AS ver ON post.current_id = ver.id + JOIN handmade_project AS project ON post.project_id = project.id + JOIN handmade_postversion AS ver ON ver.id = post.current_id + LEFT JOIN auth_user AS author ON author.id = post.author_id + LEFT JOIN auth_user AS editor ON ver.editor_id = editor.id + LEFT JOIN handmade_threadlastreadinfo AS tlri ON ( + tlri.thread_id = thread.id + AND tlri.user_id = $? + ) + LEFT JOIN handmade_subforumlastreadinfo AS slri ON ( + slri.subforum_id = thread.subforum_id + AND slri.user_id = $? + ) + -- Unconditionally fetch reply info, but make sure to check it + -- later and possibly remove these fields if the permission + -- check fails. + LEFT JOIN handmade_post AS reply_post ON reply_post.id = post.reply_id + LEFT JOIN auth_user AS reply_author ON reply_post.author_id = reply_author.id WHERE - post.thread_id = $1 - AND thread.first_id = post.id + NOT thread.deleted + AND NOT post.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) `, - thread.ID, + currentUserID, + currentUserID, + models.VisibleProjectLifecycles, + models.HMNProjectID, ) - if err != nil && !errors.Is(err, db.ErrNoMatchingRows) { - panic(oops.New(err, "failed to fetch posts for thread")) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.UserIDs) > 0 { + qb.Add(`AND post.author_id = ANY ($?)`, q.UserIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND post.thread_id = ANY ($?)`, q.ThreadIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if len(q.PostIDs) > 0 { + qb.Add(`AND post.id = ANY ($?)`, q.PostIDs) + } + if currentUser == nil { + qb.Add( + `AND author.status = $? -- post author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + author.status = $? -- post author is Approved + OR author.id = $? -- current user is the post author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + qb.Add(`ORDER BY post.postdate`) + if q.SortDescending { + qb.Add(`DESC`) + } + if q.Limit > 0 { + qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - return thread, posts, preview + it, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + if err != nil { + return nil, oops.New(err, "failed to fetch posts") + } + iresults := it.ToSlice() + + result := make([]PostAndStuff, len(iresults)) + for i, iresult := range iresults { + row := *iresult.(*resultRow) + + hasRead := false + if row.ThreadLastReadTime != nil && row.ThreadLastReadTime.After(row.Post.PostDate) { + hasRead = true + } else if row.ForumLastReadTime != nil && row.ForumLastReadTime.After(row.Post.PostDate) { + hasRead = true + } + row.ThreadUnread = !hasRead + + if row.ReplyPost != nil && row.ReplyAuthor != nil { + replyAuthorIsNotApproved := row.ReplyAuthor.Status != models.UserStatusApproved + canSeeUnapprovedReply := currentUser != nil && (row.ReplyAuthor.ID == currentUser.ID || currentUser.IsStaff) + if replyAuthorIsNotApproved && !canSeeUnapprovedReply { + row.ReplyPost = nil + row.ReplyAuthor = nil + } + } + + result[i] = row.PostAndStuff + } + + return result, nil +} + +/* +Fetches posts for a given thread. A convenient wrapper around FetchPosts that returns +the posts and the actual thread model. + +Return db.NotFound if nothing is found (no thread or no posts). +*/ +func FetchThreadPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + threadID int, + q PostsQuery, +) (models.Thread, []PostAndStuff, error) { + q.ThreadIDs = []int{threadID} + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return models.Thread{}, nil, oops.New(err, "failed to fetch posts for thread") + } + + if len(res) == 0 { + // We shouldn't have threads without posts anyway. + return models.Thread{}, nil, db.NotFound + } + + return res[0].Thread, res, nil +} + +/* +Fetches a single post for a thread and its related data. A wrapper +around FetchPosts. As with FetchPosts, provide as much information +as you know to get the most correct results. + +Returns db.NotFound if no result is found. +*/ +func FetchThreadPost( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + threadID, postID int, + q PostsQuery, +) (PostAndStuff, error) { + q.ThreadIDs = []int{threadID} + q.PostIDs = []int{postID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return PostAndStuff{}, oops.New(err, "failed to fetch post") + } + + if len(res) == 0 { + return PostAndStuff{}, db.NotFound + } + + return res[0], nil +} + +/* +Fetches a single post and its related data. A wrapper +around FetchPosts. As with FetchPosts, provide as much information +as you know to get the most correct results. + +Returns db.NotFound if no result is found. +*/ +func FetchPost( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + postID int, + q PostsQuery, +) (PostAndStuff, error) { + q.PostIDs = []int{postID} + q.Limit = 1 + q.Offset = 0 + + res, err := FetchPosts(ctx, dbConn, currentUser, q) + if err != nil { + return PostAndStuff{}, oops.New(err, "failed to fetch post") + } + + if len(res) == 0 { + return PostAndStuff{}, db.NotFound + } + + return res[0], nil +} + +func CountPosts( + ctx context.Context, + dbConn db.ConnOrTx, + currentUser *models.User, + q PostsQuery, +) (int, error) { + perf := ExtractPerf(ctx) + perf.StartBlock("SQL", "Count posts") + defer perf.EndBlock() + + var qb db.QueryBuilder + + var currentUserID *int + if currentUser != nil { + currentUserID = ¤tUser.ID + } + + qb.Add( + ` + SELECT COUNT(*) + FROM + handmade_post AS post + JOIN handmade_thread AS thread ON post.thread_id = thread.id + JOIN handmade_project AS project ON post.project_id = project.id + LEFT JOIN auth_user AS author ON author.id = post.author_id + WHERE + NOT thread.deleted + AND NOT post.deleted + AND ( -- project has valid lifecycle + project.flags = 0 AND project.lifecycle = ANY($?) + OR project.id = $? + ) + `, + models.VisibleProjectLifecycles, + models.HMNProjectID, + ) + if len(q.ProjectIDs) > 0 { + qb.Add(`AND project.id = ANY ($?)`, q.ProjectIDs) + } + if len(q.UserIDs) > 0 { + qb.Add(`AND post.author_id = ANY ($?)`, q.UserIDs) + } + if len(q.ThreadIDs) > 0 { + qb.Add(`AND post.thread_id = ANY ($?)`, q.ThreadIDs) + } + if len(q.ThreadTypes) > 0 { + qb.Add(`AND thread.type = ANY ($?)`, q.ThreadTypes) + } + if currentUser == nil { + qb.Add( + `AND author.status = $? -- post author is Approved`, + models.UserStatusApproved, + ) + } else if !currentUser.IsStaff { + qb.Add( + ` + AND ( + author.status = $? -- post author is Approved + OR author.id = $? -- current user is the post author + ) + `, + models.UserStatusApproved, + currentUserID, + ) + } + + count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + if err != nil { + return 0, oops.New(err, "failed to count posts") + } + + return count, nil } func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User, postId int) bool { @@ -233,7 +608,7 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User postId, ) if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return false } else { panic(oops.New(err, "failed to get author of post when checking permissions")) diff --git a/src/website/user.go b/src/website/user.go index af3afd5..d07b9e5 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -57,7 +57,7 @@ func UserProfile(c *RequestContext) ResponseData { ) c.Perf.EndBlock() if err != nil { - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { return FourOhFour(c) } else { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username)) @@ -65,6 +65,15 @@ func UserProfile(c *RequestContext) ResponseData { } profileUser = userResult.(*models.User) } + + { + userIsUnapproved := profileUser.Status != models.UserStatusApproved + canViewUnapprovedUser := c.CurrentUser != nil && (c.CurrentUser.ID == profileUser.ID || c.CurrentUser.IsStaff) + if userIsUnapproved && !canViewUnapprovedUser { + return FourOhFour(c) + } + } + c.Perf.StartBlock("SQL", "Fetch user links") type userLinkQuery struct { UserLink models.Link `db:"link"` @@ -119,30 +128,11 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.EndBlock() - type postQuery struct { - Post models.Post `db:"post"` - Thread models.Thread `db:"thread"` - Project models.Project `db:"project"` - } c.Perf.StartBlock("SQL", "Fetch posts") - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, - ` - SELECT $columns - FROM - handmade_post AS post - INNER JOIN handmade_thread AS thread ON thread.id = post.thread_id - INNER JOIN handmade_project AS project ON project.id = post.project_id - WHERE - post.author_id = $1 - AND project.lifecycle = ANY ($2) - `, - profileUser.ID, - models.VisibleProjectLifecycles, - ) - if err != nil { - return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch posts for user: %s", username)) - } - postQuerySlice := postQueryResult.ToSlice() + posts, err := FetchPosts(c.Context(), c.Conn, c.CurrentUser, PostsQuery{ + UserIDs: []int{profileUser.ID}, + SortDescending: true, + }) c.Perf.EndBlock() type snippetQuery struct { @@ -175,18 +165,17 @@ func UserProfile(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("PROFILE", "Construct timeline items") - timelineItems := make([]templates.TimelineItem, 0, len(postQuerySlice)+len(snippetQuerySlice)) + timelineItems := make([]templates.TimelineItem, 0, len(posts)+len(snippetQuerySlice)) numForums := 0 numBlogs := 0 numSnippets := len(snippetQuerySlice) - for _, postRow := range postQuerySlice { - postData := postRow.(*postQuery) + for _, post := range posts { timelineItem := PostToTimelineItem( lineageBuilder, - &postData.Post, - &postData.Thread, - &postData.Project, + &post.Post, + &post.Thread, + &post.Project, profileUser, c.Theme, ) @@ -199,7 +188,7 @@ func UserProfile(c *RequestContext) ResponseData { if timelineItem.Type != templates.TimelineTypeUnknown { timelineItems = append(timelineItems, timelineItem) } else { - c.Logger.Warn().Int("post ID", postData.Post.ID).Msg("Unknown timeline item type for post") + c.Logger.Warn().Int("post ID", post.Post.ID).Msg("Unknown timeline item type for post") } } @@ -292,7 +281,7 @@ func UserSettings(c *RequestContext) ResponseData { `, c.CurrentUser.ID, ) - if errors.Is(err, db.ErrNoMatchingRows) { + if errors.Is(err, db.NotFound) { // this is fine, but don't fetch any more messages } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))