Compare commits

...

2 Commits

Author SHA1 Message Date
Ben Visness bd6c95203c Start improvements to db API
Allowing for querying of scalars with db.Query, and a choice of prefix
in $columns to remove the need for single-field structs in many queries.
2022-03-27 17:20:35 -05:00
Ben Visness e74b3273cb Start converting db stuff to generics 2022-03-15 16:09:11 -05:00
18 changed files with 345 additions and 223 deletions

41
go.mod
View File

@ -1,10 +1,8 @@
module git.handmade.network/hmn/hmn module git.handmade.network/hmn/hmn
go 1.16 go 1.18
require ( require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/Masterminds/sprig v2.22.0+incompatible github.com/Masterminds/sprig v2.22.0+incompatible
github.com/alecthomas/chroma v0.9.2 github.com/alecthomas/chroma v0.9.2
github.com/aws/aws-sdk-go-v2 v1.8.1 github.com/aws/aws-sdk-go-v2 v1.8.1
@ -16,13 +14,10 @@ require (
github.com/go-stack/stack v1.8.0 github.com/go-stack/stack v1.8.0
github.com/google/uuid v1.2.0 github.com/google/uuid v1.2.0
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgconn v1.8.0 github.com/jackc/pgconn v1.8.0
github.com/jackc/pgtype v1.6.2 github.com/jackc/pgtype v1.6.2
github.com/jackc/pgx/v4 v4.10.1 github.com/jackc/pgx/v4 v4.10.1
github.com/jpillora/backoff v1.0.0 github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/rs/zerolog v1.21.0 github.com/rs/zerolog v1.21.0
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
@ -32,9 +27,43 @@ require (
github.com/yuin/goldmark v1.4.1 github.com/yuin/goldmark v1.4.1
github.com/yuin/goldmark-highlighting v0.0.0-20210516132338-9216f9c5aa01 github.com/yuin/goldmark-highlighting v0.0.0-20210516132338-9216f9c5aa01
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
) )
require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.0.6 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/puddle v1.1.3 // indirect
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/mitchellh/reflectwalk v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
replace ( replace (
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9 github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e

7
go.sum
View File

@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
@ -372,6 +370,8 @@ golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7 h1:jynE66seADJbyWMUdeOyVTvPtBZt7L6LJHupGwxPZRM=
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d h1:RNPAfi2nHY7C2srAV8A49jpsYr0ADedCk1wq6fTMTvs= golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d h1:RNPAfi2nHY7C2srAV8A49jpsYr0ADedCk1wq6fTMTvs=
@ -439,8 +439,9 @@ golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
} }
// Fetch and return the new record // Fetch and return the new record
iasset, err := db.QueryOne(ctx, dbConn, models.Asset{}, asset, err := db.QueryOne[models.Asset](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM handmade_asset FROM handmade_asset
@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
return nil, oops.New(err, "failed to fetch newly-created asset") return nil, oops.New(err, "failed to fetch newly-created asset")
} }
return iasset.(*models.Asset), nil return asset, nil
} }

View File

@ -45,7 +45,7 @@ func makeCSRFToken() string {
var ErrNoSession = errors.New("no session found") var ErrNoSession = errors.New("no session found")
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) { func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id) sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id)
if err != nil { if err != nil {
if errors.Is(err, db.NotFound) { if errors.Is(err, db.NotFound) {
return nil, ErrNoSession return nil, ErrNoSession
@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses
return nil, oops.New(err, "failed to get session") return nil, oops.New(err, "failed to get session")
} }
} }
sess := row.(*models.Session)
return sess, nil return sess, nil
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/config"
@ -95,14 +96,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn return conn
} }
type StructQueryIterator struct { type columnName []string
fieldPaths [][]int
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type StructQueryIterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows rows pgx.Rows
destType reflect.Type destType reflect.Type
closed chan struct{} closed chan struct{}
} }
func (it *StructQueryIterator) Next() (interface{}, bool) { func (it *StructQueryIterator[T]) Next() (*T, bool) {
hasNext := it.rows.Next() hasNext := it.rows.Next()
if !hasNext { if !hasNext {
it.Close() it.Close()
@ -172,10 +179,10 @@ func (it *StructQueryIterator) Next() (interface{}, bool) {
currentValue = reflect.Value{} currentValue = reflect.Value{}
} }
return result.Interface(), true return result.Interface().(*T), true
} }
func (it *StructQueryIterator) Close() { func (it *StructQueryIterator[any]) Close() {
it.rows.Close() it.rows.Close()
select { select {
case it.closed <- struct{}{}: case it.closed <- struct{}{}:
@ -183,9 +190,9 @@ func (it *StructQueryIterator) Close() {
} }
} }
func (it *StructQueryIterator) ToSlice() []interface{} { func (it *StructQueryIterator[T]) ToSlice() []*T {
defer it.Close() defer it.Close()
var result []interface{} var result []*T
for { for {
row, ok := it.Next() row, ok := it.Next()
if !ok { if !ok {
@ -231,8 +238,8 @@ func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.V
return val, field return val, field
} }
func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) { func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) {
it, err := QueryIterator(ctx, conn, destExample, query, args...) it, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} else { } else {
@ -240,27 +247,13 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st
} }
} }
func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) { func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
var destExample T
destType := reflect.TypeOf(destExample) destType := reflect.TypeOf(destExample)
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
if err != nil {
return nil, oops.New(err, "failed to generate column names")
}
columns := make([]string, 0, len(columnNames)) compiled := compileQuery(query, destType)
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ") rows, err := conn.Query(ctx, compiled.query, args...)
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
panic("query exceeded its deadline") panic("query exceeded its deadline")
@ -268,10 +261,10 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return nil, err return nil, err
} }
it := &StructQueryIterator{ it := &StructQueryIterator[T]{
fieldPaths: fieldPaths, fieldPaths: compiled.fieldPaths,
rows: rows, rows: rows,
destType: destType, destType: compiled.destType,
closed: make(chan struct{}, 1), closed: make(chan struct{}, 1),
} }
@ -292,16 +285,70 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return it, nil return it, nil
} }
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { type compiledQuery struct {
var columnNames [][]string query string
var fieldPaths [][]int destType reflect.Type
fieldPaths []fieldPath
}
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
func compileQuery(query string, destType reflect.Type) compiledQuery {
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
hasColumnsPlaceholder := columnsMatch != nil
if hasColumnsPlaceholder {
// The presence of the $columns placeholder means that the destination type
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
}
var prefix []string
prefixText := columnsMatch[2]
if prefixText != "" {
prefix = []string{prefixText}
}
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ")
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
return compiledQuery{
query: query,
destType: destType,
fieldPaths: fieldPaths,
}
} else {
return compiledQuery{
query: query,
destType: destType,
}
}
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
var columnNames []columnName
var fieldPaths []fieldPath
if destType.Kind() == reflect.Ptr { if destType.Kind() == reflect.Ptr {
destType = destType.Elem() destType = destType.Elem()
} }
if destType.Kind() != reflect.Struct { if destType.Kind() != reflect.Struct {
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
} }
type AnonPrefix struct { type AnonPrefix struct {
@ -348,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
columnNames = append(columnNames, fieldColumnNames) columnNames = append(columnNames, fieldColumnNames)
fieldPaths = append(fieldPaths, path) fieldPaths = append(fieldPaths, path)
} else if fieldType.Kind() == reflect.Struct { } else if fieldType.Kind() == reflect.Struct {
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
if err != nil {
return nil, nil, err
}
columnNames = append(columnNames, subCols...) columnNames = append(columnNames, subCols...)
fieldPaths = append(fieldPaths, subPaths...) fieldPaths = append(fieldPaths, subPaths...)
} else { } else {
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
} }
} }
} }
return columnNames, fieldPaths, nil return columnNames, fieldPaths
} }
/* /*
@ -370,8 +414,8 @@ result but find nothing.
*/ */
var NotFound = errors.New("not found") var NotFound = errors.New("not found")
func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) {
rows, err := QueryIterator(ctx, conn, destExample, query, args...) rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -10,52 +10,143 @@ import (
func TestPaths(t *testing.T) { func TestPaths(t *testing.T) {
type CustomInt int type CustomInt int
type S struct { type S2 struct {
I int `db:"I"` B bool `db:"B"` // field 0
PI *int `db:"PI"` PB *bool `db:"PB"` // field 1
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag int NoTag string // field 2
}
type S struct {
I int `db:"I"` // field 0
PI *int `db:"PI"` // field 1
CI CustomInt `db:"CI"` // field 2
PCI *CustomInt `db:"PCI"` // field 3
S2 `db:"S2"` // field 4 (embedded!)
PS2 *S2 `db:"PS2"` // field 5
NoTag int // field 6
} }
type Nested struct { type Nested struct {
S S `db:"S"` S S `db:"S"` // field 0
PS *S `db:"PS"` PS *S `db:"PS"` // field 1
NoTag S NoTag S // field 2
} }
type Embedded struct { type Embedded struct {
NoTag S NoTag S // field 0
Nested Nested // field 1
} }
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
if assert.Nil(t, err) { assert.Equal(t, []columnName{
assert.Equal(t, []string{ {"S", "I"}, {"S", "PI"},
"S.I", "S.PI", {"S", "CI"}, {"S", "PCI"},
"S.CI", "S.PCI", {"S", "S2", "B"}, {"S", "S2", "PB"},
"S.B", "S.PB", {"S", "PS2", "B"}, {"S", "PS2", "PB"},
"PS.I", "PS.PI", {"PS", "I"}, {"PS", "PI"},
"PS.CI", "PS.PCI", {"PS", "CI"}, {"PS", "PCI"},
"PS.B", "PS.PB", {"PS", "S2", "B"}, {"PS", "S2", "PB"},
}, names) {"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
assert.Equal(t, [][]int{ }, names)
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, assert.Equal(t, []fieldPath{
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, {1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
}, paths) {1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
assert.True(t, len(names) == len(paths)) {1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
} {1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
}, paths)
assert.True(t, len(names) == len(paths))
testStruct := Embedded{} testStruct := Embedded{}
for i, path := range paths { for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid()) assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i], field.Name)) assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
} }
} }
func TestCompileQuery(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
})
t.Run("complex structs", func(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag string
}
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
S2 `db:"S2"` // embedded!
PS2 *S2 `db:"PS2"`
NoTag int
}
type Nested struct {
S S `db:"S"`
PS *S `db:"PS"`
NoTag S
}
type Dest struct {
NoTag S
Nested
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
})
t.Run("int", func(t *testing.T) {
type Dest int
// There should be no error here because we do not need to extract columns from
// the destination type. There may be errors down the line in value iteration, but
// that is always the case if the Go types don't match the query.
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
})
t.Run("just one table", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
// The prefix is necessary because otherwise we would have to provide a struct with
// a db tag in order to provide the query with the `greeblies.` prefix in the
// final query. This comes up a lot when we do a JOIN to help with a condition, but
// don't actually care about any of the data we joined to.
compiled := compileQuery(
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
reflect.TypeOf(Dest{}),
)
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
})
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
type Dest int
assert.Panics(t, func() {
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
})
})
}
func TestQueryBuilder(t *testing.T) { func TestQueryBuilder(t *testing.T) {
t.Run("happy time", func(t *testing.T) { t.Run("happy time", func(t *testing.T) {
var qb QueryBuilder var qb QueryBuilder

View File

@ -93,7 +93,7 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
type profileResult struct { type profileResult struct {
HMNUser models.User `db:"auth_user"` HMNUser models.User `db:"auth_user"`
} }
ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{}, res, err := db.QueryOne[profileResult](ctx, bot.dbConn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -122,7 +122,6 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
} }
return return
} }
res := ires.(*profileResult)
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{ projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
OwnerIDs: []int{res.HMNUser.ID}, OwnerIDs: []int{res.HMNUser.ID},

View File

@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error {
// an old one or starting a new one. // an old one or starting a new one.
shouldResume := true shouldResume := true
isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`) session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`)
if err != nil { if err != nil {
if errors.Is(err, db.NotFound) { if errors.Is(err, db.NotFound) {
// No session yet! Just identify and get on with it // No session yet! Just identify and get on with it
@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error {
if shouldResume { if shouldResume {
// Reconnect to the previous session // Reconnect to the previous session
session := isession.(*models.DiscordSession)
err := bot.sendGatewayMessage(ctx, GatewayMessage{ err := bot.sendGatewayMessage(ctx, GatewayMessage{
Opcode: OpcodeResume, Opcode: OpcodeResume,
Data: Resume{ Data: Resume{
@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
} }
defer tx.Rollback(ctx) defer tx.Rollback(ctx)
msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, ` msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, `
SELECT $columns SELECT $columns
FROM discord_outgoingmessages FROM discord_outgoingmessages
ORDER BY id ASC ORDER BY id ASC
@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
return return
} }
for _, imsg := range msgs { for _, msg := range msgs {
msg := imsg.(*models.DiscordOutgoingMessage)
if time.Now().After(msg.ExpiresAt) { if time.Now().After(msg.ExpiresAt) {
continue continue
} }

View File

@ -76,7 +76,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
type query struct { type query struct {
Message models.DiscordMessage `db:"msg"` Message models.DiscordMessage `db:"msg"`
} }
imessagesWithoutContent, err := db.Query(ctx, dbConn, query{}, messagesWithoutContent, err := db.Query[query](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -95,10 +95,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
return return
} }
if len(imessagesWithoutContent) > 0 { if len(messagesWithoutContent) > 0 {
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent)) log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent))
msgloop: msgloop:
for _, imsg := range imessagesWithoutContent { for _, msgRow := range messagesWithoutContent {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Info().Msg("Scrape was canceled") log.Info().Msg("Scrape was canceled")
@ -106,7 +106,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
default: default:
} }
msg := imsg.(*query).Message msg := msgRow.Message
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID) discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
if errors.Is(err, NotFound) { if errors.Is(err, NotFound) {

View File

@ -165,7 +165,7 @@ func InternMessage(
dbConn db.ConnOrTx, dbConn db.ConnOrTx,
msg *Message, msg *Message,
) error { ) error {
_, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{}, _, err := db.QueryOne[models.DiscordMessage](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessage FROM handmade_discordmessage
@ -219,7 +219,7 @@ type InternedMessage struct {
} }
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) { func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
result, err := db.QueryOne(ctx, dbConn, InternedMessage{}, interned, err := db.QueryOne[InternedMessage](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -236,7 +236,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string)
return nil, err return nil, err
} }
interned := result.(*InternedMessage)
return interned, nil return interned, nil
} }
@ -283,7 +282,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message
} }
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error { func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{}, snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM handmade_snippet FROM handmade_snippet
@ -294,10 +293,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In
if err != nil && !errors.Is(err, db.NotFound) { if err != nil && !errors.Is(err, db.NotFound) {
return oops.New(err, "failed to fetch snippet for discord message") return oops.New(err, "failed to fetch snippet for discord message")
} }
var snippet *models.Snippet
if !errors.Is(err, db.NotFound) {
snippet = isnippet.(*models.Snippet)
}
// NOTE(asaf): Also deletes the following through a db cascade: // NOTE(asaf): Also deletes the following through a db cascade:
// * handmade_discordmessageattachment // * handmade_discordmessageattachment
@ -367,7 +362,7 @@ func SaveMessageContents(
return oops.New(err, "failed to create or update message contents") return oops.New(err, "failed to create or update message contents")
} }
icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{}, content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -380,7 +375,7 @@ func SaveMessageContents(
if err != nil { if err != nil {
return oops.New(err, "failed to fetch message contents") return oops.New(err, "failed to fetch message contents")
} }
interned.MessageContent = icontent.(*models.DiscordMessageContent) interned.MessageContent = content
} }
// Save attachments // Save attachments
@ -472,7 +467,7 @@ func saveAttachment(
hmnUserID int, hmnUserID int,
discordMessageID string, discordMessageID string,
) (*models.DiscordMessageAttachment, error) { ) (*models.DiscordMessageAttachment, error) {
iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessageattachment FROM handmade_discordmessageattachment
@ -481,7 +476,7 @@ func saveAttachment(
attachment.ID, attachment.ID,
) )
if err == nil { if err == nil {
return iexisting.(*models.DiscordMessageAttachment), nil return existing, nil
} else if errors.Is(err, db.NotFound) { } else if errors.Is(err, db.NotFound) {
// this is fine, just create it // this is fine, just create it
} else { } else {
@ -534,7 +529,7 @@ func saveAttachment(
return nil, oops.New(err, "failed to save Discord attachment data") return nil, oops.New(err, "failed to save Discord attachment data")
} }
iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{}, discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessageattachment FROM handmade_discordmessageattachment
@ -546,7 +541,7 @@ func saveAttachment(
return nil, oops.New(err, "failed to fetch new Discord attachment data") return nil, oops.New(err, "failed to fetch new Discord attachment data")
} }
return iDiscordAttachment.(*models.DiscordMessageAttachment), nil return discordAttachment, nil
} }
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it // Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
@ -636,7 +631,7 @@ func saveEmbed(
return nil, oops.New(err, "failed to insert new embed") return nil, oops.New(err, "failed to insert new embed")
} }
iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{}, discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessageembed FROM handmade_discordmessageembed
@ -648,11 +643,11 @@ func saveEmbed(
return nil, oops.New(err, "failed to fetch new Discord embed data") return nil, oops.New(err, "failed to fetch new Discord embed data")
} }
return iDiscordEmbed.(*models.DiscordMessageEmbed), nil return discordEmbed, nil
} }
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) { func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{}, snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
` `
SELECT $columns SELECT $columns
FROM handmade_snippet FROM handmade_snippet
@ -669,7 +664,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin
} }
} }
return iresult.(*models.Snippet), nil return snippet, nil
} }
/* /*
@ -808,7 +803,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
type tagsRow struct { type tagsRow struct {
Tag models.Tag `db:"tags"` Tag models.Tag `db:"tags"`
} }
iUserTags, err := db.Query(ctx, tx, tagsRow{}, userTags, err := db.Query[tagsRow](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -823,8 +818,8 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
return oops.New(err, "failed to fetch tags for user projects") return oops.New(err, "failed to fetch tags for user projects")
} }
for _, itag := range iUserTags { for _, userTag := range userTags {
tag := itag.(*tagsRow).Tag tag := userTag.Tag
allTags = append(allTags, tag.ID) allTags = append(allTags, tag.ID)
for _, messageTag := range messageTags { for _, messageTag := range messageTags {
if strings.EqualFold(tag.Text, messageTag) { if strings.EqualFold(tag.Text, messageTag) {
@ -890,7 +885,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) { func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
// Check attachments // Check attachments
attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{}, attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessageattachment FROM handmade_discordmessageattachment
@ -901,13 +896,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
if err != nil { if err != nil {
return nil, nil, oops.New(err, "failed to fetch message attachments") return nil, nil, oops.New(err, "failed to fetch message attachments")
} }
for _, iattachment := range attachments { for _, attachment := range attachments {
attachment := iattachment.(*models.DiscordMessageAttachment)
return &attachment.AssetID, nil, nil return &attachment.AssetID, nil, nil
} }
// Check embeds // Check embeds
embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{}, embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_discordmessageembed FROM handmade_discordmessageembed
@ -918,8 +912,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
if err != nil { if err != nil {
return nil, nil, oops.New(err, "failed to fetch discord embeds") return nil, nil, oops.New(err, "failed to fetch discord embeds")
} }
for _, iembed := range embeds { for _, embed := range embeds {
embed := iembed.(*models.DiscordMessageEmbed)
if embed.VideoID != nil { if embed.VideoID != nil {
return embed.VideoID, nil, nil return embed.VideoID, nil, nil
} else if embed.ImageID != nil { } else if embed.ImageID != nil {

View File

@ -140,15 +140,15 @@ func FetchProjects(
} }
// Do the query // Do the query
iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...) projects, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch projects") return nil, oops.New(err, "failed to fetch projects")
} }
// Fetch project owners to do permission checks // Fetch project owners to do permission checks
projectIds := make([]int, len(iprojects)) projectIds := make([]int, len(projects))
for i, iproject := range iprojects { for i, projectRow := range projects {
projectIds[i] = iproject.(*projectRow).Project.ID projectIds[i] = projectRow.Project.ID
} }
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds) projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
if err != nil { if err != nil {
@ -156,8 +156,7 @@ func FetchProjects(
} }
var res []ProjectAndStuff var res []ProjectAndStuff
for i, iproject := range iprojects { for i, row := range projects {
row := iproject.(*projectRow)
owners := projectOwners[i].Owners owners := projectOwners[i].Owners
/* /*
@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners(
UserID int `db:"user_id"` UserID int `db:"user_id"`
ProjectID int `db:"project_id"` ProjectID int `db:"project_id"`
} }
iuserprojects, err := db.Query(ctx, tx, userProject{}, userProjects, err := db.Query[userProject](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_user_projects FROM handmade_user_projects
@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners(
// Get the unique user IDs from this set and fetch the users from the db // Get the unique user IDs from this set and fetch the users from the db
var userIds []int var userIds []int
for _, iuserproject := range iuserprojects { for _, userProject := range userProjects {
userProject := iuserproject.(*userProject)
addUserId := true addUserId := true
for _, uid := range userIds { for _, uid := range userIds {
if uid == userProject.UserID { if uid == userProject.UserID {
@ -364,7 +361,7 @@ func FetchMultipleProjectsOwners(
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
iusers, err := db.Query(ctx, tx, userQuery{}, projectUsers, err := db.Query[userQuery](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM auth_user
@ -383,9 +380,7 @@ func FetchMultipleProjectsOwners(
for i, pid := range projectIds { for i, pid := range projectIds {
res[i] = ProjectOwners{ProjectID: pid} res[i] = ProjectOwners{ProjectID: pid}
} }
for _, iuserproject := range iuserprojects { for _, userProject := range userProjects {
userProject := iuserproject.(*userProject)
// Get a pointer to the existing record in the result // Get a pointer to the existing record in the result
var projectOwners *ProjectOwners var projectOwners *ProjectOwners
for i := range res { for i := range res {
@ -396,8 +391,8 @@ func FetchMultipleProjectsOwners(
// Get the full user record we fetched // Get the full user record we fetched
var user *models.User var user *models.User
for _, iuser := range iusers { for _, projectUser := range projectUsers {
u := iuser.(*userQuery).User u := projectUser.User
if u.ID == userProject.UserID { if u.ID == userProject.UserID {
user = &u user = &u
} }
@ -473,7 +468,7 @@ func SetProjectTag(
resultTag = p.Tag resultTag = p.Tag
} else if p.Project.TagID == nil { } else if p.Project.TagID == nil {
// Create a tag // Create a tag
itag, err := db.QueryOne(ctx, tx, models.Tag{}, tag, err := db.QueryOne[models.Tag](ctx, tx,
` `
INSERT INTO tags (text) VALUES ($1) INSERT INTO tags (text) VALUES ($1)
RETURNING $columns RETURNING $columns
@ -483,7 +478,7 @@ func SetProjectTag(
if err != nil { if err != nil {
return nil, oops.New(err, "failed to create new tag for project") return nil, oops.New(err, "failed to create new tag for project")
} }
resultTag = itag.(*models.Tag) resultTag = tag
// Attach it to the project // Attach it to the project
_, err = tx.Exec(ctx, _, err = tx.Exec(ctx,
@ -499,7 +494,7 @@ func SetProjectTag(
} }
} else { } else {
// Update the text of an existing one // Update the text of an existing one
itag, err := db.QueryOne(ctx, tx, models.Tag{}, tag, err := db.QueryOne[models.Tag](ctx, tx,
` `
UPDATE tags UPDATE tags
SET text = $1 SET text = $1
@ -511,7 +506,7 @@ func SetProjectTag(
if err != nil { if err != nil {
return nil, oops.New(err, "failed to update existing tag") return nil, oops.New(err, "failed to update existing tag")
} }
resultTag = itag.(*models.Tag) resultTag = tag
} }
err = tx.Commit(ctx) err = tx.Commit(ctx)

View File

@ -47,7 +47,7 @@ func FetchSnippets(
type snippetIDRow struct { type snippetIDRow struct {
SnippetID int `db:"snippet_id"` SnippetID int `db:"snippet_id"`
} }
iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{}, snippetIDs, err := db.Query[snippetIDRow](ctx, tx,
` `
SELECT DISTINCT snippet_id SELECT DISTINCT snippet_id
FROM FROM
@ -63,13 +63,13 @@ func FetchSnippets(
} }
// special early-out: no snippets found for these tags at all // special early-out: no snippets found for these tags at all
if len(iSnippetIDs) == 0 { if len(snippetIDs) == 0 {
return nil, nil return nil, nil
} }
q.IDs = make([]int, len(iSnippetIDs)) q.IDs = make([]int, len(snippetIDs))
for i := range iSnippetIDs { for i := range snippetIDs {
q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID q.IDs[i] = snippetIDs[i].SnippetID
} }
} }
@ -125,16 +125,14 @@ func FetchSnippets(
DiscordMessage *models.DiscordMessage `db:"discord_message"` DiscordMessage *models.DiscordMessage `db:"discord_message"`
} }
iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...) rows, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...)
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch threads") return nil, oops.New(err, "failed to fetch threads")
} }
result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not result := make([]SnippetAndStuff, len(rows)) // allocate extra space because why not
snippetIDs := make([]int, len(iresults)) snippetIDs := make([]int, len(rows))
for i, iresult := range iresults { for i, row := range rows {
row := *iresult.(*resultRow)
result[i] = SnippetAndStuff{ result[i] = SnippetAndStuff{
Snippet: row.Snippet, Snippet: row.Snippet,
Owner: row.Owner, Owner: row.Owner,
@ -150,7 +148,7 @@ func FetchSnippets(
SnippetID int `db:"snippet_tags.snippet_id"` SnippetID int `db:"snippet_tags.snippet_id"`
Tag *models.Tag `db:"tags"` Tag *models.Tag `db:"tags"`
} }
iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{}, snippetTags, err := db.Query[snippetTagRow](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -170,8 +168,7 @@ func FetchSnippets(
for i := range result { for i := range result {
resultBySnippetId[result[i].Snippet.ID] = &result[i] resultBySnippetId[result[i].Snippet.ID] = &result[i]
} }
for _, iSnippetTag := range iSnippetTags { for _, snippetTag := range snippetTags {
snippetTag := iSnippetTag.(*snippetTagRow)
item := resultBySnippetId[snippetTag.SnippetID] item := resultBySnippetId[snippetTag.SnippetID]
item.Tags = append(item.Tags, snippetTag.Tag) item.Tags = append(item.Tags, snippetTag.Tag)
} }

View File

@ -40,18 +40,12 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
} }
itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...) tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch tags") return nil, oops.New(err, "failed to fetch tags")
} }
res := make([]*models.Tag, len(itags)) return tags, nil
for i, itag := range itags {
tag := itag.(*models.Tag)
res[i] = tag
}
return res, nil
} }
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) { func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {

View File

@ -145,15 +145,13 @@ func FetchThreads(
ForumLastReadTime *time.Time `db:"slri.lastread"` ForumLastReadTime *time.Time `db:"slri.lastread"`
} }
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch threads") return nil, oops.New(err, "failed to fetch threads")
} }
result := make([]ThreadAndStuff, len(iresults)) result := make([]ThreadAndStuff, len(results))
for i, iresult := range iresults { for i, row := range results {
row := *iresult.(*resultRow)
hasRead := false hasRead := false
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) { if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
hasRead = true hasRead = true
@ -405,15 +403,13 @@ func FetchPosts(
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset) qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
} }
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...) results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch posts") return nil, oops.New(err, "failed to fetch posts")
} }
result := make([]PostAndStuff, len(iresults)) result := make([]PostAndStuff, len(results))
for i, iresult := range iresults { for i, row := range results {
row := *iresult.(*resultRow)
hasRead := false hasRead := false
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) { if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
hasRead = true hasRead = true
@ -611,7 +607,7 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
type postResult struct { type postResult struct {
AuthorID *int `db:"post.author_id"` AuthorID *int `db:"post.author_id"`
} }
iresult, err := db.QueryOne(ctx, connOrTx, postResult{}, result, err := db.QueryOne[postResult](ctx, connOrTx,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -629,7 +625,6 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
panic(oops.New(err, "failed to get author of post when checking permissions")) panic(oops.New(err, "failed to get author of post when checking permissions"))
} }
} }
result := iresult.(*postResult)
return result.AuthorID != nil && *result.AuthorID == user.ID return result.AuthorID != nil && *result.AuthorID == user.ID
} }
@ -709,7 +704,7 @@ func DeletePost(
FirstPostID int `db:"first_id"` FirstPostID int `db:"first_id"`
Deleted bool `db:"deleted"` Deleted bool `db:"deleted"`
} }
ti, err := db.QueryOne(ctx, tx, threadInfo{}, info, err := db.QueryOne[threadInfo](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -722,7 +717,6 @@ func DeletePost(
if err != nil { if err != nil {
panic(oops.New(err, "failed to fetch thread info")) panic(oops.New(err, "failed to fetch thread info"))
} }
info := ti.(*threadInfo)
if info.Deleted { if info.Deleted {
return true return true
} }
@ -851,7 +845,7 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
type assetId struct { type assetId struct {
AssetID uuid.UUID `db:"id"` AssetID uuid.UUID `db:"id"`
} }
assetResult, err := db.Query(ctx, tx, assetId{}, assets, err := db.Query[assetId](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_asset FROM handmade_asset
@ -865,8 +859,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
var values [][]interface{} var values [][]interface{}
for _, asset := range assetResult { for _, asset := range assets {
values = append(values, []interface{}{postId, asset.(*assetId).AssetID}) values = append(values, []interface{}{postId, asset.AssetID})
} }
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values)) _, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
@ -886,7 +880,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more.
You should probably mark the thread as deleted in this case. You should probably mark the thread as deleted in this case.
*/ */
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error { func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
postsIter, err := db.Query(ctx, tx, models.Post{}, posts, err := db.Query[models.Post](ctx, tx,
` `
SELECT $columns SELECT $columns
FROM handmade_post FROM handmade_post
@ -901,9 +895,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
} }
var firstPost, lastPost *models.Post var firstPost, lastPost *models.Post
for _, ipost := range postsIter { for _, post := range posts {
post := ipost.(*models.Post)
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) { if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
firstPost = post firstPost = post
} }

View File

@ -47,7 +47,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
type subforumRow struct { type subforumRow struct {
Subforum Subforum `db:"sf"` Subforum Subforum `db:"sf"`
} }
rowsSlice, err := db.Query(ctx, conn, subforumRow{}, rowsSlice, err := db.Query[subforumRow](ctx, conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -61,7 +61,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice)) sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice))
for _, row := range rowsSlice { for _, row := range rowsSlice {
sf := row.(*subforumRow).Subforum sf := row.Subforum
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf} sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
} }
@ -73,7 +73,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
for _, row := range rowsSlice { for _, row := range rowsSlice {
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order. // NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
cat := row.(*subforumRow).Subforum cat := row.Subforum
node := sfTreeMap[cat.ID] node := sfTreeMap[cat.ID]
if node.Parent != nil { if node.Parent != nil {
node.Parent.Children = append(node.Parent.Children, node) node.Parent.Children = append(node.Parent.Children, node)

View File

@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID) userIds = append(userIds, u.User.ID)
} }
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
} }
for _, ul := range userLinks { for _, link := range userLinks {
link := ul.(*models.Link)
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
} }
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE auth_user.id = $1 WHERE auth_user.id = $1
`, `,
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
} }
} }
user := u.(*userQuery).User user := u.User
whatHappened := "" whatHappened := ""
if action == ApprovalQueueActionApprove { if action == ApprovalQueueActionApprove {
@ -337,7 +337,7 @@ type UnapprovedPost struct {
} }
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch unapproved posts") return nil, oops.New(err, "failed to fetch unapproved posts")
} }
var res []*UnapprovedPost
for _, iresult := range it {
res = append(res, iresult.(*UnapprovedPost))
}
return res, nil return res, nil
} }
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
type unapprovedUser struct { type unapprovedUser struct {
ID int `db:"id"` ID int `db:"id"`
} }
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch unapproved users") return nil, oops.New(err, "failed to fetch unapproved users")
} }
ownerIDs := make([]int, 0, len(it)) ownerIDs := make([]int, 0, len(uids))
for _, uid := range it { for _, uid := range uids {
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) ownerIDs = append(ownerIDs, uid.ID)
} }
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{

View File

@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
} }
} else { } else {
canonicalUsername = userResult.(*userQuery).User.Username canonicalUsername = userResult.User.Username
} }
} }

View File

@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE LOWER(username) = LOWER($1) WHERE LOWER(username) = LOWER($1)
`, `,
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
} }
} }
user := &userRow.(*userQuery).User user := &userRow.User
success, err := tryLogin(c, user, password) success, err := tryLogin(c, user, password)
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM auth_user
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
} }
} }
if userRow != nil { if userRow != nil {
user = &userRow.(*userQuery).User user = &userRow.User
} }
if user != nil { if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token") c.Perf.StartBlock("SQL", "Fetching existing token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM handmade_onetimetoken FROM handmade_onetimetoken
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
} }
} }
var resetToken *models.OneTimeToken
if tokenRow != nil {
resetToken = tokenRow.(*models.OneTimeToken)
}
now := time.Now() now := time.Now()
if resetToken != nil { if resetToken != nil {
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil { if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token") c.Perf.StartBlock("SQL", "Creating new token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
` `
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
} }
resetToken = tokenRow.(*models.OneTimeToken)
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
if err != nil { if err != nil {
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
} }
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM auth_user
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
return result return result
} }
} }
if row != nil { if data != nil {
data := row.(*userAndTokenQuery)
result.User = &data.User result.User = &data.User
result.OneTimeToken = data.OneTimeToken result.OneTimeToken = data.OneTimeToken
if result.OneTimeToken != nil { if result.OneTimeToken != nil {