Start of db rework: clean up, start generics, improve tests

This commit is contained in:
Ben Visness 2022-04-16 12:49:29 -05:00
parent 38e93d5208
commit b9a4cb2361
5 changed files with 245 additions and 80 deletions

40
go.mod
View File

@ -1,10 +1,8 @@
module git.handmade.network/hmn/hmn
go 1.16
go 1.18
require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/Masterminds/sprig v2.22.0+incompatible
github.com/alecthomas/chroma v0.9.2
github.com/aws/aws-sdk-go-v2 v1.8.1
@ -16,13 +14,10 @@ require (
github.com/go-stack/stack v1.8.0
github.com/google/uuid v1.2.0
github.com/gorilla/websocket v1.4.2
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/jackc/pgconn v1.8.0
github.com/jackc/pgtype v1.6.2
github.com/jackc/pgx/v4 v4.10.1
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/rs/zerolog v1.21.0
github.com/spf13/cobra v1.1.3
github.com/stretchr/testify v1.7.0
@ -35,6 +30,39 @@ require (
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
)
require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver v1.5.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.0.6 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/puddle v1.1.3 // indirect
github.com/mitchellh/copystructure v1.1.1 // indirect
github.com/mitchellh/reflectwalk v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
replace (
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e

2
go.sum
View File

@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"git.handmade.network/hmn/hmn/src/config"
@ -95,14 +96,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn
}
type StructQueryIterator struct {
fieldPaths [][]int
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type StructQueryIterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
closed chan struct{}
}
func (it *StructQueryIterator) Next() (interface{}, bool) {
func (it *StructQueryIterator[T]) Next() (*T, bool) {
hasNext := it.rows.Next()
if !hasNext {
it.Close()
@ -172,10 +179,10 @@ func (it *StructQueryIterator) Next() (interface{}, bool) {
currentValue = reflect.Value{}
}
return result.Interface(), true
return result.Interface().(*T), true
}
func (it *StructQueryIterator) Close() {
func (it *StructQueryIterator[any]) Close() {
it.rows.Close()
select {
case it.closed <- struct{}{}:
@ -183,9 +190,9 @@ func (it *StructQueryIterator) Close() {
}
}
func (it *StructQueryIterator) ToSlice() []interface{} {
func (it *StructQueryIterator[T]) ToSlice() []*T {
defer it.Close()
var result []interface{}
var result []*T
for {
row, ok := it.Next()
if !ok {
@ -231,8 +238,8 @@ func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.V
return val, field
}
func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) {
it, err := QueryIterator(ctx, conn, destExample, query, args...)
func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) {
it, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
} else {
@ -240,27 +247,13 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st
}
}
func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) {
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
var destExample T
destType := reflect.TypeOf(destExample)
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
if err != nil {
return nil, oops.New(err, "failed to generate column names")
}
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
compiled := compileQuery(query, destType)
columnNamesString := strings.Join(columns, ", ")
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
rows, err := conn.Query(ctx, compiled.query, args...)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
panic("query exceeded its deadline")
@ -268,10 +261,10 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return nil, err
}
it := &StructQueryIterator{
fieldPaths: fieldPaths,
it := &StructQueryIterator[T]{
fieldPaths: compiled.fieldPaths,
rows: rows,
destType: destType,
destType: compiled.destType,
closed: make(chan struct{}, 1),
}
@ -292,16 +285,70 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
return it, nil
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
var columnNames [][]string
var fieldPaths [][]int
type compiledQuery struct {
query string
destType reflect.Type
fieldPaths []fieldPath
}
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
func compileQuery(query string, destType reflect.Type) compiledQuery {
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
hasColumnsPlaceholder := columnsMatch != nil
if hasColumnsPlaceholder {
// The presence of the $columns placeholder means that the destination type
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
}
var prefix []string
prefixText := columnsMatch[2]
if prefixText != "" {
prefix = []string{prefixText}
}
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ")
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
return compiledQuery{
query: query,
destType: destType,
fieldPaths: fieldPaths,
}
} else {
return compiledQuery{
query: query,
destType: destType,
}
}
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
var columnNames []columnName
var fieldPaths []fieldPath
if destType.Kind() == reflect.Ptr {
destType = destType.Elem()
}
if destType.Kind() != reflect.Struct {
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
}
type AnonPrefix struct {
@ -348,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
columnNames = append(columnNames, fieldColumnNames)
fieldPaths = append(fieldPaths, path)
} else if fieldType.Kind() == reflect.Struct {
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
if err != nil {
return nil, nil, err
}
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
columnNames = append(columnNames, subCols...)
fieldPaths = append(fieldPaths, subPaths...)
} else {
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
}
}
}
return columnNames, fieldPaths, nil
return columnNames, fieldPaths
}
/*
@ -370,8 +414,8 @@ result but find nothing.
*/
var NotFound = errors.New("not found")
func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) {
rows, err := QueryIterator(ctx, conn, destExample, query, args...)
func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}

View File

@ -10,52 +10,143 @@ import (
func TestPaths(t *testing.T) {
type CustomInt int
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
B bool `db:"B"`
PB *bool `db:"PB"`
type S2 struct {
B bool `db:"B"` // field 0
PB *bool `db:"PB"` // field 1
NoTag int
NoTag string // field 2
}
type S struct {
I int `db:"I"` // field 0
PI *int `db:"PI"` // field 1
CI CustomInt `db:"CI"` // field 2
PCI *CustomInt `db:"PCI"` // field 3
S2 `db:"S2"` // field 4 (embedded!)
PS2 *S2 `db:"PS2"` // field 5
NoTag int // field 6
}
type Nested struct {
S S `db:"S"`
PS *S `db:"PS"`
S S `db:"S"` // field 0
PS *S `db:"PS"` // field 1
NoTag S
NoTag S // field 2
}
type Embedded struct {
NoTag S
Nested
NoTag S // field 0
Nested // field 1
}
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
if assert.Nil(t, err) {
assert.Equal(t, []string{
"S.I", "S.PI",
"S.CI", "S.PCI",
"S.B", "S.PB",
"PS.I", "PS.PI",
"PS.CI", "PS.PCI",
"PS.B", "PS.PB",
}, names)
assert.Equal(t, [][]int{
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
}, paths)
assert.True(t, len(names) == len(paths))
}
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
assert.Equal(t, []columnName{
{"S", "I"}, {"S", "PI"},
{"S", "CI"}, {"S", "PCI"},
{"S", "S2", "B"}, {"S", "S2", "PB"},
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
{"PS", "I"}, {"PS", "PI"},
{"PS", "CI"}, {"PS", "PCI"},
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
}, names)
assert.Equal(t, []fieldPath{
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
}, paths)
assert.True(t, len(names) == len(paths))
testStruct := Embedded{}
for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i], field.Name))
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
}
}
func TestCompileQuery(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
})
t.Run("complex structs", func(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag string
}
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
S2 `db:"S2"` // embedded!
PS2 *S2 `db:"PS2"`
NoTag int
}
type Nested struct {
S S `db:"S"`
PS *S `db:"PS"`
NoTag S
}
type Dest struct {
NoTag S
Nested
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
})
t.Run("int", func(t *testing.T) {
type Dest int
// There should be no error here because we do not need to extract columns from
// the destination type. There may be errors down the line in value iteration, but
// that is always the case if the Go types don't match the query.
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
})
t.Run("just one table", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
// The prefix is necessary because otherwise we would have to provide a struct with
// a db tag in order to provide the query with the `greeblies.` prefix in the
// final query. This comes up a lot when we do a JOIN to help with a condition, but
// don't actually care about any of the data we joined to.
compiled := compileQuery(
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
reflect.TypeOf(Dest{}),
)
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
})
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
type Dest int
assert.Panics(t, func() {
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
})
})
}
func TestQueryBuilder(t *testing.T) {
t.Run("happy time", func(t *testing.T) {
var qb QueryBuilder

4
src/db/doc.go Normal file
View File

@ -0,0 +1,4 @@
/*
Wow so dobument
*/
package db