diff --git a/go.mod b/go.mod index 8c56138a..08de4878 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,8 @@ module git.handmade.network/hmn/hmn -go 1.16 +go 1.18 require ( - github.com/Masterminds/goutils v1.1.1 // indirect - github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible github.com/alecthomas/chroma v0.9.2 github.com/aws/aws-sdk-go-v2 v1.8.1 @@ -16,13 +14,10 @@ require ( github.com/go-stack/stack v1.8.0 github.com/google/uuid v1.2.0 github.com/gorilla/websocket v1.4.2 - github.com/huandu/xstrings v1.3.2 // indirect - github.com/imdario/mergo v0.3.12 // indirect github.com/jackc/pgconn v1.8.0 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.10.1 github.com/jpillora/backoff v1.0.0 - github.com/mitchellh/copystructure v1.1.1 // indirect github.com/rs/zerolog v1.21.0 github.com/spf13/cobra v1.1.3 github.com/stretchr/testify v1.7.0 @@ -35,6 +30,39 @@ require ( golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d ) +require ( + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver v1.5.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect + github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.4.0 // indirect + github.com/huandu/xstrings v1.3.2 // indirect + github.com/imdario/mergo v0.3.12 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.0.6 // indirect + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect + github.com/jackc/puddle v1.1.3 // indirect + github.com/mitchellh/copystructure v1.1.1 // indirect + github.com/mitchellh/reflectwalk v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect + golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect + golang.org/x/text v0.3.6 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) + replace ( github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9 github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e diff --git a/go.sum b/go.sum index 6e6029ee..bb9c2856 100644 --- a/go.sum +++ b/go.sum @@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= diff --git a/src/admintools/adminproject.go b/src/admintools/adminproject.go index 4f1910ef..ed806998 100644 --- a/src/admintools/adminproject.go +++ b/src/admintools/adminproject.go @@ -55,7 +55,7 @@ func addCreateProjectCommand(projectCommand *cobra.Command) { } hmn := p.Project - newProjectID, err := db.QueryInt(ctx, tx, + newProjectID, err := db.QueryOneScalar[int](ctx, tx, ` INSERT INTO handmade_project ( slug, diff --git a/src/admintools/admintools.go b/src/admintools/admintools.go index d88d1e48..23bf624a 100644 --- a/src/admintools/admintools.go +++ b/src/admintools/admintools.go @@ -210,7 +210,7 @@ func init() { } defer tx.Rollback(ctx) - projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) + projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) if err != nil { panic(err) } @@ -218,7 +218,7 @@ func init() { var parentId *int if parentSlug == "" { // Select the root subforum - id, err := db.QueryInt(ctx, tx, + id, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`, projectId, ) @@ -228,7 +228,7 @@ func init() { parentId = &id } else { // Select the parent - id, err := db.QueryInt(ctx, tx, + id, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`, parentSlug, projectId, ) @@ -238,7 +238,7 @@ func init() { parentId = &id } - newId, err := db.QueryInt(ctx, tx, + newId, err := db.QueryOneScalar[int](ctx, tx, ` INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id) VALUES ($1, $2, $3, $4, $5) @@ -289,12 +289,12 @@ func init() { } defer tx.Rollback(ctx) - projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) + projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug) if err != nil { panic(err) } - subforumId, err := db.QueryInt(ctx, tx, + subforumId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`, subforumSlug, projectId, ) diff --git a/src/assets/assets.go b/src/assets/assets.go index 3b077b5c..4ab9dffb 100644 --- a/src/assets/assets.go +++ b/src/assets/assets.go @@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As } // Fetch and return the new record - iasset, err := db.QueryOne(ctx, dbConn, models.Asset{}, + asset, err := db.QueryOne[models.Asset](ctx, dbConn, ` SELECT $columns FROM handmade_asset @@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As return nil, oops.New(err, "failed to fetch newly-created asset") } - return iasset.(*models.Asset), nil + return asset, nil } diff --git a/src/auth/session.go b/src/auth/session.go index 84c103b3..c42d7d5d 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -45,7 +45,7 @@ func makeCSRFToken() string { var ErrNoSession = errors.New("no session found") func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { - row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) + sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id) if err != nil { if errors.Is(err, db.NotFound) { return nil, ErrNoSession @@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses return nil, oops.New(err, "failed to get session") } } - sess := row.(*models.Session) return sess, nil } diff --git a/src/db/db.go b/src/db/db.go index 20aee1cc..f89a7b73 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "strings" "git.handmade.network/hmn/hmn/src/config" @@ -20,46 +21,17 @@ import ( ) /* -Values of these kinds are ok to query even if they are not directly understood by pgtype. -This is common for custom types like: - - type ThreadType int +A general error to be used when no results are found. This is the error returned +by QueryOne, and can generally be used by other database helpers that fetch a single +result but find nothing. */ -var queryableKinds = []reflect.Kind{ - reflect.Int, -} - -/* -Checks if we are able to handle a particular type in a database query. This applies only to -primitive types and not structs, since the database only returns individual primitive types -and it is our job to stitch them back together into structs later. -*/ -func typeIsQueryable(t reflect.Type) bool { - _, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags - // NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway - - if isRecognizedByPgtype { - return true - } else if t == reflect.TypeOf(uuid.UUID{}) { - return true - } - - // pgtype doesn't recognize it, but maybe it's a primitive type we can deal with - k := t.Kind() - for _, qk := range queryableKinds { - if k == qk { - return true - } - } - - return false -} +var NotFound = errors.New("not found") // This interface should match both a direct pgx connection or a pgx transaction. type ConnOrTx interface { - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row - Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) // Both raw database connections and transactions in pgx can begin/commit // transactions. For database connections it does the obvious thing; for @@ -70,6 +42,8 @@ type ConnOrTx interface { var connInfo = pgtype.NewConnInfo() +// Creates a new connection to the HMN database. +// This connection is not safe for concurrent use. func NewConn() *pgx.Conn { conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN()) if err != nil { @@ -79,6 +53,8 @@ func NewConn() *pgx.Conn { return conn } +// Creates a connection pool for the HMN database. +// The resulting pool is safe for concurrent use. func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN()) @@ -95,144 +71,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { return conn } -type StructQueryIterator struct { - fieldPaths [][]int - rows pgx.Rows - destType reflect.Type - closed chan struct{} -} +/* +Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred. -func (it *StructQueryIterator) Next() (interface{}, bool) { - hasNext := it.rows.Next() - if !hasNext { - it.Close() - return nil, false - } +Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection. - result := reflect.New(it.destType) - - vals, err := it.rows.Values() - if err != nil { - panic(err) - } - - // Better logging of panics in this confusing reflection process - var currentField reflect.StructField - var currentValue reflect.Value - var currentIdx int - defer func() { - if r := recover(); r != nil { - if currentValue.IsValid() { - logging.Error(). - Int("index", currentIdx). - Str("field name", currentField.Name). - Stringer("field type", currentField.Type). - Interface("value", currentValue.Interface()). - Stringer("value type", currentValue.Type()). - Msg("panic in iterator") - } - - if currentField.Name != "" { - panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r)) - } else { - panic(r) - } - } - }() - - for i, val := range vals { - currentIdx = i - if val == nil { - continue - } - - var field reflect.Value - field, currentField = followPathThroughStructs(result, it.fieldPaths[i]) - if field.Kind() == reflect.Ptr { - field.Set(reflect.New(field.Type().Elem())) - field = field.Elem() - } - - // Some actual values still come through as pointers (like net.IPNet). Dunno why. - // Regardless, we know it's not nil, so we can get at the contents. - valReflected := reflect.ValueOf(val) - if valReflected.Kind() == reflect.Ptr { - valReflected = valReflected.Elem() - } - currentValue = valReflected - - switch field.Kind() { - case reflect.Int: - field.SetInt(valReflected.Int()) - default: - field.Set(valReflected) - } - - currentField = reflect.StructField{} - currentValue = reflect.Value{} - } - - return result.Interface(), true -} - -func (it *StructQueryIterator) Close() { - it.rows.Close() - select { - case it.closed <- struct{}{}: - default: - } -} - -func (it *StructQueryIterator) ToSlice() []interface{} { - defer it.Close() - var result []interface{} - for { - row, ok := it.Next() - if !ok { - err := it.rows.Err() - if err != nil { - panic(oops.New(err, "error while iterating through db results")) - } - break - } - result = append(result, row) - } - return result -} - -func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) { - if len(path) < 1 { - panic(oops.New(nil, "can't follow an empty path")) - } - - if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct { - panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type())) - } - - // more informative panic recovery - var field reflect.StructField - defer func() { - if r := recover(); r != nil { - panic(oops.New(nil, "panic at field '%s': %v", field.Name, r)) - } - }() - - val := structPtrVal - for _, i := range path { - if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct { - if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - val = val.Elem() - } - field = val.Type().Field(i) - val = val.Field(i) - } - return val, field -} - -func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) { - it, err := QueryIterator(ctx, conn, destExample, query, args...) +This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar. +*/ +func Query[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) ([]*T, error) { + it, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err } else { @@ -240,27 +92,99 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st } } -func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) { - destType := reflect.TypeOf(destExample) - columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil) +/* +Identical to Query, but returns only the first result row. If there are no +rows in the result set, returns NotFound. +*/ +func QueryOne[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (*T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { - return nil, oops.New(err, "failed to generate column names") + return nil, err + } + defer rows.Close() + + result, hasRow := rows.Next() + if !hasRow { + return nil, NotFound } - columns := make([]string, 0, len(columnNames)) - for _, strSlice := range columnNames { - tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") - fullName := strSlice[len(strSlice)-1] - if tableName != "" { - fullName = tableName + "." + fullName + return result, nil +} + +/* +Identical to Query, but returns concrete values instead of pointers. More convenient +for primitive types. +*/ +func QueryScalar[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) ([]T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []T + for { + val, hasRow := rows.Next() + if !hasRow { + break } - columns = append(columns, fullName) + result = append(result, *val) } - columnNamesString := strings.Join(columns, ", ") - query = strings.Replace(query, "$columns", columnNamesString, -1) + return result, nil +} - rows, err := conn.Query(ctx, query, args...) +/* +Identical to QueryScalar, but returns only the first result value. If there are +no rows in the result set, returns NotFound. +*/ +func QueryOneScalar[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) + if err != nil { + var zero T + return zero, err + } + defer rows.Close() + + result, hasRow := rows.Next() + if !hasRow { + var zero T + return zero, NotFound + } + + return *result, nil +} + +/* +Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use. +*/ +func QueryIterator[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (*Iterator[T], error) { + var destExample T + destType := reflect.TypeOf(destExample) + + compiled := compileQuery(query, destType) + + rows, err := conn.Query(ctx, compiled.query, args...) if err != nil { if errors.Is(err, context.DeadlineExceeded) { panic("query exceeded its deadline") @@ -268,11 +192,12 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, return nil, err } - it := &StructQueryIterator{ - fieldPaths: fieldPaths, - rows: rows, - destType: destType, - closed: make(chan struct{}, 1), + it := &Iterator[T]{ + fieldPaths: compiled.fieldPaths, + rows: rows, + destType: compiled.destType, + destTypeIsScalar: typeIsQueryable(compiled.destType), + closed: make(chan struct{}, 1), } // Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold @@ -292,16 +217,72 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, return it, nil } -func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { - var columnNames [][]string - var fieldPaths [][]int +// TODO: QueryFunc? + +type compiledQuery struct { + query string + destType reflect.Type + fieldPaths []fieldPath +} + +var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`) + +func compileQuery(query string, destType reflect.Type) compiledQuery { + columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query) + hasColumnsPlaceholder := columnsMatch != nil + + if hasColumnsPlaceholder { + // The presence of the $columns placeholder means that the destination type + // must be a struct, and we will plonk that struct's fields into the query. + + if destType.Kind() != reflect.Struct { + panic("$columns can only be used when querying into a struct") + } + + var prefix []string + prefixText := columnsMatch[2] + if prefixText != "" { + prefix = []string{prefixText} + } + + columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix) + + columns := make([]string, 0, len(columnNames)) + for _, strSlice := range columnNames { + tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") + fullName := strSlice[len(strSlice)-1] + if tableName != "" { + fullName = tableName + "." + fullName + } + columns = append(columns, fullName) + } + + columnNamesString := strings.Join(columns, ", ") + query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString) + + return compiledQuery{ + query: query, + destType: destType, + fieldPaths: fieldPaths, + } + } else { + return compiledQuery{ + query: query, + destType: destType, + } + } +} + +func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) { + var columnNames []columnName + var fieldPaths []fieldPath if destType.Kind() == reflect.Ptr { destType = destType.Elem() } if destType.Kind() != reflect.Struct { - return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) + panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)) } type AnonPrefix struct { @@ -348,108 +329,214 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str columnNames = append(columnNames, fieldColumnNames) fieldPaths = append(fieldPaths, path) } else if fieldType.Kind() == reflect.Struct { - subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) - if err != nil { - return nil, nil, err - } + subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) columnNames = append(columnNames, subCols...) fieldPaths = append(fieldPaths, subPaths...) } else { - return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) + panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)) } } } - return columnNames, fieldPaths, nil + return columnNames, fieldPaths } /* -A general error to be used when no results are found. This is the error returned -by QueryOne, and can generally be used by other database helpers that fetch a single -result but find nothing. +Values of these kinds are ok to query even if they are not directly understood by pgtype. +This is common for custom types like: + + type ThreadType int */ -var NotFound = errors.New("not found") - -func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { - rows, err := QueryIterator(ctx, conn, destExample, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - result, hasRow := rows.Next() - if !hasRow { - return nil, NotFound - } - - return result, nil +var queryableKinds = []reflect.Kind{ + reflect.Int, } -func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) { - rows, err := conn.Query(ctx, query, args...) - if err != nil { - return nil, err +/* +Checks if we are able to handle a particular type in a database query. This applies only to +primitive types and not structs, since the database only returns individual primitive types +and it is our job to stitch them back together into structs later. +*/ +func typeIsQueryable(t reflect.Type) bool { + _, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags + // NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway + + if isRecognizedByPgtype { + return true + } else if t == reflect.TypeOf(uuid.UUID{}) { + return true } - defer rows.Close() - if rows.Next() { - vals, err := rows.Values() - if err != nil { - panic(err) + // pgtype doesn't recognize it, but maybe it's a primitive type we can deal with + k := t.Kind() + for _, qk := range queryableKinds { + if k == qk { + return true } + } + return false +} + +type columnName []string + +// A path to a particular field in query's destination type. Each index in the slice +// corresponds to a field index for use with Field on a reflect.Type or reflect.Value. +type fieldPath []int + +type Iterator[T any] struct { + fieldPaths []fieldPath + rows pgx.Rows + destType reflect.Type + destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function. + closed chan struct{} +} + +func (it *Iterator[T]) Next() (*T, bool) { + // TODO(ben): What happens if this panics? Does it leak resources? Do we need + // to put a recover() here and close the rows? + + hasNext := it.rows.Next() + if !hasNext { + it.Close() + return nil, false + } + + result := reflect.New(it.destType) + + vals, err := it.rows.Values() + if err != nil { + panic(err) + } + + if it.destTypeIsScalar { + // This type can be directly queried, meaning pgx recognizes it, it's + // a simple scalar thing, and we can just take the easy way out. if len(vals) != 1 { - return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals)) + panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals))) + } + setValueFromDB(result.Elem(), reflect.ValueOf(vals[0])) + return result.Interface().(*T), true + } else { + var currentField reflect.StructField + var currentValue reflect.Value + var currentIdx int + + // Better logging of panics in this confusing reflection process + defer func() { + if r := recover(); r != nil { + if currentValue.IsValid() { + logging.Error(). + Int("index", currentIdx). + Str("field name", currentField.Name). + Stringer("field type", currentField.Type). + Interface("value", currentValue.Interface()). + Stringer("value type", currentValue.Type()). + Msg("panic in iterator") + } + + if currentField.Name != "" { + panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r)) + } else { + panic(r) + } + } + }() + + for i, val := range vals { + currentIdx = i + if val == nil { + continue + } + + var field reflect.Value + field, currentField = followPathThroughStructs(result, it.fieldPaths[i]) + if field.Kind() == reflect.Ptr { + field.Set(reflect.New(field.Type().Elem())) + field = field.Elem() + } + + // Some actual values still come through as pointers (like net.IPNet). Dunno why. + // Regardless, we know it's not nil, so we can get at the contents. + valReflected := reflect.ValueOf(val) + if valReflected.Kind() == reflect.Ptr { + valReflected = valReflected.Elem() + } + currentValue = valReflected + + setValueFromDB(field, valReflected) + + currentField = reflect.StructField{} + currentValue = reflect.Value{} } - return vals[0], nil + return result.Interface().(*T), true } - - return nil, NotFound } -func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return "", err - } - - switch r := result.(type) { - case string: - return r, nil +func setValueFromDB(dest reflect.Value, value reflect.Value) { + switch dest.Kind() { + case reflect.Int: + dest.SetInt(value.Int()) default: - return "", oops.New(nil, "QueryString got a non-string result: %v", result) + dest.Set(value) } } -func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return 0, err - } - - switch r := result.(type) { - case int: - return r, nil - case int32: - return int(r), nil - case int64: - return int(r), nil +func (it *Iterator[any]) Close() { + it.rows.Close() + select { + case it.closed <- struct{}{}: default: - return 0, oops.New(nil, "QueryInt got a non-int result: %v", result) } } -func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return false, err +/* +Pulls all the remaining values into a slice, and closes the iterator. +*/ +func (it *Iterator[T]) ToSlice() []*T { + defer it.Close() + var result []*T + for { + row, ok := it.Next() + if !ok { + err := it.rows.Err() + if err != nil { + panic(oops.New(err, "error while iterating through db results")) + } + break + } + result = append(result, row) + } + return result +} + +func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) { + if len(path) < 1 { + panic(oops.New(nil, "can't follow an empty path")) } - switch r := result.(type) { - case bool: - return r, nil - default: - return false, oops.New(nil, "QueryBool got a non-bool result: %v", result) + if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct { + panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type())) } + + // more informative panic recovery + var field reflect.StructField + defer func() { + if r := recover(); r != nil { + panic(oops.New(nil, "panic at field '%s': %v", field.Name, r)) + } + }() + + val := structPtrVal + for _, i := range path { + if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + field = val.Type().Field(i) + val = val.Field(i) + } + return val, field } diff --git a/src/db/db_test.go b/src/db/db_test.go index 41b9491d..edc647ed 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -10,52 +10,143 @@ import ( func TestPaths(t *testing.T) { type CustomInt int - type S struct { - I int `db:"I"` - PI *int `db:"PI"` - CI CustomInt `db:"CI"` - PCI *CustomInt `db:"PCI"` - B bool `db:"B"` - PB *bool `db:"PB"` + type S2 struct { + B bool `db:"B"` // field 0 + PB *bool `db:"PB"` // field 1 - NoTag int + NoTag string // field 2 + } + type S struct { + I int `db:"I"` // field 0 + PI *int `db:"PI"` // field 1 + CI CustomInt `db:"CI"` // field 2 + PCI *CustomInt `db:"PCI"` // field 3 + S2 `db:"S2"` // field 4 (embedded!) + PS2 *S2 `db:"PS2"` // field 5 + + NoTag int // field 6 } type Nested struct { - S S `db:"S"` - PS *S `db:"PS"` + S S `db:"S"` // field 0 + PS *S `db:"PS"` // field 1 - NoTag S + NoTag S // field 2 } type Embedded struct { - NoTag S - Nested + NoTag S // field 0 + Nested // field 1 } - names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") - if assert.Nil(t, err) { - assert.Equal(t, []string{ - "S.I", "S.PI", - "S.CI", "S.PCI", - "S.B", "S.PB", - "PS.I", "PS.PI", - "PS.CI", "PS.PCI", - "PS.B", "PS.PB", - }, names) - assert.Equal(t, [][]int{ - {1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, - {1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, - }, paths) - assert.True(t, len(names) == len(paths)) - } + names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil) + assert.Equal(t, []columnName{ + {"S", "I"}, {"S", "PI"}, + {"S", "CI"}, {"S", "PCI"}, + {"S", "S2", "B"}, {"S", "S2", "PB"}, + {"S", "PS2", "B"}, {"S", "PS2", "PB"}, + {"PS", "I"}, {"PS", "PI"}, + {"PS", "CI"}, {"PS", "PCI"}, + {"PS", "S2", "B"}, {"PS", "S2", "PB"}, + {"PS", "PS2", "B"}, {"PS", "PS2", "PB"}, + }, names) + assert.Equal(t, []fieldPath{ + {1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI + {1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI + {1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB + {1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB + {1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI + {1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI + {1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB + {1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB + }, paths) + assert.True(t, len(names) == len(paths)) testStruct := Embedded{} for i, path := range paths { val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) assert.True(t, val.IsValid()) - assert.True(t, strings.Contains(names[i], field.Name)) + assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name)) } } +func TestCompileQuery(t *testing.T) { + t.Run("simple struct", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query) + }) + t.Run("complex structs", func(t *testing.T) { + type CustomInt int + type S2 struct { + B bool `db:"B"` + PB *bool `db:"PB"` + + NoTag string + } + type S struct { + I int `db:"I"` + PI *int `db:"PI"` + CI CustomInt `db:"CI"` + PCI *CustomInt `db:"PCI"` + S2 `db:"S2"` // embedded! + PS2 *S2 `db:"PS2"` + + NoTag int + } + type Nested struct { + S S `db:"S"` + PS *S `db:"PS"` + + NoTag S + } + type Dest struct { + NoTag S + Nested + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query) + }) + t.Run("int", func(t *testing.T) { + type Dest int + + // There should be no error here because we do not need to extract columns from + // the destination type. There may be errors down the line in value iteration, but + // that is always the case if the Go types don't match the query. + compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0))) + assert.Equal(t, "SELECT id FROM greeblies", compiled.query) + }) + t.Run("just one table", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + // The prefix is necessary because otherwise we would have to provide a struct with + // a db tag in order to provide the query with the `greeblies.` prefix in the + // final query. This comes up a lot when we do a JOIN to help with a condition, but + // don't actually care about any of the data we joined to. + compiled := compileQuery( + "SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props", + reflect.TypeOf(Dest{}), + ) + assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query) + }) + + t.Run("using $columns without a struct is not allowed", func(t *testing.T) { + type Dest int + + assert.Panics(t, func() { + compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0))) + }) + }) +} + func TestQueryBuilder(t *testing.T) { t.Run("happy time", func(t *testing.T) { var qb QueryBuilder diff --git a/src/db/doc.go b/src/db/doc.go new file mode 100644 index 00000000..6d66b48c --- /dev/null +++ b/src/db/doc.go @@ -0,0 +1,59 @@ +/* +This package contains lowish-level APIs for making database queries to our Postgres database. It streamlines the process of mapping query results to Go types, while allowing you to write arbitrary SQL queries. + +The primary functions are Query and QueryIterator. See the package and function examples for detailed usage. + +Query syntax + +This package allows a few small extensions to SQL syntax to streamline the interaction between Go and Postgres. + +Arguments can be provided using placeholders like $1, $2, etc. All arguments will be safely escaped and mapped from their Go type to the correct Postgres type. (This is a direct proxy to pgx.) + + projectIDs, err := db.Query[int](ctx, conn, + ` + SELECT id + FROM handmade_project + WHERE + slug = ANY($1) + AND hidden = $2 + `, + []string{"4coder", "metadesk"}, + false, + ) + +(This also demonstrates a useful tip: if you want to use a slice in your query, use Postgres arrays instead of IN.) + +When querying individual fields, you can simply select the field like so: + + ids, err := db.Query[int](ctx, conn, `SELECT id FROM handmade_project`) + +To query multiple columns at once, you may use a struct type with `db:"column_name"` tags, and the special $columns placeholder: + + type Project struct { + ID int `db:"id"` + Slug string `db:"slug"` + DateCreated time.Time `db:"date_created"` + } + projects, err := db.Query[Project](ctx, conn, `SELECT $columns FROM ...`) + // Resulting query: + // SELECT id, slug, date_created FROM ... + +Sometimes a table name prefix is required on each column to disambiguate between column names, especially when performing a JOIN. In those situations, you can include the prefix in the $columns placeholder like $columns{prefix}: + + type Project struct { + ID int `db:"id"` + Slug string `db:"slug"` + DateCreated time.Time `db:"date_created"` + } + orphanedProjects, err := db.Query[Project](ctx, conn, ` + SELECT $columns{projects} + FROM + handmade_project AS projects + LEFT JOIN handmade_user_projects AS uproj + WHERE + uproj.user_id IS NULL + `) + // Resulting query: + // SELECT projects.id, projects.slug, projects.date_created FROM ... +*/ +package db diff --git a/src/db/query_builder.go b/src/db/query_builder.go index 80a687f0..d8dcb41d 100644 --- a/src/db/query_builder.go +++ b/src/db/query_builder.go @@ -7,7 +7,7 @@ import ( type QueryBuilder struct { sql strings.Builder - args []interface{} + args []any } /* @@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number. foo ARG1 bar ARG2 baz $? foo ARG1 bar ARG2 baz ARG3 */ -func (qb *QueryBuilder) Add(sql string, args ...interface{}) { +func (qb *QueryBuilder) Add(sql string, args ...any) { numPlaceholders := strings.Count(sql, "$?") if numPlaceholders != len(args) { panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args))) @@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string { return qb.sql.String() } -func (qb *QueryBuilder) Args() []interface{} { +func (qb *QueryBuilder) Args() []any { return qb.args } diff --git a/src/discord/commands.go b/src/discord/commands.go index 5918bd6c..eef3cba0 100644 --- a/src/discord/commands.go +++ b/src/discord/commands.go @@ -90,12 +90,9 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction return } - type profileResult struct { - HMNUser models.User `db:"auth_user"` - } - ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{}, + hmnUser, err := db.QueryOne[models.User](ctx, bot.dbConn, ` - SELECT $columns + SELECT $columns{auth_user} FROM handmade_discorduser AS duser JOIN auth_user ON duser.hmn_user_id = auth_user.id @@ -122,16 +119,15 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction } return } - res := ires.(*profileResult) projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{ - OwnerIDs: []int{res.HMNUser.ID}, + OwnerIDs: []int{hmnUser.ID}, }) if err != nil { logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects") } - url := hmnurl.BuildUserProfile(res.HMNUser.Username) + url := hmnurl.BuildUserProfile(hmnUser.Username) msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url) if len(projectsAndStuff) > 0 { projectNoun := "projects" diff --git a/src/discord/gateway.go b/src/discord/gateway.go index 55e63cd9..9fb2b963 100644 --- a/src/discord/gateway.go +++ b/src/discord/gateway.go @@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error { // an old one or starting a new one. shouldResume := true - isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) + session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`) if err != nil { if errors.Is(err, db.NotFound) { // No session yet! Just identify and get on with it @@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error { if shouldResume { // Reconnect to the previous session - session := isession.(*models.DiscordSession) - err := bot.sendGatewayMessage(ctx, GatewayMessage{ Opcode: OpcodeResume, Data: Resume{ @@ -356,7 +354,7 @@ func (bot *botInstance) doSender(ctx context.Context) { } bot.didAckHeartbeat = false - latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) + latestSequenceNumber, err := db.QueryOneScalar[int](ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`) if err != nil { log.Error().Err(err).Msg("failed to fetch latest sequence number from the db") return false @@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) { } defer tx.Rollback(ctx) - msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` + msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, ` SELECT $columns FROM discord_outgoingmessages ORDER BY id ASC @@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) { return } - for _, imsg := range msgs { - msg := imsg.(*models.DiscordOutgoingMessage) + for _, msg := range msgs { if time.Now().After(msg.ExpiresAt) { continue } diff --git a/src/discord/history.go b/src/discord/history.go index 138bae75..3a69d066 100644 --- a/src/discord/history.go +++ b/src/discord/history.go @@ -73,12 +73,9 @@ func RunHistoryWatcher(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { log := logging.ExtractLogger(ctx) - type query struct { - Message models.DiscordMessage `db:"msg"` - } - imessagesWithoutContent, err := db.Query(ctx, dbConn, query{}, + messagesWithoutContent, err := db.Query[models.DiscordMessage](ctx, dbConn, ` - SELECT $columns + SELECT $columns{msg} FROM handmade_discordmessage AS msg JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users @@ -95,10 +92,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { return } - if len(imessagesWithoutContent) > 0 { - log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent)) + if len(messagesWithoutContent) > 0 { + log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent)) msgloop: - for _, imsg := range imessagesWithoutContent { + for _, msg := range messagesWithoutContent { select { case <-ctx.Done(): log.Info().Msg("Scrape was canceled") @@ -106,8 +103,6 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { default: } - msg := imsg.(*query).Message - discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID) if errors.Is(err, NotFound) { // This message has apparently been deleted; delete it from our database diff --git a/src/discord/message_handling.go b/src/discord/message_handling.go index 29c511e0..e9172f09 100644 --- a/src/discord/message_handling.go +++ b/src/discord/message_handling.go @@ -165,7 +165,7 @@ func InternMessage( dbConn db.ConnOrTx, msg *Message, ) error { - _, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{}, + _, err := db.QueryOne[models.DiscordMessage](ctx, dbConn, ` SELECT $columns FROM handmade_discordmessage @@ -219,7 +219,7 @@ type InternedMessage struct { } func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) { - result, err := db.QueryOne(ctx, dbConn, InternedMessage{}, + interned, err := db.QueryOne[InternedMessage](ctx, dbConn, ` SELECT $columns FROM @@ -235,8 +235,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) if err != nil { return nil, err } - - interned := result.(*InternedMessage) return interned, nil } @@ -283,7 +281,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message } func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error { - isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -294,10 +292,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In if err != nil && !errors.Is(err, db.NotFound) { return oops.New(err, "failed to fetch snippet for discord message") } - var snippet *models.Snippet - if !errors.Is(err, db.NotFound) { - snippet = isnippet.(*models.Snippet) - } // NOTE(asaf): Also deletes the following through a db cascade: // * handmade_discordmessageattachment @@ -367,7 +361,7 @@ func SaveMessageContents( return oops.New(err, "failed to create or update message contents") } - icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{}, + content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn, ` SELECT $columns FROM @@ -380,7 +374,7 @@ func SaveMessageContents( if err != nil { return oops.New(err, "failed to fetch message contents") } - interned.MessageContent = icontent.(*models.DiscordMessageContent) + interned.MessageContent = content } // Save attachments @@ -395,12 +389,12 @@ func SaveMessageContents( // Save / delete embeds if msg.OriginalHasFields("embeds") { - numSavedEmbeds, err := db.QueryInt(ctx, dbConn, + numSavedEmbeds, err := db.QueryOneScalar[int](ctx, dbConn, ` - SELECT COUNT(*) - FROM handmade_discordmessageembed - WHERE message_id = $1 - `, + SELECT COUNT(*) + FROM handmade_discordmessageembed + WHERE message_id = $1 + `, msg.ID, ) if err != nil { @@ -472,7 +466,7 @@ func saveAttachment( hmnUserID int, discordMessageID string, ) (*models.DiscordMessageAttachment, error) { - iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -481,7 +475,7 @@ func saveAttachment( attachment.ID, ) if err == nil { - return iexisting.(*models.DiscordMessageAttachment), nil + return existing, nil } else if errors.Is(err, db.NotFound) { // this is fine, just create it } else { @@ -534,7 +528,7 @@ func saveAttachment( return nil, oops.New(err, "failed to save Discord attachment data") } - iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, + discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -546,7 +540,7 @@ func saveAttachment( return nil, oops.New(err, "failed to fetch new Discord attachment data") } - return iDiscordAttachment.(*models.DiscordMessageAttachment), nil + return discordAttachment, nil } // Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it @@ -636,7 +630,7 @@ func saveEmbed( return nil, oops.New(err, "failed to insert new embed") } - iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{}, + discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -648,11 +642,11 @@ func saveEmbed( return nil, oops.New(err, "failed to fetch new Discord embed data") } - return iDiscordEmbed.(*models.DiscordMessageEmbed), nil + return discordEmbed, nil } func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) { - iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{}, + snippet, err := db.QueryOne[models.Snippet](ctx, dbConn, ` SELECT $columns FROM handmade_snippet @@ -669,7 +663,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin } } - return iresult.(*models.Snippet), nil + return snippet, nil } /* @@ -805,12 +799,9 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in projectIDs[i] = p.Project.ID } - type tagsRow struct { - Tag models.Tag `db:"tags"` - } - iUserTags, err := db.Query(ctx, tx, tagsRow{}, + userTags, err := db.Query[models.Tag](ctx, tx, ` - SELECT $columns + SELECT $columns{tags} FROM tags JOIN handmade_project AS project ON project.tag = tags.id @@ -823,8 +814,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in return oops.New(err, "failed to fetch tags for user projects") } - for _, itag := range iUserTags { - tag := itag.(*tagsRow).Tag + for _, tag := range userTags { allTags = append(allTags, tag.ID) for _, messageTag := range messageTags { if strings.EqualFold(tag.Text, messageTag) { @@ -890,7 +880,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) { // Check attachments - attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{}, + attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx, ` SELECT $columns FROM handmade_discordmessageattachment @@ -901,13 +891,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch message attachments") } - for _, iattachment := range attachments { - attachment := iattachment.(*models.DiscordMessageAttachment) + for _, attachment := range attachments { return &attachment.AssetID, nil, nil } // Check embeds - embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{}, + embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx, ` SELECT $columns FROM handmade_discordmessageembed @@ -918,8 +907,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco if err != nil { return nil, nil, oops.New(err, "failed to fetch discord embeds") } - for _, iembed := range embeds { - embed := iembed.(*models.DiscordMessageEmbed) + for _, embed := range embeds { if embed.VideoID != nil { return embed.VideoID, nil, nil } else if embed.ImageID != nil { diff --git a/src/hmndata/project_helper.go b/src/hmndata/project_helper.go index 7341ae7d..128665ac 100644 --- a/src/hmndata/project_helper.go +++ b/src/hmndata/project_helper.go @@ -140,15 +140,15 @@ func FetchProjects( } // Do the query - iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...) + projectRows, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch projects") } // Fetch project owners to do permission checks - projectIds := make([]int, len(iprojects)) - for i, iproject := range iprojects { - projectIds[i] = iproject.(*projectRow).Project.ID + projectIds := make([]int, len(projectRows)) + for i, p := range projectRows { + projectIds[i] = p.Project.ID } projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds) if err != nil { @@ -156,8 +156,7 @@ func FetchProjects( } var res []ProjectAndStuff - for i, iproject := range iprojects { - row := iproject.(*projectRow) + for i, p := range projectRows { owners := projectOwners[i].Owners /* @@ -191,10 +190,10 @@ func FetchProjects( } projectGenerallyVisible := true && - row.Project.Lifecycle.In(models.VisibleProjectLifecycles) && - !row.Project.Hidden && - (!row.Project.Personal || allOwnersApproved || row.Project.IsHMN()) - if row.Project.IsHMN() { + p.Project.Lifecycle.In(models.VisibleProjectLifecycles) && + !p.Project.Hidden && + (!p.Project.Personal || allOwnersApproved || p.Project.IsHMN()) + if p.Project.IsHMN() { projectGenerallyVisible = true // hard override } @@ -205,11 +204,11 @@ func FetchProjects( if projectVisible { res = append(res, ProjectAndStuff{ - Project: row.Project, - LogoLightAsset: row.LogoLightAsset, - LogoDarkAsset: row.LogoDarkAsset, + Project: p.Project, + LogoLightAsset: p.LogoLightAsset, + LogoDarkAsset: p.LogoDarkAsset, Owners: owners, - Tag: row.Tag, + Tag: p.Tag, }) } } @@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners( UserID int `db:"user_id"` ProjectID int `db:"project_id"` } - iuserprojects, err := db.Query(ctx, tx, userProject{}, + userProjects, err := db.Query[userProject](ctx, tx, ` SELECT $columns FROM handmade_user_projects @@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners( // Get the unique user IDs from this set and fetch the users from the db var userIds []int - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { addUserId := true for _, uid := range userIds { if uid == userProject.UserID { @@ -361,14 +358,12 @@ func FetchMultipleProjectsOwners( userIds = append(userIds, userProject.UserID) } } - type userQuery struct { - User models.User `db:"auth_user"` - } - iusers, err := db.Query(ctx, tx, userQuery{}, + users, err := db.Query[models.User](ctx, tx, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE auth_user.id = ANY($1) `, @@ -383,9 +378,7 @@ func FetchMultipleProjectsOwners( for i, pid := range projectIds { res[i] = ProjectOwners{ProjectID: pid} } - for _, iuserproject := range iuserprojects { - userProject := iuserproject.(*userProject) - + for _, userProject := range userProjects { // Get a pointer to the existing record in the result var projectOwners *ProjectOwners for i := range res { @@ -396,10 +389,9 @@ func FetchMultipleProjectsOwners( // Get the full user record we fetched var user *models.User - for _, iuser := range iusers { - u := iuser.(*userQuery).User + for _, u := range users { if u.ID == userProject.UserID { - user = &u + user = u } } if user == nil { @@ -473,7 +465,7 @@ func SetProjectTag( resultTag = p.Tag } else if p.Project.TagID == nil { // Create a tag - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` INSERT INTO tags (text) VALUES ($1) RETURNING $columns @@ -483,7 +475,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to create new tag for project") } - resultTag = itag.(*models.Tag) + resultTag = tag // Attach it to the project _, err = tx.Exec(ctx, @@ -499,7 +491,7 @@ func SetProjectTag( } } else { // Update the text of an existing one - itag, err := db.QueryOne(ctx, tx, models.Tag{}, + tag, err := db.QueryOne[models.Tag](ctx, tx, ` UPDATE tags SET text = $1 @@ -511,7 +503,7 @@ func SetProjectTag( if err != nil { return nil, oops.New(err, "failed to update existing tag") } - resultTag = itag.(*models.Tag) + resultTag = tag } err = tx.Commit(ctx) diff --git a/src/hmndata/snippet_helper.go b/src/hmndata/snippet_helper.go index 99356dea..ecdd3a8f 100644 --- a/src/hmndata/snippet_helper.go +++ b/src/hmndata/snippet_helper.go @@ -44,10 +44,7 @@ func FetchSnippets( if len(q.Tags) > 0 { // Get snippet IDs with this tag, then use that in the main query - type snippetIDRow struct { - SnippetID int `db:"snippet_id"` - } - iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{}, + snippetIDs, err := db.QueryScalar[int](ctx, tx, ` SELECT DISTINCT snippet_id FROM @@ -63,14 +60,11 @@ func FetchSnippets( } // special early-out: no snippets found for these tags at all - if len(iSnippetIDs) == 0 { + if len(snippetIDs) == 0 { return nil, nil } - q.IDs = make([]int, len(iSnippetIDs)) - for i := range iSnippetIDs { - q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID - } + q.IDs = snippetIDs } var qb db.QueryBuilder @@ -125,16 +119,14 @@ func FetchSnippets( DiscordMessage *models.DiscordMessage `db:"discord_message"` } - iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...) + results, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not - snippetIDs := make([]int, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]SnippetAndStuff, len(results)) // allocate extra space because why not + snippetIDs := make([]int, len(results)) + for i, row := range results { result[i] = SnippetAndStuff{ Snippet: row.Snippet, Owner: row.Owner, @@ -150,7 +142,7 @@ func FetchSnippets( SnippetID int `db:"snippet_tags.snippet_id"` Tag *models.Tag `db:"tags"` } - iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{}, + snippetTags, err := db.Query[snippetTagRow](ctx, tx, ` SELECT $columns FROM @@ -170,8 +162,7 @@ func FetchSnippets( for i := range result { resultBySnippetId[result[i].Snippet.ID] = &result[i] } - for _, iSnippetTag := range iSnippetTags { - snippetTag := iSnippetTag.(*snippetTagRow) + for _, snippetTag := range snippetTags { item := resultBySnippetId[snippetTag.SnippetID] item.Tags = append(item.Tags, snippetTag.Tag) } diff --git a/src/hmndata/tag_helper.go b/src/hmndata/tag_helper.go index fab5b948..8c0fe63b 100644 --- a/src/hmndata/tag_helper.go +++ b/src/hmndata/tag_helper.go @@ -40,18 +40,11 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...) + tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch tags") } - - res := make([]*models.Tag, len(itags)) - for i, itag := range itags { - tag := itag.(*models.Tag) - res[i] = tag - } - - return res, nil + return tags, nil } func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) { diff --git a/src/hmndata/threads_and_posts_helper.go b/src/hmndata/threads_and_posts_helper.go index a51cbc09..6525e4ed 100644 --- a/src/hmndata/threads_and_posts_helper.go +++ b/src/hmndata/threads_and_posts_helper.go @@ -145,15 +145,13 @@ func FetchThreads( ForumLastReadTime *time.Time `db:"slri.lastread"` } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch threads") } - result := make([]ThreadAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]ThreadAndStuff, len(rows)) + for i, row := range rows { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) { hasRead = true @@ -263,7 +261,7 @@ func CountThreads( ) } - count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return 0, oops.New(err, "failed to fetch count of threads") } @@ -405,15 +403,13 @@ func FetchPosts( qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) } - iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) + rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return nil, oops.New(err, "failed to fetch posts") } - result := make([]PostAndStuff, len(iresults)) - for i, iresult := range iresults { - row := *iresult.(*resultRow) - + result := make([]PostAndStuff, len(rows)) + for i, row := range rows { hasRead := false if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) { hasRead = true @@ -595,7 +591,7 @@ func CountPosts( ) } - count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...) + count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...) if err != nil { return 0, oops.New(err, "failed to count posts") } @@ -608,12 +604,9 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User return true } - type postResult struct { - AuthorID *int `db:"post.author_id"` - } - iresult, err := db.QueryOne(ctx, connOrTx, postResult{}, + authorID, err := db.QueryOneScalar[*int](ctx, connOrTx, ` - SELECT $columns + SELECT post.author_id FROM handmade_post AS post WHERE @@ -629,9 +622,8 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User panic(oops.New(err, "failed to get author of post when checking permissions")) } } - result := iresult.(*postResult) - return result.AuthorID != nil && *result.AuthorID == user.ID + return authorID != nil && *authorID == user.ID } func CreateNewPost( @@ -709,7 +701,7 @@ func DeletePost( FirstPostID int `db:"first_id"` Deleted bool `db:"deleted"` } - ti, err := db.QueryOne(ctx, tx, threadInfo{}, + info, err := db.QueryOne[threadInfo](ctx, tx, ` SELECT $columns FROM @@ -722,7 +714,6 @@ func DeletePost( if err != nil { panic(oops.New(err, "failed to fetch thread info")) } - info := ti.(*threadInfo) if info.Deleted { return true } @@ -848,12 +839,9 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte keys = append(keys, key) } - type assetId struct { - AssetID uuid.UUID `db:"id"` - } - assetResult, err := db.Query(ctx, tx, assetId{}, + assetIDs, err := db.QueryScalar[uuid.UUID](ctx, tx, ` - SELECT $columns + SELECT id FROM handmade_asset WHERE s3_key = ANY($1) `, @@ -865,8 +853,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte var values [][]interface{} - for _, asset := range assetResult { - values = append(values, []interface{}{postId, asset.(*assetId).AssetID}) + for _, assetID := range assetIDs { + values = append(values, []interface{}{postId, assetID}) } _, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values)) @@ -886,7 +874,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more. You should probably mark the thread as deleted in this case. */ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { - postsIter, err := db.Query(ctx, tx, models.Post{}, + posts, err := db.Query[models.Post](ctx, tx, ` SELECT $columns FROM handmade_post @@ -901,9 +889,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { } var firstPost, lastPost *models.Post - for _, ipost := range postsIter { - post := ipost.(*models.Post) - + for _, post := range posts { if firstPost == nil || post.PostDate.Before(firstPost.PostDate) { firstPost = post } diff --git a/src/hmndata/twitch.go b/src/hmndata/twitch.go index 3ad4d969..f52bc5f7 100644 --- a/src/hmndata/twitch.go +++ b/src/hmndata/twitch.go @@ -22,12 +22,9 @@ type TwitchStreamer struct { var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P[^/]+)$`) func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) { - type linkResult struct { - Link models.Link `db:"link"` - } - streamers, err := db.Query(ctx, dbConn, linkResult{}, + dbStreamers, err := db.Query[models.Link](ctx, dbConn, ` - SELECT $columns + SELECT $columns{link} FROM handmade_links AS link LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id @@ -49,10 +46,8 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre return nil, oops.New(err, "failed to fetch twitch links") } - result := make([]TwitchStreamer, 0, len(streamers)) - for _, s := range streamers { - dbStreamer := s.(*linkResult).Link - + result := make([]TwitchStreamer, 0, len(dbStreamers)) + for _, dbStreamer := range dbStreamers { streamer := TwitchStreamer{ UserID: dbStreamer.UserID, ProjectID: dbStreamer.ProjectID, @@ -81,7 +76,7 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre } func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) { - links, err := db.Query(ctx, dbConn, models.Link{}, + links, err := db.Query[models.Link](ctx, dbConn, ` SELECT $columns FROM @@ -100,8 +95,7 @@ func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, result := make([]string, 0, len(links)) for _, l := range links { - url := l.(*models.Link).URL - match := twitchRegex.FindStringSubmatch(url) + match := twitchRegex.FindStringSubmatch(l.URL) if match != nil { login := strings.ToLower(match[twitchRegex.SubexpIndex("login")]) result = append(result, login) diff --git a/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go b/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go index 552319b0..29168d00 100644 --- a/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go +++ b/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go @@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error { // Port "jam snippets" to use a tag // - jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`) + jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`) if err != nil { return oops.New(err, "failed to create jam tag") } diff --git a/src/models/subforum.go b/src/models/subforum.go index d65346a7..1bed9e2a 100644 --- a/src/models/subforum.go +++ b/src/models/subforum.go @@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum { } func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { - type subforumRow struct { - Subforum Subforum `db:"sf"` - } - rowsSlice, err := db.Query(ctx, conn, subforumRow{}, + subforums, err := db.Query[Subforum](ctx, conn, ` SELECT $columns - FROM - handmade_subforum as sf + FROM handmade_subforum ORDER BY sort, id ASC `, ) @@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { panic(oops.New(err, "failed to fetch subforum tree")) } - sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice)) - for _, row := range rowsSlice { - sf := row.(*subforumRow).Subforum - sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf} + sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums)) + for _, sf := range subforums { + sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf} } for _, node := range sfTreeMap { @@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { } } - for _, row := range rowsSlice { + for _, cat := range subforums { // NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order. - cat := row.(*subforumRow).Subforum node := sfTreeMap[cat.ID] if node.Parent != nil { node.Parent.Children = append(node.Parent.Children, node) diff --git a/src/twitch/twitch.go b/src/twitch/twitch.go index 99279cb3..62eaf967 100644 --- a/src/twitch/twitch.go +++ b/src/twitch/twitch.go @@ -440,7 +440,7 @@ func updateStreamStatusInDB(ctx context.Context, conn db.ConnOrTx, status *strea inserted := false if isStatusRelevant(status) { log.Debug().Msg("Status relevant") - _, err := db.QueryOne(ctx, conn, models.TwitchStream{}, + _, err := db.QueryOne[models.TwitchStream](ctx, conn, ` SELECT $columns FROM twitch_streams diff --git a/src/website/admin.go b/src/website/admin.go index 0bdacd62..7133dd48 100644 --- a/src/website/admin.go +++ b/src/website/admin.go @@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { userIds = append(userIds, u.User.ID) } - userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, + userLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) } - for _, ul := range userLinks { - link := ul.(*models.Link) + for _, link := range userLinks { userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) } @@ -257,12 +256,9 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { return RejectRequest(c, "User id can't be parsed") } - type userQuery struct { - User models.User `db:"auth_user"` - } - u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE auth_user.id = $1 @@ -276,7 +272,6 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) } } - user := u.(*userQuery).User whatHappened := "" if action == ApprovalQueueActionApprove { @@ -337,7 +332,7 @@ type UnapprovedPost struct { } func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { - it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, + posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn, ` SELECT $columns FROM @@ -358,11 +353,7 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved posts") } - var res []*UnapprovedPost - for _, iresult := range it { - res = append(res, iresult.(*UnapprovedPost)) - } - return res, nil + return posts, nil } type UnapprovedProject struct { @@ -372,12 +363,9 @@ type UnapprovedProject struct { } func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { - type unapprovedUser struct { - ID int `db:"id"` - } - it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, + ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn, ` - SELECT $columns + SELECT id FROM auth_user AS u WHERE @@ -388,10 +376,6 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved users") } - ownerIDs := make([]int, 0, len(it)) - for _, uid := range it { - ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) - } projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ OwnerIDs: ownerIDs, @@ -406,7 +390,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { projectIDs = append(projectIDs, p.Project.ID) } - projectLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -425,8 +409,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { for idx, proj := range projects { links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects. - for _, l := range projectLinks { - link := l.(*models.Link) + for _, link := range projectLinks { if *link.ProjectID == proj.Project.ID { links = append(links, link) } @@ -455,7 +438,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int) ThreadID int `db:"thread.id"` PostID int `db:"post.id"` } - it, err := db.Query(ctx, tx, toDelete{}, + rows, err := db.Query[toDelete](ctx, tx, ` SELECT $columns FROM @@ -471,8 +454,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int) return oops.New(err, "failed to fetch posts to delete for user") } - for _, iResult := range it { - row := iResult.(*toDelete) + for _, row := range rows { hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID) } err = tx.Commit(ctx) @@ -489,9 +471,9 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in } defer tx.Rollback(ctx) - toDelete, err := db.Query(ctx, tx, models.Project{}, + projectIDsToDelete, err := db.QueryScalar[int](ctx, tx, ` - SELECT $columns + SELECT project.id FROM handmade_project AS project JOIN handmade_user_projects AS up ON up.project_id = project.id @@ -504,17 +486,12 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in return oops.New(err, "failed to fetch user's projects") } - var projectIds []int - for _, p := range toDelete { - projectIds = append(projectIds, p.(*models.Project).ID) - } - - if len(projectIds) > 0 { + if len(projectIDsToDelete) > 0 { _, err = tx.Exec(ctx, ` DELETE FROM handmade_project WHERE id = ANY($1) `, - projectIds, + projectIDsToDelete, ) if err != nil { return oops.New(err, "failed to delete user's projects") diff --git a/src/website/api.go b/src/website/api.go index f1c5b33d..24627824 100644 --- a/src/website/api.go +++ b/src/website/api.go @@ -19,12 +19,9 @@ func APICheckUsername(c *RequestContext) ResponseData { requestedUsername := usernameArgs[0] found = true c.Perf.StartBlock("SQL", "Fetch user") - type userQuery struct { - User models.User `db:"auth_user"` - } - userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id @@ -43,7 +40,7 @@ func APICheckUsername(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) } } else { - canonicalUsername = userResult.(*userQuery).User.Username + canonicalUsername = user.Username } } diff --git a/src/website/auth.go b/src/website/auth.go index 991f4034..4dedbbaf 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -75,13 +75,11 @@ func Login(c *RequestContext) ResponseData { return res } - type userQuery struct { - User models.User `db:"auth_user"` - } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user + SELECT $columns{auth_user} + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = LOWER($1) `, @@ -94,7 +92,6 @@ func Login(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } - user := &userRow.(*userQuery).User success, err := tryLogin(c, user, password) @@ -174,7 +171,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { c.Perf.StartBlock("SQL", "Check for existing usernames and emails") userAlreadyExists := true - _, err := db.QueryInt(c.Context(), c.Conn, + _, err := db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT id FROM auth_user @@ -195,7 +192,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData { } emailAlreadyExists := true - _, err = db.QueryInt(c.Context(), c.Conn, + _, err = db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT id FROM auth_user @@ -454,17 +451,16 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return RejectRequest(c, "You must provide a username and an email address.") } - var user *models.User - c.Perf.StartBlock("SQL", "Fetching user") type userQuery struct { User models.User `db:"auth_user"` } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = LOWER($1) AND LOWER(email) = LOWER($2) @@ -478,13 +474,10 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } - if userRow != nil { - user = &userRow.(*userQuery).User - } if user != nil { c.Perf.StartBlock("SQL", "Fetching existing token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` SELECT $columns FROM handmade_onetimetoken @@ -501,10 +494,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) } } - var resetToken *models.OneTimeToken - if tokenRow != nil { - resetToken = tokenRow.(*models.OneTimeToken) - } now := time.Now() if resetToken != nil { @@ -527,7 +516,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if resetToken == nil { c.Perf.StartBlock("SQL", "Creating new token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) VALUES ($1, $2, $3, $4, $5) @@ -543,7 +532,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) } - resetToken = tokenRow.(*models.OneTimeToken) + resetToken = newToken err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) if err != nil { @@ -787,7 +776,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, User models.User `db:"auth_user"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"` } - row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, + data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, ` SELECT $columns FROM auth_user @@ -807,8 +796,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, return result } } - if row != nil { - data := row.(*userAndTokenQuery) + if data != nil { result.User = &data.User result.OneTimeToken = data.OneTimeToken if result.OneTimeToken != nil { diff --git a/src/website/blogs.go b/src/website/blogs.go index 70d0af31..0ff0a9e2 100644 --- a/src/website/blogs.go +++ b/src/website/blogs.go @@ -558,7 +558,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.ThreadID = threadId c.Perf.StartBlock("SQL", "Verify that the thread exists") - threadExists, err := db.QueryBool(c.Context(), c.Conn, + threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, ` SELECT COUNT(*) > 0 FROM handmade_thread @@ -586,7 +586,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) { res.PostID = postId c.Perf.StartBlock("SQL", "Verify that the post exists") - postExists, err := db.QueryBool(c.Context(), c.Conn, + postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn, ` SELECT COUNT(*) > 0 FROM handmade_post diff --git a/src/website/discord.go b/src/website/discord.go index 8554da9c..ae0911f1 100644 --- a/src/website/discord.go +++ b/src/website/discord.go @@ -104,7 +104,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } defer tx.Rollback(c.Context()) - iDiscordUser, err := db.QueryOne(c.Context(), tx, models.DiscordUser{}, + discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx, ` SELECT $columns FROM handmade_discorduser @@ -119,7 +119,6 @@ func DiscordUnlink(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink")) } } - discordUser := iDiscordUser.(*models.DiscordUser) _, err = tx.Exec(c.Context(), ` @@ -146,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData { } func DiscordShowcaseBacklog(c *RequestContext) ResponseData { - iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{}, + duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, `SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`, c.CurrentUser.ID, ) @@ -157,14 +156,10 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user")) } - duser := iduser.(*models.DiscordUser) - type messageIdQuery struct { - MessageID string `db:"msg.id"` - } - iMsgIDs, err := db.Query(c.Context(), c.Conn, messageIdQuery{}, + msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn, ` - SELECT $columns + SELECT msg.id FROM handmade_discordmessage AS msg WHERE @@ -178,10 +173,6 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, err) } - var msgIDs []string - for _, imsgId := range iMsgIDs { - msgIDs = append(msgIDs, imsgId.(*messageIdQuery).MessageID) - } for _, msgID := range msgIDs { interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID) if err != nil && !errors.Is(err, db.NotFound) { diff --git a/src/website/forums.go b/src/website/forums.go index ea6b8bbe..6b603901 100644 --- a/src/website/forums.go +++ b/src/website/forums.go @@ -936,7 +936,7 @@ func addForumUrlsToPost(urlContext *hmnurl.UrlContext, p *templates.Post, subfor // Takes a template post and adds information about how many posts the user has made // on the site. func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) { - numPosts, err := db.QueryInt(ctx, conn, + numPosts, err := db.QueryOneScalar[int](ctx, conn, ` SELECT COUNT(*) FROM @@ -956,7 +956,7 @@ func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.P p.AuthorNumPosts = numPosts } - numProjects, err := db.QueryInt(ctx, conn, + numProjects, err := db.QueryOneScalar[int](ctx, conn, ` SELECT COUNT(*) FROM diff --git a/src/website/imagefile_helper.go b/src/website/imagefile_helper.go index 7d44bf34..aaf2ae71 100644 --- a/src/website/imagefile_helper.go +++ b/src/website/imagefile_helper.go @@ -89,8 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string, img.Seek(0, io.SeekStart) io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs sha1sum := hasher.Sum(nil) - // TODO(db): Should use insert helper - imageFile, err := db.QueryOne(c.Context(), dbConn, models.ImageFile{}, + imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn, ` INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height) VALUES ($1, $2, $3, $4, $5, $6) @@ -105,7 +104,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string, } return SaveImageFileResult{ - ImageFile: imageFile.(*models.ImageFile), + ImageFile: imageFile, } } diff --git a/src/website/links_helper.go b/src/website/links_helper.go index 93eeef16..156058e7 100644 --- a/src/website/links_helper.go +++ b/src/website/links_helper.go @@ -30,10 +30,9 @@ func ParseLinks(text string) []ParsedLink { return res } -func LinksToText(links []interface{}) string { +func LinksToText(links []*models.Link) string { linksText := "" - for _, l := range links { - link := l.(*models.Link) + for _, link := range links { linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name) } return linksText diff --git a/src/website/podcast.go b/src/website/podcast.go index e6e800e5..c542bcb8 100644 --- a/src/website/podcast.go +++ b/src/website/podcast.go @@ -532,11 +532,12 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG Podcast models.Podcast `db:"podcast"` ImageFilename string `db:"imagefile.file"` } - podcastQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastQuery{}, + podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn, ` SELECT $columns - FROM handmade_podcast AS podcast - LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id + FROM + handmade_podcast AS podcast + LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id WHERE podcast.project_id = $1 `, projectId, @@ -549,18 +550,15 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG return result, oops.New(err, "failed to fetch podcast") } } - podcast := podcastQueryResult.(*podcastQuery).Podcast - podcastImageFilename := podcastQueryResult.(*podcastQuery).ImageFilename + podcast := podcastQueryResult.Podcast + podcastImageFilename := podcastQueryResult.ImageFilename result.Podcast = &podcast result.ImageFile = podcastImageFilename if fetchEpisodes { - type podcastEpisodeQuery struct { - Episode models.PodcastEpisode `db:"episode"` - } if episodeGUID == "" { c.Perf.StartBlock("SQL", "Fetch podcast episodes") - podcastEpisodeQueryResult, err := db.Query(c.Context(), c.Conn, podcastEpisodeQuery{}, + episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn, ` SELECT $columns FROM handmade_podcastepisode AS episode @@ -573,16 +571,14 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG if err != nil { return result, oops.New(err, "failed to fetch podcast episodes") } - for _, episodeRow := range podcastEpisodeQueryResult { - result.Episodes = append(result.Episodes, &episodeRow.(*podcastEpisodeQuery).Episode) - } + result.Episodes = episodes } else { guid, err := uuid.Parse(episodeGUID) if err != nil { return result, err } c.Perf.StartBlock("SQL", "Fetch podcast episode") - podcastEpisodeQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastEpisodeQuery{}, + episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn, ` SELECT $columns FROM handmade_podcastepisode AS episode @@ -599,8 +595,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG return result, oops.New(err, "failed to fetch podcast episode") } } - episode := podcastEpisodeQueryResult.(*podcastEpisodeQuery).Episode - result.Episodes = append(result.Episodes, &episode) + result.Episodes = append(result.Episodes, episode) } } diff --git a/src/website/projects.go b/src/website/projects.go index 84f113c6..bb83508b 100644 --- a/src/website/projects.go +++ b/src/website/projects.go @@ -187,12 +187,9 @@ func ProjectHomepage(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetching screenshots") - type screenshotQuery struct { - Filename string `db:"screenshot.file"` - } - screenshotQueryResult, err := db.Query(c.Context(), c.Conn, screenshotQuery{}, + screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn, ` - SELECT $columns + SELECT screenshot.file FROM handmade_imagefile AS screenshot INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id @@ -207,10 +204,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { c.Perf.EndBlock() c.Perf.StartBlock("SQL", "Fetching project links") - type projectLinkQuery struct { - Link models.Link `db:"link"` - } - projectLinkResult, err := db.Query(c.Context(), c.Conn, projectLinkQuery{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -237,7 +231,7 @@ func ProjectHomepage(c *RequestContext) ResponseData { Thread models.Thread `db:"thread"` Author models.User `db:"author"` } - postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{}, + posts, err := db.Query[postQuery](c.Context(), c.Conn, ` SELECT $columns FROM @@ -318,21 +312,21 @@ func ProjectHomepage(c *RequestContext) ResponseData { } } - for _, screenshot := range screenshotQueryResult { - templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshot.(*screenshotQuery).Filename)) + for _, screenshotFilename := range screenshotFilenames { + templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshotFilename)) } - for _, link := range projectLinkResult { - templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(&link.(*projectLinkQuery).Link)) + for _, link := range projectLinks { + templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(link)) } - for _, post := range postQueryResult { + for _, post := range posts { templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem( c.UrlContext, lineageBuilder, - &post.(*postQuery).Post, - &post.(*postQuery).Thread, - &post.(*postQuery).Author, + &post.Post, + &post.Thread, + &post.Author, c.Theme, )) } @@ -498,7 +492,7 @@ func ProjectEdit(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetching project links") - projectLinkResult, err := db.Query(c.Context(), c.Conn, models.Link{}, + projectLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -525,7 +519,7 @@ func ProjectEdit(c *RequestContext) ResponseData { c.Theme, ) - projectSettings.LinksText = LinksToText(projectLinkResult) + projectSettings.LinksText = LinksToText(projectLinks) var res ResponseData res.MustWriteTemplate("project_edit.html", ProjectEditData{ @@ -822,14 +816,12 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P } } - type userQuery struct { - User models.User `db:"auth_user"` - } - ownerRows, err := db.Query(ctx, tx, userQuery{}, + owners, err := db.Query[models.User](ctx, tx, ` - SELECT $columns - FROM auth_user - LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id + SELECT $columns{auth_user} + FROM + auth_user + LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = ANY ($1) `, payload.OwnerUsernames, @@ -849,7 +841,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P return oops.New(err, "Failed to delete project owners") } - for _, ownerRow := range ownerRows { + for _, owner := range owners { _, err = tx.Exec(ctx, ` INSERT INTO handmade_user_projects @@ -857,7 +849,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P VALUES ($1, $2) `, - ownerRow.(*userQuery).User.ID, + owner.ID, payload.ProjectID, ) if err != nil { diff --git a/src/website/routes.go b/src/website/routes.go index 7d890991..8b7ec30f 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -548,13 +548,11 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User } } - type userQuery struct { - User models.User `db:"auth_user"` - } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns - FROM auth_user + SELECT $columns{auth_user} + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE username = $1 `, @@ -568,7 +566,6 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User return nil, nil, oops.New(err, "failed to get user for session") } } - user := &userRow.(*userQuery).User return user, session, nil } diff --git a/src/website/twitch.go b/src/website/twitch.go index e7e0ff77..eda30723 100644 --- a/src/website/twitch.go +++ b/src/website/twitch.go @@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData { } func TwitchDebugPage(c *RequestContext) ResponseData { - streams, err := db.Query(c.Context(), c.Conn, models.TwitchStream{}, + streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn, ` SELECT $columns FROM @@ -83,8 +83,7 @@ func TwitchDebugPage(c *RequestContext) ResponseData { } html := "" - for _, stream := range streams { - s := stream.(*models.TwitchStream) + for _, s := range streams { html += fmt.Sprintf(`%s%s
`, s.Login, s.Login, s.Title) } var res ResponseData diff --git a/src/website/user.go b/src/website/user.go index 395695b7..fa77a7a9 100644 --- a/src/website/user.go +++ b/src/website/user.go @@ -53,12 +53,9 @@ func UserProfile(c *RequestContext) ResponseData { profileUser = c.CurrentUser } else { c.Perf.StartBlock("SQL", "Fetch user") - type userQuery struct { - User models.User `db:"auth_user"` - } - userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + user, err := db.QueryOne[models.User](c.Context(), c.Conn, ` - SELECT $columns + SELECT $columns{auth_user} FROM auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id @@ -75,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username)) } } - profileUser = &userResult.(*userQuery).User + profileUser = user } { @@ -87,10 +84,7 @@ func UserProfile(c *RequestContext) ResponseData { } c.Perf.StartBlock("SQL", "Fetch user links") - type userLinkQuery struct { - UserLink models.Link `db:"link"` - } - userLinksSlice, err := db.Query(c.Context(), c.Conn, userLinkQuery{}, + userLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -104,9 +98,9 @@ func UserProfile(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username)) } - profileUserLinks := make([]templates.Link, 0, len(userLinksSlice)) - for _, l := range userLinksSlice { - profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(&l.(*userLinkQuery).UserLink)) + profileUserLinks := make([]templates.Link, 0, len(userLinks)) + for _, l := range userLinks { + profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(l)) } c.Perf.EndBlock() @@ -231,7 +225,7 @@ func UserSettings(c *RequestContext) ResponseData { DiscordShowcaseBacklogUrl string } - links, err := db.Query(c.Context(), c.Conn, models.Link{}, + links, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM handmade_links @@ -248,7 +242,7 @@ func UserSettings(c *RequestContext) ResponseData { var tduser *templates.DiscordUser var numUnsavedMessages int - iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{}, + duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn, ` SELECT $columns FROM handmade_discorduser @@ -261,11 +255,10 @@ func UserSettings(c *RequestContext) ResponseData { } else if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account")) } else { - duser := iduser.(*models.DiscordUser) tmp := templates.DiscordUserToTemplate(duser) tduser = &tmp - numUnsavedMessages, err = db.QueryInt(c.Context(), c.Conn, + numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn, ` SELECT COUNT(*) FROM