Compare commits

...

2 Commits

Author SHA1 Message Date
Ben Visness bd6c95203c Start improvements to db API
Allowing for querying of scalars with db.Query, and a choice of prefix
in $columns to remove the need for single-field structs in many queries.
2022-03-27 17:20:35 -05:00
Ben Visness e74b3273cb Start converting db stuff to generics 2022-03-15 16:09:11 -05:00
18 changed files with 345 additions and 223 deletions

41
go.mod
View File

@ -1,10 +1,8 @@
module git.handmade.network/hmn/hmn
go 1.16
go 1.18
require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/Masterminds/sprig v2.22.0+incompatible
github.com/alecthomas/chroma v0.9.2
github.com/aws/aws-sdk-go-v2 v1.8.1
@ -16,13 +14,10 @@ require (
github.com/go-stack/stack v1.8.0
github.com/google/uuid v1.2.0
github.com/gorilla/websocket v1.4.2
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgconn v1.8.0
github.com/jackc/pgtype v1.6.2
github.com/jackc/pgx/v4 v4.10.1
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/rs/zerolog v1.21.0
github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.7.0
@ -32,9 +27,43 @@ require (
github.com/yuin/goldmark v1.4.1
github.com/yuin/goldmark-highlighting v0.0.0-20210516132338-9216f9c5aa01
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
)
require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.0.6 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/puddle v1.1.3 // indirect
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/mitchellh/reflectwalk v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
replace (
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e

7
go.sum
View File

@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
@ -372,6 +370,8 @@ golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY=
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7 h1:jynE66seADJbyWMUdeOyVTvPtBZt7L6LJHupGwxPZRM=
golang.org/x/exp v0.0.0-20220314205449-43aec2f8a4e7/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d h1:RNPAfi2nHY7C2srAV8A49jpsYr0ADedCk1wq6fTMTvs=
@ -439,8 +439,9 @@ golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
}
// Fetch and return the new record
iasset, err := db.QueryOne(ctx, dbConn, models.Asset{},
asset, err := db.QueryOne[models.Asset](ctx, dbConn,
`
SELECT $columns
FROM handmade_asset
@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
return nil, oops.New(err, "failed to fetch newly-created asset")
}
return iasset.(*models.Asset), nil
return asset, nil
}

View File

@ -45,7 +45,7 @@ func makeCSRFToken() string {
var ErrNoSession = errors.New("no session found")
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id)
sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id)
if err != nil {
if errors.Is(err, db.NotFound) {
return nil, ErrNoSession
@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses
return nil, oops.New(err, "failed to get session")
}
}
sess := row.(*models.Session)
return sess, nil
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"git.handmade.network/hmn/hmn/src/config"
@ -95,14 +96,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn
}
type StructQueryIterator struct {
fieldPaths [][]int
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type StructQueryIterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
closed chan struct{}
}
func (it *StructQueryIterator) Next() (interface{}, bool) {
func (it *StructQueryIterator[T]) Next() (*T, bool) {
hasNext := it.rows.Next()
if !hasNext {
it.Close()
@ -172,10 +179,10 @@ func (it *StructQueryIterator) Next() (interface{}, bool) {
currentValue = reflect.Value{}
}
return result.Interface(), true
return result.Interface().(*T), true
}
func (it *StructQueryIterator) Close() {
func (it *StructQueryIterator[any]) Close() {
it.rows.Close()
select {
case it.closed <- struct{}{}:
@ -183,9 +190,9 @@ func (it *StructQueryIterator) Close() {
}
}
func (it *StructQueryIterator) ToSlice() []interface{} {
func (it *StructQueryIterator[T]) ToSlice() []*T {
defer it.Close()
var result []interface{}
var result []*T
for {
row, ok := it.Next()
if !ok {
@ -231,8 +238,8 @@ func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.V
return val, field
}
func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) {
it, err := QueryIterator(ctx, conn, destExample, query, args...)
func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) {
it, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
} else {
@ -240,27 +247,13 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st
}
}
func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) {
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
var destExample T
destType := reflect.TypeOf(destExample)
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
if err != nil {
return nil, oops.New(err, "failed to generate column names")
}
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
compiled := compileQuery(query, destType)
columnNamesString := strings.Join(columns, ", ")
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
rows, err := conn.Query(ctx, compiled.query, args...)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
panic("query exceeded its deadline")
@ -268,10 +261,10 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return nil, err
}
it := &StructQueryIterator{
fieldPaths: fieldPaths,
it := &StructQueryIterator[T]{
fieldPaths: compiled.fieldPaths,
rows: rows,
destType: destType,
destType: compiled.destType,
closed: make(chan struct{}, 1),
}
@ -292,16 +285,70 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return it, nil
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
var columnNames [][]string
var fieldPaths [][]int
type compiledQuery struct {
query string
destType reflect.Type
fieldPaths []fieldPath
}
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
func compileQuery(query string, destType reflect.Type) compiledQuery {
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
hasColumnsPlaceholder := columnsMatch != nil
if hasColumnsPlaceholder {
// The presence of the $columns placeholder means that the destination type
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
}
var prefix []string
prefixText := columnsMatch[2]
if prefixText != "" {
prefix = []string{prefixText}
}
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ")
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
return compiledQuery{
query: query,
destType: destType,
fieldPaths: fieldPaths,
}
} else {
return compiledQuery{
query: query,
destType: destType,
}
}
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
var columnNames []columnName
var fieldPaths []fieldPath
if destType.Kind() == reflect.Ptr {
destType = destType.Elem()
}
if destType.Kind() != reflect.Struct {
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
}
type AnonPrefix struct {
@ -348,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
columnNames = append(columnNames, fieldColumnNames)
fieldPaths = append(fieldPaths, path)
} else if fieldType.Kind() == reflect.Struct {
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
if err != nil {
return nil, nil, err
}
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
columnNames = append(columnNames, subCols...)
fieldPaths = append(fieldPaths, subPaths...)
} else {
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
}
}
}
return columnNames, fieldPaths, nil
return columnNames, fieldPaths
}
/*
@ -370,8 +414,8 @@ result but find nothing.
*/
var NotFound = errors.New("not found")
func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) {
rows, err := QueryIterator(ctx, conn, destExample, query, args...)
func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}

View File

@ -10,13 +10,90 @@ import (
func TestPaths(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"` // field 0
PB *bool `db:"PB"` // field 1
NoTag string // field 2
}
type S struct {
I int `db:"I"` // field 0
PI *int `db:"PI"` // field 1
CI CustomInt `db:"CI"` // field 2
PCI *CustomInt `db:"PCI"` // field 3
S2 `db:"S2"` // field 4 (embedded!)
PS2 *S2 `db:"PS2"` // field 5
NoTag int // field 6
}
type Nested struct {
S S `db:"S"` // field 0
PS *S `db:"PS"` // field 1
NoTag S // field 2
}
type Embedded struct {
NoTag S // field 0
Nested // field 1
}
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
assert.Equal(t, []columnName{
{"S", "I"}, {"S", "PI"},
{"S", "CI"}, {"S", "PCI"},
{"S", "S2", "B"}, {"S", "S2", "PB"},
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
{"PS", "I"}, {"PS", "PI"},
{"PS", "CI"}, {"PS", "PCI"},
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
}, names)
assert.Equal(t, []fieldPath{
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
}, paths)
assert.True(t, len(names) == len(paths))
testStruct := Embedded{}
for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
}
}
func TestCompileQuery(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
})
t.Run("complex structs", func(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag string
}
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
B bool `db:"B"`
PB *bool `db:"PB"`
S2 `db:"S2"` // embedded!
PS2 *S2 `db:"PS2"`
NoTag int
}
@ -26,34 +103,48 @@ func TestPaths(t *testing.T) {
NoTag S
}
type Embedded struct {
type Dest struct {
NoTag S
Nested
}
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
if assert.Nil(t, err) {
assert.Equal(t, []string{
"S.I", "S.PI",
"S.CI", "S.PCI",
"S.B", "S.PB",
"PS.I", "PS.PI",
"PS.CI", "PS.PCI",
"PS.B", "PS.PB",
}, names)
assert.Equal(t, [][]int{
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
}, paths)
assert.True(t, len(names) == len(paths))
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
})
t.Run("int", func(t *testing.T) {
type Dest int
// There should be no error here because we do not need to extract columns from
// the destination type. There may be errors down the line in value iteration, but
// that is always the case if the Go types don't match the query.
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
})
t.Run("just one table", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
testStruct := Embedded{}
for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i], field.Name))
}
// The prefix is necessary because otherwise we would have to provide a struct with
// a db tag in order to provide the query with the `greeblies.` prefix in the
// final query. This comes up a lot when we do a JOIN to help with a condition, but
// don't actually care about any of the data we joined to.
compiled := compileQuery(
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
reflect.TypeOf(Dest{}),
)
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
})
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
type Dest int
assert.Panics(t, func() {
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
})
})
}
func TestQueryBuilder(t *testing.T) {

View File

@ -93,7 +93,7 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
type profileResult struct {
HMNUser models.User `db:"auth_user"`
}
ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{},
res, err := db.QueryOne[profileResult](ctx, bot.dbConn,
`
SELECT $columns
FROM
@ -122,7 +122,6 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
}
return
}
res := ires.(*profileResult)
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
OwnerIDs: []int{res.HMNUser.ID},

View File

@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error {
// an old one or starting a new one.
shouldResume := true
isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`)
session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`)
if err != nil {
if errors.Is(err, db.NotFound) {
// No session yet! Just identify and get on with it
@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error {
if shouldResume {
// Reconnect to the previous session
session := isession.(*models.DiscordSession)
err := bot.sendGatewayMessage(ctx, GatewayMessage{
Opcode: OpcodeResume,
Data: Resume{
@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
}
defer tx.Rollback(ctx)
msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, `
msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, `
SELECT $columns
FROM discord_outgoingmessages
ORDER BY id ASC
@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
return
}
for _, imsg := range msgs {
msg := imsg.(*models.DiscordOutgoingMessage)
for _, msg := range msgs {
if time.Now().After(msg.ExpiresAt) {
continue
}

View File

@ -76,7 +76,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
type query struct {
Message models.DiscordMessage `db:"msg"`
}
imessagesWithoutContent, err := db.Query(ctx, dbConn, query{},
messagesWithoutContent, err := db.Query[query](ctx, dbConn,
`
SELECT $columns
FROM
@ -95,10 +95,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
return
}
if len(imessagesWithoutContent) > 0 {
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent))
if len(messagesWithoutContent) > 0 {
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent))
msgloop:
for _, imsg := range imessagesWithoutContent {
for _, msgRow := range messagesWithoutContent {
select {
case <-ctx.Done():
log.Info().Msg("Scrape was canceled")
@ -106,7 +106,7 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
default:
}
msg := imsg.(*query).Message
msg := msgRow.Message
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
if errors.Is(err, NotFound) {

View File

@ -165,7 +165,7 @@ func InternMessage(
dbConn db.ConnOrTx,
msg *Message,
) error {
_, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{},
_, err := db.QueryOne[models.DiscordMessage](ctx, dbConn,
`
SELECT $columns
FROM handmade_discordmessage
@ -219,7 +219,7 @@ type InternedMessage struct {
}
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
result, err := db.QueryOne(ctx, dbConn, InternedMessage{},
interned, err := db.QueryOne[InternedMessage](ctx, dbConn,
`
SELECT $columns
FROM
@ -236,7 +236,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string)
return nil, err
}
interned := result.(*InternedMessage)
return interned, nil
}
@ -283,7 +282,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message
}
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{},
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
`
SELECT $columns
FROM handmade_snippet
@ -294,10 +293,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In
if err != nil && !errors.Is(err, db.NotFound) {
return oops.New(err, "failed to fetch snippet for discord message")
}
var snippet *models.Snippet
if !errors.Is(err, db.NotFound) {
snippet = isnippet.(*models.Snippet)
}
// NOTE(asaf): Also deletes the following through a db cascade:
// * handmade_discordmessageattachment
@ -367,7 +362,7 @@ func SaveMessageContents(
return oops.New(err, "failed to create or update message contents")
}
icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{},
content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn,
`
SELECT $columns
FROM
@ -380,7 +375,7 @@ func SaveMessageContents(
if err != nil {
return oops.New(err, "failed to fetch message contents")
}
interned.MessageContent = icontent.(*models.DiscordMessageContent)
interned.MessageContent = content
}
// Save attachments
@ -472,7 +467,7 @@ func saveAttachment(
hmnUserID int,
discordMessageID string,
) (*models.DiscordMessageAttachment, error) {
iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
`
SELECT $columns
FROM handmade_discordmessageattachment
@ -481,7 +476,7 @@ func saveAttachment(
attachment.ID,
)
if err == nil {
return iexisting.(*models.DiscordMessageAttachment), nil
return existing, nil
} else if errors.Is(err, db.NotFound) {
// this is fine, just create it
} else {
@ -534,7 +529,7 @@ func saveAttachment(
return nil, oops.New(err, "failed to save Discord attachment data")
}
iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
`
SELECT $columns
FROM handmade_discordmessageattachment
@ -546,7 +541,7 @@ func saveAttachment(
return nil, oops.New(err, "failed to fetch new Discord attachment data")
}
return iDiscordAttachment.(*models.DiscordMessageAttachment), nil
return discordAttachment, nil
}
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
@ -636,7 +631,7 @@ func saveEmbed(
return nil, oops.New(err, "failed to insert new embed")
}
iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{},
discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx,
`
SELECT $columns
FROM handmade_discordmessageembed
@ -648,11 +643,11 @@ func saveEmbed(
return nil, oops.New(err, "failed to fetch new Discord embed data")
}
return iDiscordEmbed.(*models.DiscordMessageEmbed), nil
return discordEmbed, nil
}
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{},
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
`
SELECT $columns
FROM handmade_snippet
@ -669,7 +664,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin
}
}
return iresult.(*models.Snippet), nil
return snippet, nil
}
/*
@ -808,7 +803,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
type tagsRow struct {
Tag models.Tag `db:"tags"`
}
iUserTags, err := db.Query(ctx, tx, tagsRow{},
userTags, err := db.Query[tagsRow](ctx, tx,
`
SELECT $columns
FROM
@ -823,8 +818,8 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
return oops.New(err, "failed to fetch tags for user projects")
}
for _, itag := range iUserTags {
tag := itag.(*tagsRow).Tag
for _, userTag := range userTags {
tag := userTag.Tag
allTags = append(allTags, tag.ID)
for _, messageTag := range messageTags {
if strings.EqualFold(tag.Text, messageTag) {
@ -890,7 +885,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
// Check attachments
attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{},
attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx,
`
SELECT $columns
FROM handmade_discordmessageattachment
@ -901,13 +896,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
if err != nil {
return nil, nil, oops.New(err, "failed to fetch message attachments")
}
for _, iattachment := range attachments {
attachment := iattachment.(*models.DiscordMessageAttachment)
for _, attachment := range attachments {
return &attachment.AssetID, nil, nil
}
// Check embeds
embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{},
embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx,
`
SELECT $columns
FROM handmade_discordmessageembed
@ -918,8 +912,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
if err != nil {
return nil, nil, oops.New(err, "failed to fetch discord embeds")
}
for _, iembed := range embeds {
embed := iembed.(*models.DiscordMessageEmbed)
for _, embed := range embeds {
if embed.VideoID != nil {
return embed.VideoID, nil, nil
} else if embed.ImageID != nil {

View File

@ -140,15 +140,15 @@ func FetchProjects(
}
// Do the query
iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...)
projects, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil {
return nil, oops.New(err, "failed to fetch projects")
}
// Fetch project owners to do permission checks
projectIds := make([]int, len(iprojects))
for i, iproject := range iprojects {
projectIds[i] = iproject.(*projectRow).Project.ID
projectIds := make([]int, len(projects))
for i, projectRow := range projects {
projectIds[i] = projectRow.Project.ID
}
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
if err != nil {
@ -156,8 +156,7 @@ func FetchProjects(
}
var res []ProjectAndStuff
for i, iproject := range iprojects {
row := iproject.(*projectRow)
for i, row := range projects {
owners := projectOwners[i].Owners
/*
@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners(
UserID int `db:"user_id"`
ProjectID int `db:"project_id"`
}
iuserprojects, err := db.Query(ctx, tx, userProject{},
userProjects, err := db.Query[userProject](ctx, tx,
`
SELECT $columns
FROM handmade_user_projects
@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners(
// Get the unique user IDs from this set and fetch the users from the db
var userIds []int
for _, iuserproject := range iuserprojects {
userProject := iuserproject.(*userProject)
for _, userProject := range userProjects {
addUserId := true
for _, uid := range userIds {
if uid == userProject.UserID {
@ -364,7 +361,7 @@ func FetchMultipleProjectsOwners(
type userQuery struct {
User models.User `db:"auth_user"`
}
iusers, err := db.Query(ctx, tx, userQuery{},
projectUsers, err := db.Query[userQuery](ctx, tx,
`
SELECT $columns
FROM auth_user
@ -383,9 +380,7 @@ func FetchMultipleProjectsOwners(
for i, pid := range projectIds {
res[i] = ProjectOwners{ProjectID: pid}
}
for _, iuserproject := range iuserprojects {
userProject := iuserproject.(*userProject)
for _, userProject := range userProjects {
// Get a pointer to the existing record in the result
var projectOwners *ProjectOwners
for i := range res {
@ -396,8 +391,8 @@ func FetchMultipleProjectsOwners(
// Get the full user record we fetched
var user *models.User
for _, iuser := range iusers {
u := iuser.(*userQuery).User
for _, projectUser := range projectUsers {
u := projectUser.User
if u.ID == userProject.UserID {
user = &u
}
@ -473,7 +468,7 @@ func SetProjectTag(
resultTag = p.Tag
} else if p.Project.TagID == nil {
// Create a tag
itag, err := db.QueryOne(ctx, tx, models.Tag{},
tag, err := db.QueryOne[models.Tag](ctx, tx,
`
INSERT INTO tags (text) VALUES ($1)
RETURNING $columns
@ -483,7 +478,7 @@ func SetProjectTag(
if err != nil {
return nil, oops.New(err, "failed to create new tag for project")
}
resultTag = itag.(*models.Tag)
resultTag = tag
// Attach it to the project
_, err = tx.Exec(ctx,
@ -499,7 +494,7 @@ func SetProjectTag(
}
} else {
// Update the text of an existing one
itag, err := db.QueryOne(ctx, tx, models.Tag{},
tag, err := db.QueryOne[models.Tag](ctx, tx,
`
UPDATE tags
SET text = $1
@ -511,7 +506,7 @@ func SetProjectTag(
if err != nil {
return nil, oops.New(err, "failed to update existing tag")
}
resultTag = itag.(*models.Tag)
resultTag = tag
}
err = tx.Commit(ctx)

View File

@ -47,7 +47,7 @@ func FetchSnippets(
type snippetIDRow struct {
SnippetID int `db:"snippet_id"`
}
iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{},
snippetIDs, err := db.Query[snippetIDRow](ctx, tx,
`
SELECT DISTINCT snippet_id
FROM
@ -63,13 +63,13 @@ func FetchSnippets(
}
// special early-out: no snippets found for these tags at all
if len(iSnippetIDs) == 0 {
if len(snippetIDs) == 0 {
return nil, nil
}
q.IDs = make([]int, len(iSnippetIDs))
for i := range iSnippetIDs {
q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID
q.IDs = make([]int, len(snippetIDs))
for i := range snippetIDs {
q.IDs[i] = snippetIDs[i].SnippetID
}
}
@ -125,16 +125,14 @@ func FetchSnippets(
DiscordMessage *models.DiscordMessage `db:"discord_message"`
}
iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...)
rows, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...)
if err != nil {
return nil, oops.New(err, "failed to fetch threads")
}
result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not
snippetIDs := make([]int, len(iresults))
for i, iresult := range iresults {
row := *iresult.(*resultRow)
result := make([]SnippetAndStuff, len(rows)) // allocate extra space because why not
snippetIDs := make([]int, len(rows))
for i, row := range rows {
result[i] = SnippetAndStuff{
Snippet: row.Snippet,
Owner: row.Owner,
@ -150,7 +148,7 @@ func FetchSnippets(
SnippetID int `db:"snippet_tags.snippet_id"`
Tag *models.Tag `db:"tags"`
}
iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{},
snippetTags, err := db.Query[snippetTagRow](ctx, tx,
`
SELECT $columns
FROM
@ -170,8 +168,7 @@ func FetchSnippets(
for i := range result {
resultBySnippetId[result[i].Snippet.ID] = &result[i]
}
for _, iSnippetTag := range iSnippetTags {
snippetTag := iSnippetTag.(*snippetTagRow)
for _, snippetTag := range snippetTags {
item := resultBySnippetId[snippetTag.SnippetID]
item.Tags = append(item.Tags, snippetTag.Tag)
}

View File

@ -40,18 +40,12 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
}
itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...)
tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil {
return nil, oops.New(err, "failed to fetch tags")
}
res := make([]*models.Tag, len(itags))
for i, itag := range itags {
tag := itag.(*models.Tag)
res[i] = tag
}
return res, nil
return tags, nil
}
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {

View File

@ -145,15 +145,13 @@ func FetchThreads(
ForumLastReadTime *time.Time `db:"slri.lastread"`
}
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil {
return nil, oops.New(err, "failed to fetch threads")
}
result := make([]ThreadAndStuff, len(iresults))
for i, iresult := range iresults {
row := *iresult.(*resultRow)
result := make([]ThreadAndStuff, len(results))
for i, row := range results {
hasRead := false
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
hasRead = true
@ -405,15 +403,13 @@ func FetchPosts(
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
}
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
results, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
if err != nil {
return nil, oops.New(err, "failed to fetch posts")
}
result := make([]PostAndStuff, len(iresults))
for i, iresult := range iresults {
row := *iresult.(*resultRow)
result := make([]PostAndStuff, len(results))
for i, row := range results {
hasRead := false
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
hasRead = true
@ -611,7 +607,7 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
type postResult struct {
AuthorID *int `db:"post.author_id"`
}
iresult, err := db.QueryOne(ctx, connOrTx, postResult{},
result, err := db.QueryOne[postResult](ctx, connOrTx,
`
SELECT $columns
FROM
@ -629,7 +625,6 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
panic(oops.New(err, "failed to get author of post when checking permissions"))
}
}
result := iresult.(*postResult)
return result.AuthorID != nil && *result.AuthorID == user.ID
}
@ -709,7 +704,7 @@ func DeletePost(
FirstPostID int `db:"first_id"`
Deleted bool `db:"deleted"`
}
ti, err := db.QueryOne(ctx, tx, threadInfo{},
info, err := db.QueryOne[threadInfo](ctx, tx,
`
SELECT $columns
FROM
@ -722,7 +717,6 @@ func DeletePost(
if err != nil {
panic(oops.New(err, "failed to fetch thread info"))
}
info := ti.(*threadInfo)
if info.Deleted {
return true
}
@ -851,7 +845,7 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
type assetId struct {
AssetID uuid.UUID `db:"id"`
}
assetResult, err := db.Query(ctx, tx, assetId{},
assets, err := db.Query[assetId](ctx, tx,
`
SELECT $columns
FROM handmade_asset
@ -865,8 +859,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
var values [][]interface{}
for _, asset := range assetResult {
values = append(values, []interface{}{postId, asset.(*assetId).AssetID})
for _, asset := range assets {
values = append(values, []interface{}{postId, asset.AssetID})
}
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
@ -886,7 +880,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more.
You should probably mark the thread as deleted in this case.
*/
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
postsIter, err := db.Query(ctx, tx, models.Post{},
posts, err := db.Query[models.Post](ctx, tx,
`
SELECT $columns
FROM handmade_post
@ -901,9 +895,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
}
var firstPost, lastPost *models.Post
for _, ipost := range postsIter {
post := ipost.(*models.Post)
for _, post := range posts {
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
firstPost = post
}

View File

@ -47,7 +47,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
type subforumRow struct {
Subforum Subforum `db:"sf"`
}
rowsSlice, err := db.Query(ctx, conn, subforumRow{},
rowsSlice, err := db.Query[subforumRow](ctx, conn,
`
SELECT $columns
FROM
@ -61,7 +61,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice))
for _, row := range rowsSlice {
sf := row.(*subforumRow).Subforum
sf := row.Subforum
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
}
@ -73,7 +73,7 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
for _, row := range rowsSlice {
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
cat := row.(*subforumRow).Subforum
cat := row.Subforum
node := sfTreeMap[cat.ID]
if node.Parent != nil {
node.Parent.Children = append(node.Parent.Children, node)

View File

@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID)
}
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
}
for _, ul := range userLinks {
link := ul.(*models.Link)
for _, link := range userLinks {
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
}
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE auth_user.id = $1
`,
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
}
}
user := u.(*userQuery).User
user := u.User
whatHappened := ""
if action == ApprovalQueueActionApprove {
@ -337,7 +337,7 @@ type UnapprovedPost struct {
}
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
if err != nil {
return nil, oops.New(err, "failed to fetch unapproved posts")
}
var res []*UnapprovedPost
for _, iresult := range it {
res = append(res, iresult.(*UnapprovedPost))
}
return res, nil
}
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
type unapprovedUser struct {
ID int `db:"id"`
}
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
if err != nil {
return nil, oops.New(err, "failed to fetch unapproved users")
}
ownerIDs := make([]int, 0, len(it))
for _, uid := range it {
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
ownerIDs := make([]int, 0, len(uids))
for _, uid := range uids {
ownerIDs = append(ownerIDs, uid.ID)
}
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{

View File

@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
}
} else {
canonicalUsername = userResult.(*userQuery).User.Username
canonicalUsername = userResult.User.Username
}
}

View File

@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE LOWER(username) = LOWER($1)
`,
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
}
}
user := &userRow.(*userQuery).User
user := &userRow.User
success, err := tryLogin(c, user, password)
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
}
}
if userRow != nil {
user = &userRow.(*userQuery).User
user = &userRow.User
}
if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
`
SELECT $columns
FROM handmade_onetimetoken
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
}
}
var resetToken *models.OneTimeToken
if tokenRow != nil {
resetToken = tokenRow.(*models.OneTimeToken)
}
now := time.Now()
if resetToken != nil {
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
`
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5)
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
}
resetToken = tokenRow.(*models.OneTimeToken)
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
if err != nil {
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"auth_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
}
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
return result
}
}
if row != nil {
data := row.(*userAndTokenQuery)
if data != nil {
result.User = &data.User
result.OneTimeToken = data.OneTimeToken
if result.OneTimeToken != nil {