From e74b3273cb832e52aae3e5adf458c3cab2bee68a Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Tue, 15 Mar 2022 16:09:11 -0500 Subject: [PATCH] Start converting db stuff to generics --- go.mod | 41 +++++++++++++++++---- go.sum | 7 ++-- src/assets/assets.go | 4 +-- src/auth/session.go | 3 +- src/db/db.go | 25 ++++++------- src/discord/commands.go | 3 +- src/discord/gateway.go | 9 ++--- src/discord/history.go | 10 +++--- src/discord/message_handling.go | 47 +++++++++++-------------- src/hmndata/project_helper.go | 35 ++++++++---------- src/hmndata/snippet_helper.go | 25 ++++++------- src/hmndata/tag_helper.go | 10 ++---- src/hmndata/threads_and_posts_helper.go | 34 +++++++----------- src/models/subforum.go | 6 ++-- 14 files changed, 128 insertions(+), 131 deletions(-) diff --git a/go.mod b/go.mod index 8c56138..453a082 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,8 @@ module git.handmade.network/hmn/hmn -go 1.16 +go 1.18 require ( - github.com/Masterminds/goutils v1.1.1 // indirect - github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible github.com/alecthomas/chroma v0.9.2 github.com/aws/aws-sdk-go-v2 v1.8.1 @@ -16,13 +14,10 @@ require ( github.com/go-stack/stack v1.8.0 github.com/google/uuid v1.2.0 github.com/gorilla/websocket v1.4.2 - github.com/huandu/xstrings v1.3.2 // indirect - github.com/imdario/mergo v0.3.12 // indirect github.com/jackc/pgconn v1.8.0 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.10.1 github.com/jpillora/backoff v1.0.0 - github.com/mitchellh/copystructure v1.1.1 // indirect github.com/rs/zerolog v1.21.0 github.com/spf13/cobra v1.1.3 github.com/stretchr/testify v1.7.0 @@ -32,9 +27,43 @@ require ( github.com/yuin/goldmark v1.4.1 github.com/yuin/goldmark-highlighting v0.0.0-20210516132338-9216f9c5aa01 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7 golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d ) +require ( + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver v1.5.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect + github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.4.0 // indirect + github.com/huandu/xstrings v1.3.2 // indirect + github.com/imdario/mergo v0.3.12 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.0.6 // indirect + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect + github.com/jackc/puddle v1.1.3 // indirect + github.com/mitchellh/copystructure v1.1.1 // indirect + github.com/mitchellh/reflectwalk v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect + golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect + golang.org/x/text v0.3.6 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) + replace ( github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9 github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e diff --git a/go.sum b/go.sum index 6e6029e..2b8a521 100644 --- a/go.sum +++ b/go.sum @@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= @@ -372,6 +370,8 @@ golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7 h1:jynE66seADJbyWMUdeOyVTvPtBZt7L6LJHupGwxPZRM= +golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d h1:RNPAfi2nHY7C2srAV8A49jpsYr0ADedCk1wq6fTMTvs= @@ -439,8 +439,9 @@ golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/src/assets/assets.go b/src/assets/assets.go index 3b077b5..4ab9dff 100644 --- a/src/assets/assets.go +++ b/src/assets/assets.go @@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As } // Fetch and return the new record - iasset, err := db.QueryOne(ctx, dbConn, models.Asset{}, + asset, err := db.QueryOne[models.Asset](ctx, dbConn, ` SELECT $columns FROM handmade_asset @@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As return nil, oops.New(err, "failed to fetch newly-created asset") } - return iasset.(*models.Asset), nil + return asset, nil } diff --git a/src/auth/session.go b/src/auth/session.go index 84c103b..c42d7d5 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -45,7 +45,7 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses return nil, oops.New(err, "failed to get session") } } - sess := row.(*models.Session) return sess, nil } diff --git a/src/db/db.go b/src/db/db.go index 20aee1c..55b54b9 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -95,14 +95,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { return conn } -type StructQueryIterator struct { +type StructQueryIterator[T any] struct { fieldPaths [][]int rows pgx.Rows destType reflect.Type closed chan struct{} } -func (it *StructQueryIterator) Next() (interface{}, bool) { +func (it *StructQueryIterator[T]) Next() (*T, bool) { hasNext := it.rows.Next() if !hasNext { it.Close() @@ -172,10 +172,10 @@ func (it *StructQueryIterator) Next() (interface{}, bool) { currentValue = reflect.Value{} } - return result.Interface(), true + return result.Interface().(*T), true } -func (it *StructQueryIterator) Close() { +func (it *StructQueryIterator[any]) Close() { it.rows.Close() select { case it.closed <- struct{}{}: @@ -183,9 +183,9 @@ func (it *StructQueryIterator) Close() { } } -func (it *StructQueryIterator) ToSlice() []interface{} { +func (it *StructQueryIterator[T]) ToSlice() []*T { defer it.Close() - var result []interface{} + var result []*T for { row, ok := it.Next() if !ok { @@ -231,8 +231,8 @@ func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.V return val, field } -func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) { - it, err := QueryIterator(ctx, conn, destExample, query, args...) +func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) { + it, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err } else { @@ -240,7 +240,8 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st } } -func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) { +func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { + var destExample T destType := reflect.TypeOf(destExample) columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil) if err != nil { @@ -268,7 +269,7 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, return nil, err } - it := &StructQueryIterator{ + it := &StructQueryIterator[T]{ fieldPaths: fieldPaths, rows: rows, destType: destType, @@ -370,8 +371,8 @@ result but find nothing. */ var NotFound = errors.New("not found") -func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { - rows, err := QueryIterator(ctx, conn, destExample, query, args...) +func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err } diff --git a/src/discord/commands.go b/src/discord/commands.go index 5918bd6..bb17b80 100644 --- a/src/discord/commands.go +++ b/src/discord/commands.go @@ -93,7 +93,7 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction type profileResult struct { HMNUser models.User `db:"auth_user"` } - ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{}, + res, err := db.QueryOne[profileResult](ctx, bot.dbConn, ` SELECT $columns FROM @@ -122,7 +122,6 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction } return } - res := ires.(*profileResult) projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{ OwnerIDs: []int{res.HMNUser.ID}, diff --git a/src/discord/gateway.go b/src/discord/gateway.go index 55e63cd..285d67e 100644 --- a/src/discord/gateway.go +++ b/src/discord/gateway.go @@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error { // an old one or starting a new one. shouldResume := true - isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) + session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`) if err != nil { if errors.Is(err, db.NotFound) { // No session yet! Just identify and get on with it @@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error { if shouldResume { // Reconnect to the previous session - session := isession.(*models.DiscordSession) - err := bot.sendGatewayMessage(ctx, GatewayMessage{ Opcode: OpcodeResume, Data: Resume{ @@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) { } defer tx.Rollback(ctx) - msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` + msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, ` SELECT $columns FROM discord_outgoingmessages ORDER BY id ASC @@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) { return } - for _, imsg := range msgs { - msg := imsg.(*models.DiscordOutgoingMessage) + for _, msg := range msgs { if time.Now().After(msg.ExpiresAt) { continue } diff --git a/src/discord/history.go b/src/discord/history.go index 138bae7..cca18f3 100644 --- a/src/discord/history.go +++ b/src/discord/history.go @@ -76,7 +76,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { type query struct { Message models.DiscordMessage `db:"msg"` } - imessagesWithoutContent, err := db.Query(ctx, dbConn, query{}, + messagesWithoutContent, err := db.Query[query](ctx, dbConn, ` SELECT $columns FROM @@ -95,10 +95,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { return } - if len(imessagesWithoutContent) > 0 { - log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent)) + if len(messagesWithoutContent) > 0 { + log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent)) msgloop: - for _, imsg := range imessagesWithoutContent { + for _, msgRow := range messagesWithoutContent { select { case <-ctx.Done(): log.Info().Msg("Scrape was canceled") @@ -106,7 +106,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { default: } - msg := imsg.(*query).Message + msg := msgRow.Message discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID) if errors.Is(err, NotFound) { diff --git a/src/discord/message_handling.go b/src/discord/message_handling.go index 29c511e..edd2603 100644 --- a/src/discord/message_handling.go +++ b/src/discord/message_handling.go @@ -165,7 +165,7 @@ func InternMessage( dbConn db.ConnOrTx, msg *Message, ) error { - _, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{}, + _, err := db.QueryOne[models.DiscordMessage](ctx, dbConn, ` SELECT $columns FROM handmade_discordmessage @@ -219,7 +219,7 @@ type InternedMessage struct { } func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) { - result, err := db.QueryOne(ctx, dbConn, InternedMessage{}, + interned, err := db.QueryOne[InternedMessage](ctx, dbConn, ` SELECT $columns FROM @@ -236,7 +236,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) return nil, err } - interned := result.(*InternedMessage) return interned, nil } @@ -283,7 +282,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message } func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error { - isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -294,10 +293,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In if err != nil && !errors.Is(err, db.NotFound) { return oops.New(err, "failed to fetch snippet for discord message") } - var snippet *models.Snippet - if !errors.Is(err, db.NotFound) { - snippet = isnippet.(*models.Snippet) - } // NOTE(asaf): Also deletes the following through a db cascade: // * handmade_discordmessageattachment @@ -367,7 +362,7 @@ func SaveMessageContents( return oops.New(err, "failed to create or update message contents") } - icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{}, + content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn, ` SELECT $columns FROM @@ -380,7 +375,7 @@ func SaveMessageContents( if err != nil { return oops.New(err, "failed to fetch message contents") } - interned.MessageContent = icontent.(*models.DiscordMessageContent) + interned.MessageContent = content } // Save attachments @@ -472,7 +467,7 @@ func saveAttachment( hmnUserID int, discordMessageID string, ) (*models.DiscordMessageAttachment, error) { - iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -481,7 +476,7 @@ func saveAttachment( attachment.ID, ) if err == nil { - return iexisting.(*models.DiscordMessageAttachment), nil + return existing, nil } else if errors.Is(err, db.NotFound) { // this is fine, just create it } else { @@ -534,7 +529,7 @@ func saveAttachment( return nil, oops.New(err, "failed to save Discord attachment data") } - iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -546,7 +541,7 @@ func saveAttachment( return nil, oops.New(err, "failed to fetch new Discord attachment data") } - return iDiscordAttachment.(*models.DiscordMessageAttachment), nil + return discordAttachment, nil } // Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it @@ -636,7 +631,7 @@ func saveEmbed( return nil, oops.New(err, "failed to insert new embed") } - iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{}, + discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -648,11 +643,11 @@ func saveEmbed( return nil, oops.New(err, "failed to fetch new Discord embed data") } - return iDiscordEmbed.(*models.DiscordMessageEmbed), nil + return discordEmbed, nil } func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) { - iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -669,7 +664,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin } } - return iresult.(*models.Snippet), nil + return snippet, nil } /* @@ -808,7 +803,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in type tagsRow struct { Tag models.Tag `db:"tags"` } - iUserTags, err := db.Query(ctx, tx, tagsRow{}, + userTags, err := db.Query[tagsRow](ctx, tx, ` SELECT $columns FROM @@ -823,8 +818,8 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in return oops.New(err, "failed to fetch tags for user projects") } - for _, itag := range iUserTags { - tag := itag.(*tagsRow).Tag + for _, userTag := range userTags { + tag := userTag.Tag allTags = append(allTags, tag.ID) for _, messageTag := range messageTags { if strings.EqualFold(tag.Text, messageTag) { @@ -890,7 +885,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) { // Check attachments - attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{}, + attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -901,13 +896,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch message attachments") } - for _, iattachment := range attachments { - attachment := iattachment.(*models.DiscordMessageAttachment) + for _, attachment := range attachments { return &attachment.AssetID, nil, nil } // Check embeds - embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{}, + embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -918,8 +912,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch discord embeds") } - for _, iembed := range embeds { - embed := iembed.(*models.DiscordMessageEmbed) + for _, embed := range embeds { if embed.VideoID != nil { return embed.VideoID, nil, nil } else if embed.ImageID != nil { diff --git a/src/hmndata/project_helper.go b/src/hmndata/project_helper.go index 7341ae7..cae1ee6 100644 --- a/src/hmndata/project_helper.go +++ b/src/hmndata/project_helper.go @@ -140,15 +140,15 @@ func FetchProjects( } // Do the query - iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...) + projects, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch projects") } // Fetch project owners to do permission checks - projectIds := make([]int, len(iprojects)) - for i, iproject := range iprojects { - projectIds[i] = iproject.(*projectRow).Project.ID + projectIds := make([]int, len(projects)) + for i, projectRow := range projects { + projectIds[i] = projectRow.Project.ID } projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds) if err != nil { @@ -156,8 +156,7 @@ func FetchProjects( } var res []ProjectAndStuff - for i, iproject := range iprojects { - row := iproject.(*projectRow) + for i, row := range projects { owners := projectOwners[i].Owners /* @@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners( UserID int `db:"user_id"` ProjectID int `db:"project_id"` } - iuserprojects, err := db.Query(ctx, tx, userProject{}, + userProjects, err := db.Query[userProject](ctx, tx, ` SELECT $columns FROM handmade_user_projects @@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners( // Get the unique user IDs from this set and fetch the users from the db var userIds []int - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { addUserId := true for _, uid := range userIds { if uid == userProject.UserID { @@ -364,7 +361,7 @@ func FetchMultipleProjectsOwners( type userQuery struct { User models.User `db:"auth_user"` } - iusers, err := db.Query(ctx, tx, userQuery{}, + projectUsers, err := db.Query[userQuery](ctx, tx, ` SELECT $columns FROM auth_user @@ -383,9 +380,7 @@ func FetchMultipleProjectsOwners( for i, pid := range projectIds { res[i] = ProjectOwners{ProjectID: pid} } - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { // Get a pointer to the existing record in the result var projectOwners *ProjectOwners for i := range res { @@ -396,8 +391,8 @@ func FetchMultipleProjectsOwners( // Get the full user record we fetched var user *models.User - for _, iuser := range iusers { - u := iuser.(*userQuery).User + for _, projectUser := range projectUsers { + u := projectUser.User if u.ID == userProject.UserID { user = &u } @@ -473,7 +468,7 @@ func SetProjectTag( resultTag = p.Tag } else if p.Project.TagID == nil { // Create a tag - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` INSERT INTO tags (text) VALUES ($1) RETURNING $columns @@ -483,7 +478,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to create new tag for project") } - resultTag = itag.(*models.Tag) + resultTag = tag // Attach it to the project _, err = tx.Exec(ctx, @@ -499,7 +494,7 @@ func SetProjectTag( } } else { // Update the text of an existing one - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` UPDATE tags SET text = $1 @@ -511,7 +506,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to update existing tag") } - resultTag = itag.(*models.Tag) + resultTag = tag } err = tx.Commit(ctx) diff --git a/src/hmndata/snippet_helper.go b/src/hmndata/snippet_helper.go index 99356de..9518d19 100644 --- a/src/hmndata/snippet_helper.go +++ b/src/hmndata/snippet_helper.go @@ -47,7 +47,7 @@ func FetchSnippets( type snippetIDRow struct { SnippetID int `db:"snippet_id"` } - iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{}, + snippetIDs, err := db.Query[snippetIDRow](ctx, tx, ` SELECT DISTINCT snippet_id FROM @@ -63,13 +63,13 @@ func FetchSnippets( } // special early-out: no snippets found for these tags at all - if len(iSnippetIDs) == 0 { + if len(snippetIDs) == 0 { return nil, nil } - q.IDs = make([]int, len(iSnippetIDs)) - for i := range iSnippetIDs { - q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID + q.IDs = make([]int, len(snippetIDs)) + for i := range snippetIDs { + q.IDs[i] = snippetIDs[i].SnippetID } } @@ -125,16 +125,14 @@ func FetchSnippets( DiscordMessage *models.DiscordMessage `db:"discord_message"` } - iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...) + rows, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not - snippetIDs := make([]int, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]SnippetAndStuff, len(rows)) // allocate extra space because why not + snippetIDs := make([]int, len(rows)) + for i, row := range rows { result[i] = SnippetAndStuff{ Snippet: row.Snippet, Owner: row.Owner, @@ -150,7 +148,7 @@ func FetchSnippets( SnippetID int `db:"snippet_tags.snippet_id"` Tag *models.Tag `db:"tags"` } - iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{}, + snippetTags, err := db.Query[snippetTagRow](ctx, tx, ` SELECT $columns FROM @@ -170,8 +168,7 @@ func FetchSnippets( for i := range result { resultBySnippetId[result[i].Snippet.ID] = &result[i] } - for _, iSnippetTag := range iSnippetTags { - snippetTag := iSnippetTag.(*snippetTagRow) + for _, snippetTag := range snippetTags { item := resultBySnippetId[snippetTag.SnippetID] item.Tags = append(item.Tags, snippetTag.Tag) } diff --git a/src/hmndata/tag_helper.go b/src/hmndata/tag_helper.go index fab5b94..debdd14 100644 --- a/src/hmndata/tag_helper.go +++ b/src/hmndata/tag_helper.go @@ -40,18 +40,12 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...) + tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch tags") } - res := make([]*models.Tag, len(itags)) - for i, itag := range itags { - tag := itag.(*models.Tag) - res[i] = tag - } - - return res, nil + return tags, nil } func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) { diff --git a/src/hmndata/threads_and_posts_helper.go b/src/hmndata/threads_and_posts_helper.go index a51cbc0..7adf00a 100644 --- a/src/hmndata/threads_and_posts_helper.go +++ b/src/hmndata/threads_and_posts_helper.go @@ -145,15 +145,13 @@ func FetchThreads( ForumLastReadTime *time.Time `db:"slri.lastread"` } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]ThreadAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]ThreadAndStuff, len(results)) + for i, row := range results { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) { hasRead = true @@ -405,15 +403,13 @@ func FetchPosts( qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch posts") } - result := make([]PostAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]PostAndStuff, len(results)) + for i, row := range results { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) { hasRead = true @@ -611,7 +607,7 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User type postResult struct { AuthorID *int `db:"post.author_id"` } - iresult, err := db.QueryOne(ctx, connOrTx, postResult{}, + result, err := db.QueryOne[postResult](ctx, connOrTx, ` SELECT $columns FROM @@ -629,7 +625,6 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User panic(oops.New(err, "failed to get author of post when checking permissions")) } } - result := iresult.(*postResult) return result.AuthorID != nil && *result.AuthorID == user.ID } @@ -709,7 +704,7 @@ func DeletePost( FirstPostID int `db:"first_id"` Deleted bool `db:"deleted"` } - ti, err := db.QueryOne(ctx, tx, threadInfo{}, + info, err := db.QueryOne[threadInfo](ctx, tx, ` SELECT $columns FROM @@ -722,7 +717,6 @@ func DeletePost( if err != nil { panic(oops.New(err, "failed to fetch thread info")) } - info := ti.(*threadInfo) if info.Deleted { return true } @@ -851,7 +845,7 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte type assetId struct { AssetID uuid.UUID `db:"id"` } - assetResult, err := db.Query(ctx, tx, assetId{}, + assets, err := db.Query[assetId](ctx, tx, ` SELECT $columns FROM handmade_asset @@ -865,8 +859,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte var values [][]interface{} - for _, asset := range assetResult { - values = append(values, []interface{}{postId, asset.(*assetId).AssetID}) + for _, asset := range assets { + values = append(values, []interface{}{postId, asset.AssetID}) } _, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values)) @@ -886,7 +880,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more. You should probably mark the thread as deleted in this case. */ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { - postsIter, err := db.Query(ctx, tx, models.Post{}, + posts, err := db.Query[models.Post](ctx, tx, ` SELECT $columns FROM handmade_post @@ -901,9 +895,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { } var firstPost, lastPost *models.Post - for _, ipost := range postsIter { - post := ipost.(*models.Post) - + for _, post := range posts { if firstPost == nil || post.PostDate.Before(firstPost.PostDate) { firstPost = post } diff --git a/src/models/subforum.go b/src/models/subforum.go index d65346a..2bb416d 100644 --- a/src/models/subforum.go +++ b/src/models/subforum.go @@ -47,7 +47,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { type subforumRow struct { Subforum Subforum `db:"sf"` } - rowsSlice, err := db.Query(ctx, conn, subforumRow{}, + rowsSlice, err := db.Query[subforumRow](ctx, conn, ` SELECT $columns FROM @@ -61,7 +61,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice)) for _, row := range rowsSlice { - sf := row.(*subforumRow).Subforum + sf := row.Subforum sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf} } @@ -73,7 +73,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { for _, row := range rowsSlice { // NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order. - cat := row.(*subforumRow).Subforum + cat := row.Subforum node := sfTreeMap[cat.ID] if node.Parent != nil { node.Parent.Children = append(node.Parent.Children, node)