From 2229ac85d5d5be51e73d01cacbd4bec6431acf61 Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Sat, 16 Apr 2022 20:19:07 -0500 Subject: [PATCH] Update all callsites of db functions Finish converting all callsites Not too bad actually! Centralizing access into the helpers makes a big difference. --- src/admintools/adminproject.go | 2 +- src/admintools/admintools.go | 12 ++--- src/assets/assets.go | 4 +- src/auth/session.go | 3 +- src/discord/commands.go | 12 ++--- src/discord/gateway.go | 11 ++--- src/discord/history.go | 15 ++---- src/discord/message_handling.go | 62 ++++++++++--------------- src/hmndata/project_helper.go | 62 +++++++++++-------------- src/hmndata/snippet_helper.go | 27 ++++------- src/hmndata/tag_helper.go | 11 +---- src/hmndata/threads_and_posts_helper.go | 50 +++++++------------- src/hmndata/twitch.go | 18 +++---- src/twitch/twitch.go | 2 +- src/website/admin.go | 55 +++++++--------------- src/website/api.go | 9 ++-- src/website/auth.go | 44 +++++++----------- src/website/blogs.go | 4 +- src/website/discord.go | 17 ++----- src/website/forums.go | 4 +- src/website/imagefile_helper.go | 5 +- src/website/links_helper.go | 5 +- src/website/podcast.go | 25 ++++------ src/website/projects.go | 50 +++++++++----------- src/website/routes.go | 11 ++--- src/website/twitch.go | 5 +- src/website/user.go | 27 ++++------- 27 files changed, 205 insertions(+), 347 deletions(-) diff --git a/src/admintools/adminproject.go b/src/admintools/adminproject.go index 4f1910e..ed80699 100644 --- a/src/admintools/adminproject.go +++ b/src/admintools/adminproject.go @@ -55,7 +55,7 @@ func addCreateProjectCommand(projectCommand *cobra.Command) { } hmn := p.Project - newProjectID, err := db.QueryInt(ctx, tx, + newProjectID, err := db.QueryOneScalar[int](ctx, tx, ` INSERT INTO handmade_project ( slug, diff --git a/src/admintools/admintools.go b/src/admintools/admintools.go index d88d1e4..23bf624 100644 --- a/src/admintools/admintools.go +++ b/src/admintools/admintools.go @@ -210,7 +210,7 @@ func init() { } defer tx.Rollback(ctx) - projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) + projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) if err != nil { panic(err) } @@ -218,7 +218,7 @@ func init() { var parentId *int if parentSlug == "" { // Select the root subforum - id, err := db.QueryInt(ctx, tx, + id, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`, projectId, ) @@ -228,7 +228,7 @@ func init() { parentId = &id } else { // Select the parent - id, err := db.QueryInt(ctx, tx, + id, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`, parentSlug, projectId, ) @@ -238,7 +238,7 @@ func init() { parentId = &id } - newId, err := db.QueryInt(ctx, tx, + newId, err := db.QueryOneScalar[int](ctx, tx, ` INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id) VALUES ($1, $2, $3, $4, $5) @@ -289,12 +289,12 @@ func init() { } defer tx.Rollback(ctx) - projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) + projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) if err != nil { panic(err) } - subforumId, err := db.QueryInt(ctx, tx, + subforumId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`, subforumSlug, projectId, ) diff --git a/src/assets/assets.go b/src/assets/assets.go index 3b077b5..4ab9dff 100644 --- a/src/assets/assets.go +++ b/src/assets/assets.go @@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As } // Fetch and return the new record - iasset, err := db.QueryOne(ctx, dbConn, models.Asset{}, + asset, err := db.QueryOne[models.Asset](ctx, dbConn, ` SELECT $columns FROM handmade_asset @@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As return nil, oops.New(err, "failed to fetch newly-created asset") } - return iasset.(*models.Asset), nil + return asset, nil } diff --git a/src/auth/session.go b/src/auth/session.go index 84c103b..c42d7d5 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -45,7 +45,7 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses return nil, oops.New(err, "failed to get session") } } - sess := row.(*models.Session) return sess, nil } diff --git a/src/discord/commands.go b/src/discord/commands.go index 5918bd6..eef3cba 100644 --- a/src/discord/commands.go +++ b/src/discord/commands.go @@ -90,12 +90,9 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction return } - type profileResult struct { - HMNUser models.User `db:"auth_user"` - } - ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{}, + hmnUser, err := db.QueryOne[models.User](ctx, bot.dbConn, ` - SELECT $columns + SELECT $columns{auth_user} FROM handmade_discorduser AS duser JOIN auth_user ON duser.hmn_user_id = auth_user.id @@ -122,16 +119,15 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction } return } - res := ires.(*profileResult) projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{ - OwnerIDs: []int{res.HMNUser.ID}, + OwnerIDs: []int{hmnUser.ID}, }) if err != nil { logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects") } - url := hmnurl.BuildUserProfile(res.HMNUser.Username) + url := hmnurl.BuildUserProfile(hmnUser.Username) msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url) if len(projectsAndStuff) > 0 { projectNoun := "projects" diff --git a/src/discord/gateway.go b/src/discord/gateway.go index 55e63cd..9fb2b96 100644 --- a/src/discord/gateway.go +++ b/src/discord/gateway.go @@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error { // an old one or starting a new one. shouldResume := true - isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) + session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`) if err != nil { if errors.Is(err, db.NotFound) { // No session yet! Just identify and get on with it @@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error { if shouldResume { // Reconnect to the previous session - session := isession.(*models.DiscordSession) - err := bot.sendGatewayMessage(ctx, GatewayMessage{ Opcode: OpcodeResume, Data: Resume{ @@ -356,7 +354,7 @@ func (bot *botInstance) doSender(ctx context.Context) { } bot.didAckHeartbeat = false - latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) + latestSequenceNumber, err := db.QueryOneScalar[int](ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) if err != nil { log.Error().Err(err).Msg("failed to fetch latest sequence number from the db") return false @@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) { } defer tx.Rollback(ctx) - msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` + msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, ` SELECT $columns FROM discord_outgoingmessages ORDER BY id ASC @@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) { return } - for _, imsg := range msgs { - msg := imsg.(*models.DiscordOutgoingMessage) + for _, msg := range msgs { if time.Now().After(msg.ExpiresAt) { continue } diff --git a/src/discord/history.go b/src/discord/history.go index 138bae7..3a69d06 100644 --- a/src/discord/history.go +++ b/src/discord/history.go @@ -73,12 +73,9 @@ func RunHistoryWatcher(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { log := logging.ExtractLogger(ctx) - type query struct { - Message models.DiscordMessage `db:"msg"` - } - imessagesWithoutContent, err := db.Query(ctx, dbConn, query{}, + messagesWithoutContent, err := db.Query[models.DiscordMessage](ctx, dbConn, ` - SELECT $columns + SELECT $columns{msg} FROM handmade_discordmessage AS msg JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users @@ -95,10 +92,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { return } - if len(imessagesWithoutContent) > 0 { - log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent)) + if len(messagesWithoutContent) > 0 { + log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent)) msgloop: - for _, imsg := range imessagesWithoutContent { + for _, msg := range messagesWithoutContent { select { case <-ctx.Done(): log.Info().Msg("Scrape was canceled") @@ -106,8 +103,6 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { default: } - msg := imsg.(*query).Message - discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID) if errors.Is(err, NotFound) { // This message has apparently been deleted; delete it from our database diff --git a/src/discord/message_handling.go b/src/discord/message_handling.go index 29c511e..e9172f0 100644 --- a/src/discord/message_handling.go +++ b/src/discord/message_handling.go @@ -165,7 +165,7 @@ func InternMessage( dbConn db.ConnOrTx, msg *Message, ) error { - _, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{}, + _, err := db.QueryOne[models.DiscordMessage](ctx, dbConn, ` SELECT $columns FROM handmade_discordmessage @@ -219,7 +219,7 @@ type InternedMessage struct { } func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) { - result, err := db.QueryOne(ctx, dbConn, InternedMessage{}, + interned, err := db.QueryOne[InternedMessage](ctx, dbConn, ` SELECT $columns FROM @@ -235,8 +235,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) if err != nil { return nil, err } - - interned := result.(*InternedMessage) return interned, nil } @@ -283,7 +281,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message } func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error { - isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -294,10 +292,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In if err != nil && !errors.Is(err, db.NotFound) { return oops.New(err, "failed to fetch snippet for discord message") } - var snippet *models.Snippet - if !errors.Is(err, db.NotFound) { - snippet = isnippet.(*models.Snippet) - } // NOTE(asaf): Also deletes the following through a db cascade: // * handmade_discordmessageattachment @@ -367,7 +361,7 @@ func SaveMessageContents( return oops.New(err, "failed to create or update message contents") } - icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{}, + content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn, ` SELECT $columns FROM @@ -380,7 +374,7 @@ func SaveMessageContents( if err != nil { return oops.New(err, "failed to fetch message contents") } - interned.MessageContent = icontent.(*models.DiscordMessageContent) + interned.MessageContent = content } // Save attachments @@ -395,12 +389,12 @@ func SaveMessageContents( // Save / delete embeds if msg.OriginalHasFields("embeds") { - numSavedEmbeds, err := db.QueryInt(ctx, dbConn, + numSavedEmbeds, err := db.QueryOneScalar[int](ctx, dbConn, ` - SELECT COUNT(*) - FROM handmade_discordmessageembed - WHERE message_id = $1 - `, + SELECT COUNT(*) + FROM handmade_discordmessageembed + WHERE message_id = $1 + `, msg.ID, ) if err != nil { @@ -472,7 +466,7 @@ func saveAttachment( hmnUserID int, discordMessageID string, ) (*models.DiscordMessageAttachment, error) { - iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -481,7 +475,7 @@ func saveAttachment( attachment.ID, ) if err == nil { - return iexisting.(*models.DiscordMessageAttachment), nil + return existing, nil } else if errors.Is(err, db.NotFound) { // this is fine, just create it } else { @@ -534,7 +528,7 @@ func saveAttachment( return nil, oops.New(err, "failed to save Discord attachment data") } - iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -546,7 +540,7 @@ func saveAttachment( return nil, oops.New(err, "failed to fetch new Discord attachment data") } - return iDiscordAttachment.(*models.DiscordMessageAttachment), nil + return discordAttachment, nil } // Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it @@ -636,7 +630,7 @@ func saveEmbed( return nil, oops.New(err, "failed to insert new embed") } - iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{}, + discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -648,11 +642,11 @@ func saveEmbed( return nil, oops.New(err, "failed to fetch new Discord embed data") } - return iDiscordEmbed.(*models.DiscordMessageEmbed), nil + return discordEmbed, nil } func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) { - iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -669,7 +663,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin } } - return iresult.(*models.Snippet), nil + return snippet, nil } /* @@ -805,12 +799,9 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in projectIDs[i] = p.Project.ID } - type tagsRow struct { - Tag models.Tag `db:"tags"` - } - iUserTags, err := db.Query(ctx, tx, tagsRow{}, + userTags, err := db.Query[models.Tag](ctx, tx, ` - SELECT $columns + SELECT $columns{tags} FROM tags JOIN handmade_project AS project ON project.tag = tags.id @@ -823,8 +814,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in return oops.New(err, "failed to fetch tags for user projects") } - for _, itag := range iUserTags { - tag := itag.(*tagsRow).Tag + for _, tag := range userTags { allTags = append(allTags, tag.ID) for _, messageTag := range messageTags { if strings.EqualFold(tag.Text, messageTag) { @@ -890,7 +880,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) { // Check attachments - attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{}, + attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -901,13 +891,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch message attachments") } - for _, iattachment := range attachments { - attachment := iattachment.(*models.DiscordMessageAttachment) + for _, attachment := range attachments { return &attachment.AssetID, nil, nil } // Check embeds - embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{}, + embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -918,8 +907,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch discord embeds") } - for _, iembed := range embeds { - embed := iembed.(*models.DiscordMessageEmbed) + for _, embed := range embeds { if embed.VideoID != nil { return embed.VideoID, nil, nil } else if embed.ImageID != nil { diff --git a/src/hmndata/project_helper.go b/src/hmndata/project_helper.go index 7341ae7..128665a 100644 --- a/src/hmndata/project_helper.go +++ b/src/hmndata/project_helper.go @@ -140,15 +140,15 @@ func FetchProjects( } // Do the query - iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...) + projectRows, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch projects") } // Fetch project owners to do permission checks - projectIds := make([]int, len(iprojects)) - for i, iproject := range iprojects { - projectIds[i] = iproject.(*projectRow).Project.ID + projectIds := make([]int, len(projectRows)) + for i, p := range projectRows { + projectIds[i] = p.Project.ID } projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds) if err != nil { @@ -156,8 +156,7 @@ func FetchProjects( } var res []ProjectAndStuff - for i, iproject := range iprojects { - row := iproject.(*projectRow) + for i, p := range projectRows { owners := projectOwners[i].Owners /* @@ -191,10 +190,10 @@ func FetchProjects( } projectGenerallyVisible := true && - row.Project.Lifecycle.In(models.VisibleProjectLifecycles) && - !row.Project.Hidden && - (!row.Project.Personal || allOwnersApproved || row.Project.IsHMN()) - if row.Project.IsHMN() { + p.Project.Lifecycle.In(models.VisibleProjectLifecycles) && + !p.Project.Hidden && + (!p.Project.Personal || allOwnersApproved || p.Project.IsHMN()) + if p.Project.IsHMN() { projectGenerallyVisible = true // hard override } @@ -205,11 +204,11 @@ func FetchProjects( if projectVisible { res = append(res, ProjectAndStuff{ - Project: row.Project, - LogoLightAsset: row.LogoLightAsset, - LogoDarkAsset: row.LogoDarkAsset, + Project: p.Project, + LogoLightAsset: p.LogoLightAsset, + LogoDarkAsset: p.LogoDarkAsset, Owners: owners, - Tag: row.Tag, + Tag: p.Tag, }) } } @@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners( UserID int `db:"user_id"` ProjectID int `db:"project_id"` } - iuserprojects, err := db.Query(ctx, tx, userProject{}, + userProjects, err := db.Query[userProject](ctx, tx, ` SELECT $columns FROM handmade_user_projects @@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners( // Get the unique user IDs from this set and fetch the users from the db var userIds []int - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { addUserId := true for _, uid := range userIds { if uid == userProject.UserID { @@ -361,14 +358,12 @@ func FetchMultipleProjectsOwners( userIds = append(userIds, userProject.UserID) } } - type userQuery struct { - User models.User `db:"auth_user"` - } - iusers, err := db.Query(ctx, tx, userQuery{}, + users, err := db.Query[models.User](ctx, tx, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE auth_user.id = ANY($1) `, @@ -383,9 +378,7 @@ func FetchMultipleProjectsOwners( for i, pid := range projectIds { res[i] = ProjectOwners{ProjectID: pid} } - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { // Get a pointer to the existing record in the result var projectOwners *ProjectOwners for i := range res { @@ -396,10 +389,9 @@ func FetchMultipleProjectsOwners( // Get the full user record we fetched var user *models.User - for _, iuser := range iusers { - u := iuser.(*userQuery).User + for _, u := range users { if u.ID == userProject.UserID { - user = &u + user = u } } if user == nil { @@ -473,7 +465,7 @@ func SetProjectTag( resultTag = p.Tag } else if p.Project.TagID == nil { // Create a tag - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` INSERT INTO tags (text) VALUES ($1) RETURNING $columns @@ -483,7 +475,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to create new tag for project") } - resultTag = itag.(*models.Tag) + resultTag = tag // Attach it to the project _, err = tx.Exec(ctx, @@ -499,7 +491,7 @@ func SetProjectTag( } } else { // Update the text of an existing one - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` UPDATE tags SET text = $1 @@ -511,7 +503,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to update existing tag") } - resultTag = itag.(*models.Tag) + resultTag = tag } err = tx.Commit(ctx) diff --git a/src/hmndata/snippet_helper.go b/src/hmndata/snippet_helper.go index 99356de..ecdd3a8 100644 --- a/src/hmndata/snippet_helper.go +++ b/src/hmndata/snippet_helper.go @@ -44,10 +44,7 @@ func FetchSnippets( if len(q.Tags) > 0 { // Get snippet IDs with this tag, then use that in the main query - type snippetIDRow struct { - SnippetID int `db:"snippet_id"` - } - iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{}, + snippetIDs, err := db.QueryScalar[int](ctx, tx, ` SELECT DISTINCT snippet_id FROM @@ -63,14 +60,11 @@ func FetchSnippets( } // special early-out: no snippets found for these tags at all - if len(iSnippetIDs) == 0 { + if len(snippetIDs) == 0 { return nil, nil } - q.IDs = make([]int, len(iSnippetIDs)) - for i := range iSnippetIDs { - q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID - } + q.IDs = snippetIDs } var qb db.QueryBuilder @@ -125,16 +119,14 @@ func FetchSnippets( DiscordMessage *models.DiscordMessage `db:"discord_message"` } - iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...) + results, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not - snippetIDs := make([]int, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]SnippetAndStuff, len(results)) // allocate extra space because why not + snippetIDs := make([]int, len(results)) + for i, row := range results { result[i] = SnippetAndStuff{ Snippet: row.Snippet, Owner: row.Owner, @@ -150,7 +142,7 @@ func FetchSnippets( SnippetID int `db:"snippet_tags.snippet_id"` Tag *models.Tag `db:"tags"` } - iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{}, + snippetTags, err := db.Query[snippetTagRow](ctx, tx, ` SELECT $columns FROM @@ -170,8 +162,7 @@ func FetchSnippets( for i := range result { resultBySnippetId[result[i].Snippet.ID] = &result[i] } - for _, iSnippetTag := range iSnippetTags { - snippetTag := iSnippetTag.(*snippetTagRow) + for _, snippetTag := range snippetTags { item := resultBySnippetId[snippetTag.SnippetID] item.Tags = append(item.Tags, snippetTag.Tag) } diff --git a/src/hmndata/tag_helper.go b/src/hmndata/tag_helper.go index fab5b94..8c0fe63 100644 --- a/src/hmndata/tag_helper.go +++ b/src/hmndata/tag_helper.go @@ -40,18 +40,11 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...) + tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch tags") } - - res := make([]*models.Tag, len(itags)) - for i, itag := range itags { - tag := itag.(*models.Tag) - res[i] = tag - } - - return res, nil + return tags, nil } func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) { diff --git a/src/hmndata/threads_and_posts_helper.go b/src/hmndata/threads_and_posts_helper.go index a51cbc0..6525e4e 100644 --- a/src/hmndata/threads_and_posts_helper.go +++ b/src/hmndata/threads_and_posts_helper.go @@ -145,15 +145,13 @@ func FetchThreads( ForumLastReadTime *time.Time `db:"slri.lastread"` } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]ThreadAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]ThreadAndStuff, len(rows)) + for i, row := range rows { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) { hasRead = true @@ -263,7 +261,7 @@ func CountThreads( ) } - count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return 0, oops.New(err, "failed to fetch count of threads") } @@ -405,15 +403,13 @@ func FetchPosts( qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch posts") } - result := make([]PostAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]PostAndStuff, len(rows)) + for i, row := range rows { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) { hasRead = true @@ -595,7 +591,7 @@ func CountPosts( ) } - count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return 0, oops.New(err, "failed to count posts") } @@ -608,12 +604,9 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User return true } - type postResult struct { - AuthorID *int `db:"post.author_id"` - } - iresult, err := db.QueryOne(ctx, connOrTx, postResult{}, + authorID, err := db.QueryOneScalar[*int](ctx, connOrTx, ` - SELECT $columns + SELECT post.author_id FROM handmade_post AS post WHERE @@ -629,9 +622,8 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User panic(oops.New(err, "failed to get author of post when checking permissions")) } } - result := iresult.(*postResult) - return result.AuthorID != nil && *result.AuthorID == user.ID + return authorID != nil && *authorID == user.ID } func CreateNewPost( @@ -709,7 +701,7 @@ func DeletePost( FirstPostID int `db:"first_id"` Deleted bool `db:"deleted"` } - ti, err := db.QueryOne(ctx, tx, threadInfo{}, + info, err := db.QueryOne[threadInfo](ctx, tx, ` SELECT $columns FROM @@ -722,7 +714,6 @@ func DeletePost( if err != nil { panic(oops.New(err, "failed to fetch thread info")) } - info := ti.(*threadInfo) if info.Deleted { return true } @@ -848,12 +839,9 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte keys = append(keys, key) } - type assetId struct { - AssetID uuid.UUID `db:"id"` - } - assetResult, err := db.Query(ctx, tx, assetId{}, + assetIDs, err := db.QueryScalar[uuid.UUID](ctx, tx, ` - SELECT $columns + SELECT id FROM handmade_asset WHERE s3_key = ANY($1) `, @@ -865,8 +853,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte var values [][]interface{} - for _, asset := range assetResult { - values = append(values, []interface{}{postId, asset.(*assetId).AssetID}) + for _, assetID := range assetIDs { + values = append(values, []interface{}{postId, assetID}) } _, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values)) @@ -886,7 +874,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more. You should probably mark the thread as deleted in this case. */ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { - postsIter, err := db.Query(ctx, tx, models.Post{}, + posts, err := db.Query[models.Post](ctx, tx, ` SELECT $columns FROM handmade_post @@ -901,9 +889,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { } var firstPost, lastPost *models.Post - for _, ipost := range postsIter { - post := ipost.(*models.Post) - + for _, post := range posts { if firstPost == nil || post.PostDate.Before(firstPost.PostDate) { firstPost = post } diff --git a/src/hmndata/twitch.go b/src/hmndata/twitch.go index 3ad4d96..f52bc5f 100644 --- a/src/hmndata/twitch.go +++ b/src/hmndata/twitch.go @@ -22,12 +22,9 @@ type TwitchStreamer struct { var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P[^/]+)$`) func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) { - type linkResult struct { - Link models.Link `db:"link"` - } - streamers, err := db.Query(ctx, dbConn, linkResult{}, + dbStreamers, err := db.Query[models.Link](ctx, dbConn, ` - SELECT $columns + SELECT $columns{link} FROM handmade_links AS link LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id @@ -49,10 +46,8 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre return nil, oops.New(err, "failed to fetch twitch links") } - result := make([]TwitchStreamer, 0, len(streamers)) - for _, s := range streamers { - dbStreamer := s.(*linkResult).Link - + result := make([]TwitchStreamer, 0, len(dbStreamers)) + for _, dbStreamer := range dbStreamers { streamer := TwitchStreamer{ UserID: dbStreamer.UserID, ProjectID: dbStreamer.ProjectID, @@ -81,7 +76,7 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre } func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) { - links, err := db.Query(ctx, dbConn, models.Link{}, + links, err := db.Query[models.Link](ctx, dbConn, ` SELECT $columns FROM @@ -100,8 +95,7 @@ func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, result := make([]string, 0, len(links)) for _, l := range links { - url := l.(*models.Link).URL - match := twitchRegex.FindStringSubmatch(url) + match := twitchRegex.FindStringSubmatch(l.URL) if match != nil { login := strings.ToLower(match[twitchRegex.SubexpIndex("login")]) result = append(result, login) diff --git a/src/twitch/twitch.go b/src/twitch/twitch.go index 99279cb..62eaf96 100644 --- a/src/twitch/twitch.go +++ b/src/twitch/twitch.go @@ -440,7 +440,7 @@ func updateStreamStatusInDB(ctx context.Context, conn db.ConnOrTx, status *strea inserted := false if isStatusRelevant(status) { log.Debug().Msg("Status relevant") - _, err := db.QueryOne(ctx, conn, models.TwitchStream{}, + _, err := db.QueryOne[models.TwitchStream](ctx, conn, ` SELECT $columns FROM twitch_streams diff --git a/src/website/admin.go b/src/website/admin.go index 0bdacd6..7133dd4 100644 --- a/src/website/admin.go +++ b/src/website/admin.go @@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { userIds = append(userIds, u.User.ID) } - userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, + userLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) } - for _, ul := range userLinks { - link := ul.(*models.Link) + for _, link := range userLinks { userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) } @@ -257,12 +256,9 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { return RejectRequest(c, "User id can't be parsed") } - type userQuery struct { - User models.User `db:"auth_user"` - } - u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE auth_user.id = $1 @@ -276,7 +272,6 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) } } - user := u.(*userQuery).User whatHappened := "" if action == ApprovalQueueActionApprove { @@ -337,7 +332,7 @@ type UnapprovedPost struct { } func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { - it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, + posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn, ` SELECT $columns FROM @@ -358,11 +353,7 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved posts") } - var res []*UnapprovedPost - for _, iresult := range it { - res = append(res, iresult.(*UnapprovedPost)) - } - return res, nil + return posts, nil } type UnapprovedProject struct { @@ -372,12 +363,9 @@ type UnapprovedProject struct { } func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { - type unapprovedUser struct { - ID int `db:"id"` - } - it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, + ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn, ` - SELECT $columns + SELECT id FROM auth_user AS u WHERE @@ -388,10 +376,6 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved users") } - ownerIDs := make([]int, 0, len(it)) - for _, uid := range it { - ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) - } projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: ownerIDs, @@ -406,7 +390,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { projectIDs = append(projectIDs, p.Project.ID) } - projectLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -425,8 +409,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { for idx, proj := range projects { links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects. - for _, l := range projectLinks { - link := l.(*models.Link) + for _, link := range projectLinks { if *link.ProjectID == proj.Project.ID { links = append(links, link) } @@ -455,7 +438,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int) ThreadID int `db:"thread.id"` PostID int `db:"post.id"` } - it, err := db.Query(ctx, tx, toDelete{}, + rows, err := db.Query[toDelete](ctx, tx, ` SELECT $columns FROM @@ -471,8 +454,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int) return oops.New(err, "failed to fetch posts to delete for user") } - for _, iResult := range it { - row := iResult.(*toDelete) + for _, row := range rows { hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID) } err = tx.Commit(ctx) @@ -489,9 +471,9 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in } defer tx.Rollback(ctx) - toDelete, err := db.Query(ctx, tx, models.Project{}, + projectIDsToDelete, err := db.QueryScalar[int](ctx, tx, ` - SELECT $columns + SELECT project.id FROM handmade_project AS project JOIN handmade_user_projects AS up ON up.project_id = project.id @@ -504,17 +486,12 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in return oops.New(err, "failed to fetch user's projects") } - var projectIds []int - for _, p := range toDelete { - projectIds = append(projectIds, p.(*models.Project).ID) - } - - if len(projectIds) > 0 { + if len(projectIDsToDelete) > 0 { _, err = tx.Exec(ctx, ` DELETE FROM handmade_project WHERE id = ANY($1) `, - projectIds, + projectIDsToDelete, ) if err != nil { return oops.New(err, "failed to delete user's projects") diff --git a/src/website/api.go b/src/website/api.go index f1c5b33..2462782 100644 --- a/src/website/api.go +++ b/src/website/api.go @@ -19,12 +19,9 @@ func APICheckUsername(c *RequestContext) ResponseData { requestedUsername := usernameArgs[0] found = true c.Perf.StartBlock("SQL", "Fetch user") - type userQuery struct { - User models.User `db:"auth_user"` - } - userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id @@ -43,7 +40,7 @@ func APICheckUsername(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) } } else { - canonicalUsername = userResult.(*userQuery).User.Username + canonicalUsername = user.Username } } diff --git a/src/website/auth.go b/src/website/auth.go index 991f403..4dedbba 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -75,13 +75,11 @@ func Login(c *RequestContext) ResponseData { return res } - type userQuery struct { - User models.User `db:"auth_user"` - } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user + SELECT $columns{auth_user} + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = LOWER($1) `, @@ -94,7 +92,6 @@ func Login(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } - user := &userRow.(*userQuery).User success, err := tryLogin(c, user, password) @@ -174,7 +171,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Check for existing usernames and emails") userAlreadyExists := true - _, err := db.QueryInt(c.Context(), c.Conn, + _, err := db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT id FROM auth_user @@ -195,7 +192,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { } emailAlreadyExists := true - _, err = db.QueryInt(c.Context(), c.Conn, + _, err = db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT id FROM auth_user @@ -454,17 +451,16 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return RejectRequest(c, "You must provide a username and an email address.") } - var user *models.User - c.Perf.StartBlock("SQL", "Fetching user") type userQuery struct { User models.User `db:"auth_user"` } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = LOWER($1) AND LOWER(email) = LOWER($2) @@ -478,13 +474,10 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } - if userRow != nil { - user = &userRow.(*userQuery).User - } if user != nil { c.Perf.StartBlock("SQL", "Fetching existing token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` SELECT $columns FROM handmade_onetimetoken @@ -501,10 +494,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) } } - var resetToken *models.OneTimeToken - if tokenRow != nil { - resetToken = tokenRow.(*models.OneTimeToken) - } now := time.Now() if resetToken != nil { @@ -527,7 +516,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if resetToken == nil { c.Perf.StartBlock("SQL", "Creating new token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) VALUES ($1, $2, $3, $4, $5) @@ -543,7 +532,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) } - resetToken = tokenRow.(*models.OneTimeToken) + resetToken = newToken err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) if err != nil { @@ -787,7 +776,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, User models.User `db:"auth_user"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"` } - row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, + data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, ` SELECT $columns FROM auth_user @@ -807,8 +796,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, return result } } - if row != nil { - data := row.(*userAndTokenQuery) + if data != nil { result.User = &data.User result.OneTimeToken = data.OneTimeToken if result.OneTimeToken != nil { diff --git a/src/website/blogs.go b/src/website/blogs.go index 70d0af3..0ff0a9e 100644 --- a/src/website/blogs.go +++ b/src/website/blogs.go @@ -558,7 +558,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.ThreadID = threadId c.Perf.StartBlock("SQL", "Verify that the thread exists") - threadExists, err := db.QueryBool(c.Context(), c.Conn, + threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, ` SELECT COUNT(*) > 0 FROM handmade_thread @@ -586,7 +586,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.PostID = postId c.Perf.StartBlock("SQL", "Verify that the post exists") - postExists, err := db.QueryBool(c.Context(), c.Conn, + postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, ` SELECT COUNT(*) > 0 FROM handmade_post diff --git a/src/website/discord.go b/src/website/discord.go index 8554da9..ae0911f 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -104,7 +104,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } defer tx.Rollback(c.Context()) - iDiscordUser, err := db.QueryOne(c.Context(), tx, models.DiscordUser{}, + discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx, ` SELECT $columns FROM handmade_discorduser @@ -119,7 +119,6 @@ func DiscordUnlink(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink")) } } - discordUser := iDiscordUser.(*models.DiscordUser) _, err = tx.Exec(c.Context(), ` @@ -146,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } func DiscordShowcaseBacklog(c *RequestContext) ResponseData { - iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{}, + duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, `SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`, c.CurrentUser.ID, ) @@ -157,14 +156,10 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user")) } - duser := iduser.(*models.DiscordUser) - type messageIdQuery struct { - MessageID string `db:"msg.id"` - } - iMsgIDs, err := db.Query(c.Context(), c.Conn, messageIdQuery{}, + msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn, ` - SELECT $columns + SELECT msg.id FROM handmade_discordmessage AS msg WHERE @@ -178,10 +173,6 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, err) } - var msgIDs []string - for _, imsgId := range iMsgIDs { - msgIDs = append(msgIDs, imsgId.(*messageIdQuery).MessageID) - } for _, msgID := range msgIDs { interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID) if err != nil && !errors.Is(err, db.NotFound) { diff --git a/src/website/forums.go b/src/website/forums.go index ea6b8bb..6b60390 100644 --- a/src/website/forums.go +++ b/src/website/forums.go @@ -936,7 +936,7 @@ func addForumUrlsToPost(urlContext *hmnurl.UrlContext, p *templates.Post, subfor // Takes a template post and adds information about how many posts the user has made // on the site. func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) { - numPosts, err := db.QueryInt(ctx, conn, + numPosts, err := db.QueryOneScalar[int](ctx, conn, ` SELECT COUNT(*) FROM @@ -956,7 +956,7 @@ func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.P p.AuthorNumPosts = numPosts } - numProjects, err := db.QueryInt(ctx, conn, + numProjects, err := db.QueryOneScalar[int](ctx, conn, ` SELECT COUNT(*) FROM diff --git a/src/website/imagefile_helper.go b/src/website/imagefile_helper.go index 7d44bf3..aaf2ae7 100644 --- a/src/website/imagefile_helper.go +++ b/src/website/imagefile_helper.go @@ -89,8 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string, img.Seek(0, io.SeekStart) io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs sha1sum := hasher.Sum(nil) - // TODO(db): Should use insert helper - imageFile, err := db.QueryOne(c.Context(), dbConn, models.ImageFile{}, + imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn, ` INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height) VALUES ($1, $2, $3, $4, $5, $6) @@ -105,7 +104,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string, } return SaveImageFileResult{ - ImageFile: imageFile.(*models.ImageFile), + ImageFile: imageFile, } } diff --git a/src/website/links_helper.go b/src/website/links_helper.go index 93eeef1..156058e 100644 --- a/src/website/links_helper.go +++ b/src/website/links_helper.go @@ -30,10 +30,9 @@ func ParseLinks(text string) []ParsedLink { return res } -func LinksToText(links []interface{}) string { +func LinksToText(links []*models.Link) string { linksText := "" - for _, l := range links { - link := l.(*models.Link) + for _, link := range links { linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name) } return linksText diff --git a/src/website/podcast.go b/src/website/podcast.go index e6e800e..c542bcb 100644 --- a/src/website/podcast.go +++ b/src/website/podcast.go @@ -532,11 +532,12 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG Podcast models.Podcast `db:"podcast"` ImageFilename string `db:"imagefile.file"` } - podcastQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastQuery{}, + podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn, ` SELECT $columns - FROM handmade_podcast AS podcast - LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id + FROM + handmade_podcast AS podcast + LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id WHERE podcast.project_id = $1 `, projectId, @@ -549,18 +550,15 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG return result, oops.New(err, "failed to fetch podcast") } } - podcast := podcastQueryResult.(*podcastQuery).Podcast - podcastImageFilename := podcastQueryResult.(*podcastQuery).ImageFilename + podcast := podcastQueryResult.Podcast + podcastImageFilename := podcastQueryResult.ImageFilename result.Podcast = &podcast result.ImageFile = podcastImageFilename if fetchEpisodes { - type podcastEpisodeQuery struct { - Episode models.PodcastEpisode `db:"episode"` - } if episodeGUID == "" { c.Perf.StartBlock("SQL", "Fetch podcast episodes") - podcastEpisodeQueryResult, err := db.Query(c.Context(), c.Conn, podcastEpisodeQuery{}, + episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn, ` SELECT $columns FROM handmade_podcastepisode AS episode @@ -573,16 +571,14 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG if err != nil { return result, oops.New(err, "failed to fetch podcast episodes") } - for _, episodeRow := range podcastEpisodeQueryResult { - result.Episodes = append(result.Episodes, &episodeRow.(*podcastEpisodeQuery).Episode) - } + result.Episodes = episodes } else { guid, err := uuid.Parse(episodeGUID) if err != nil { return result, err } c.Perf.StartBlock("SQL", "Fetch podcast episode") - podcastEpisodeQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastEpisodeQuery{}, + episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn, ` SELECT $columns FROM handmade_podcastepisode AS episode @@ -599,8 +595,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG return result, oops.New(err, "failed to fetch podcast episode") } } - episode := podcastEpisodeQueryResult.(*podcastEpisodeQuery).Episode - result.Episodes = append(result.Episodes, &episode) + result.Episodes = append(result.Episodes, episode) } } diff --git a/src/website/projects.go b/src/website/projects.go index 84f113c..bb83508 100644 --- a/src/website/projects.go +++ b/src/website/projects.go @@ -187,12 +187,9 @@ func ProjectHomepage(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetching screenshots") - type screenshotQuery struct { - Filename string `db:"screenshot.file"` - } - screenshotQueryResult, err := db.Query(c.Context(), c.Conn, screenshotQuery{}, + screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn, ` - SELECT $columns + SELECT screenshot.file FROM handmade_imagefile AS screenshot INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id @@ -207,10 +204,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetching project links") - type projectLinkQuery struct { - Link models.Link `db:"link"` - } - projectLinkResult, err := db.Query(c.Context(), c.Conn, projectLinkQuery{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -237,7 +231,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { Thread models.Thread `db:"thread"` Author models.User `db:"author"` } - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, + posts, err := db.Query[postQuery](c.Context(), c.Conn, ` SELECT $columns FROM @@ -318,21 +312,21 @@ func ProjectHomepage(c *RequestContext) ResponseData { } } - for _, screenshot := range screenshotQueryResult { - templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshot.(*screenshotQuery).Filename)) + for _, screenshotFilename := range screenshotFilenames { + templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshotFilename)) } - for _, link := range projectLinkResult { - templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(&link.(*projectLinkQuery).Link)) + for _, link := range projectLinks { + templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(link)) } - for _, post := range postQueryResult { + for _, post := range posts { templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem( c.UrlContext, lineageBuilder, - &post.(*postQuery).Post, - &post.(*postQuery).Thread, - &post.(*postQuery).Author, + &post.Post, + &post.Thread, + &post.Author, c.Theme, )) } @@ -498,7 +492,7 @@ func ProjectEdit(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetching project links") - projectLinkResult, err := db.Query(c.Context(), c.Conn, models.Link{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -525,7 +519,7 @@ func ProjectEdit(c *RequestContext) ResponseData { c.Theme, ) - projectSettings.LinksText = LinksToText(projectLinkResult) + projectSettings.LinksText = LinksToText(projectLinks) var res ResponseData res.MustWriteTemplate("project_edit.html", ProjectEditData{ @@ -822,14 +816,12 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P } } - type userQuery struct { - User models.User `db:"auth_user"` - } - ownerRows, err := db.Query(ctx, tx, userQuery{}, + owners, err := db.Query[models.User](ctx, tx, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = ANY ($1) `, payload.OwnerUsernames, @@ -849,7 +841,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P return oops.New(err, "Failed to delete project owners") } - for _, ownerRow := range ownerRows { + for _, owner := range owners { _, err = tx.Exec(ctx, ` INSERT INTO handmade_user_projects @@ -857,7 +849,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P VALUES ($1, $2) `, - ownerRow.(*userQuery).User.ID, + owner.ID, payload.ProjectID, ) if err != nil { diff --git a/src/website/routes.go b/src/website/routes.go index 7d89099..8b7ec30 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -548,13 +548,11 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User } } - type userQuery struct { - User models.User `db:"auth_user"` - } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user + SELECT $columns{auth_user} + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE username = $1 `, @@ -568,7 +566,6 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User return nil, nil, oops.New(err, "failed to get user for session") } } - user := &userRow.(*userQuery).User return user, session, nil } diff --git a/src/website/twitch.go b/src/website/twitch.go index e7e0ff7..eda3072 100644 --- a/src/website/twitch.go +++ b/src/website/twitch.go @@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData { } func TwitchDebugPage(c *RequestContext) ResponseData { - streams, err := db.Query(c.Context(), c.Conn, models.TwitchStream{}, + streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn, ` SELECT $columns FROM @@ -83,8 +83,7 @@ func TwitchDebugPage(c *RequestContext) ResponseData { } html := "" - for _, stream := range streams { - s := stream.(*models.TwitchStream) + for _, s := range streams { html += fmt.Sprintf(`%s%s
`, s.Login, s.Login, s.Title) } var res ResponseData diff --git a/src/website/user.go b/src/website/user.go index 395695b..fa77a7a 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -53,12 +53,9 @@ func UserProfile(c *RequestContext) ResponseData { profileUser = c.CurrentUser } else { c.Perf.StartBlock("SQL", "Fetch user") - type userQuery struct { - User models.User `db:"auth_user"` - } - userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id @@ -75,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username)) } } - profileUser = &userResult.(*userQuery).User + profileUser = user } { @@ -87,10 +84,7 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch user links") - type userLinkQuery struct { - UserLink models.Link `db:"link"` - } - userLinksSlice, err := db.Query(c.Context(), c.Conn, userLinkQuery{}, + userLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -104,9 +98,9 @@ func UserProfile(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username)) } - profileUserLinks := make([]templates.Link, 0, len(userLinksSlice)) - for _, l := range userLinksSlice { - profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(&l.(*userLinkQuery).UserLink)) + profileUserLinks := make([]templates.Link, 0, len(userLinks)) + for _, l := range userLinks { + profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(l)) } c.Perf.EndBlock() @@ -231,7 +225,7 @@ func UserSettings(c *RequestContext) ResponseData { DiscordShowcaseBacklogUrl string } - links, err := db.Query(c.Context(), c.Conn, models.Link{}, + links, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM handmade_links @@ -248,7 +242,7 @@ func UserSettings(c *RequestContext) ResponseData { var tduser *templates.DiscordUser var numUnsavedMessages int - iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{}, + duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, ` SELECT $columns FROM handmade_discorduser @@ -261,11 +255,10 @@ func UserSettings(c *RequestContext) ResponseData { } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account")) } else { - duser := iduser.(*models.DiscordUser) tmp := templates.DiscordUserToTemplate(duser) tduser = &tmp - numUnsavedMessages, err = db.QueryInt(c.Context(), c.Conn, + numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT COUNT(*) FROM