diff --git a/go.mod b/go.mod index 8c56138a..08de4878 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,8 @@ module git.handmade.network/hmn/hmn -go 1.16 +go 1.18 require ( - github.com/Masterminds/goutils v1.1.1 // indirect - github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible github.com/alecthomas/chroma v0.9.2 github.com/aws/aws-sdk-go-v2 v1.8.1 @@ -16,13 +14,10 @@ require ( github.com/go-stack/stack v1.8.0 github.com/google/uuid v1.2.0 github.com/gorilla/websocket v1.4.2 - github.com/huandu/xstrings v1.3.2 // indirect - github.com/imdario/mergo v0.3.12 // indirect github.com/jackc/pgconn v1.8.0 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.10.1 github.com/jpillora/backoff v1.0.0 - github.com/mitchellh/copystructure v1.1.1 // indirect github.com/rs/zerolog v1.21.0 github.com/spf13/cobra v1.1.3 github.com/stretchr/testify v1.7.0 @@ -35,6 +30,39 @@ require ( golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d ) +require ( + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver v1.5.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect + github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.4.0 // indirect + github.com/huandu/xstrings v1.3.2 // indirect + github.com/imdario/mergo v0.3.12 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.0.6 // indirect + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect + github.com/jackc/puddle v1.1.3 // indirect + github.com/mitchellh/copystructure v1.1.1 // indirect + github.com/mitchellh/reflectwalk v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect + golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect + golang.org/x/text v0.3.6 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) + replace ( github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9 github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e diff --git a/go.sum b/go.sum index 6e6029ee..bb9c2856 100644 --- a/go.sum +++ b/go.sum @@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU= github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= diff --git a/src/db/db.go b/src/db/db.go index 20aee1cc..2048a427 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "strings" "git.handmade.network/hmn/hmn/src/config" @@ -95,14 +96,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { return conn } -type StructQueryIterator struct { - fieldPaths [][]int +type columnName []string + +// A path to a particular field in query's destination type. Each index in the slice +// corresponds to a field index for use with Field on a reflect.Type or reflect.Value. +type fieldPath []int + +type StructQueryIterator[T any] struct { + fieldPaths []fieldPath rows pgx.Rows destType reflect.Type closed chan struct{} } -func (it *StructQueryIterator) Next() (interface{}, bool) { +func (it *StructQueryIterator[T]) Next() (*T, bool) { hasNext := it.rows.Next() if !hasNext { it.Close() @@ -172,10 +179,10 @@ func (it *StructQueryIterator) Next() (interface{}, bool) { currentValue = reflect.Value{} } - return result.Interface(), true + return result.Interface().(*T), true } -func (it *StructQueryIterator) Close() { +func (it *StructQueryIterator[any]) Close() { it.rows.Close() select { case it.closed <- struct{}{}: @@ -183,9 +190,9 @@ func (it *StructQueryIterator) Close() { } } -func (it *StructQueryIterator) ToSlice() []interface{} { +func (it *StructQueryIterator[T]) ToSlice() []*T { defer it.Close() - var result []interface{} + var result []*T for { row, ok := it.Next() if !ok { @@ -231,8 +238,8 @@ func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.V return val, field } -func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) { - it, err := QueryIterator(ctx, conn, destExample, query, args...) +func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) { + it, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err } else { @@ -240,27 +247,13 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st } } -func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) { +func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { + var destExample T destType := reflect.TypeOf(destExample) - columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil) - if err != nil { - return nil, oops.New(err, "failed to generate column names") - } - columns := make([]string, 0, len(columnNames)) - for _, strSlice := range columnNames { - tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") - fullName := strSlice[len(strSlice)-1] - if tableName != "" { - fullName = tableName + "." + fullName - } - columns = append(columns, fullName) - } + compiled := compileQuery(query, destType) - columnNamesString := strings.Join(columns, ", ") - query = strings.Replace(query, "$columns", columnNamesString, -1) - - rows, err := conn.Query(ctx, query, args...) + rows, err := conn.Query(ctx, compiled.query, args...) if err != nil { if errors.Is(err, context.DeadlineExceeded) { panic("query exceeded its deadline") @@ -268,10 +261,10 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, return nil, err } - it := &StructQueryIterator{ - fieldPaths: fieldPaths, + it := &StructQueryIterator[T]{ + fieldPaths: compiled.fieldPaths, rows: rows, - destType: destType, + destType: compiled.destType, closed: make(chan struct{}, 1), } @@ -292,16 +285,70 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, return it, nil } -func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { - var columnNames [][]string - var fieldPaths [][]int +type compiledQuery struct { + query string + destType reflect.Type + fieldPaths []fieldPath +} + +var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`) + +func compileQuery(query string, destType reflect.Type) compiledQuery { + columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query) + hasColumnsPlaceholder := columnsMatch != nil + + if hasColumnsPlaceholder { + // The presence of the $columns placeholder means that the destination type + // must be a struct, and we will plonk that struct's fields into the query. + + if destType.Kind() != reflect.Struct { + panic("$columns can only be used when querying into some kind of struct") + } + + var prefix []string + prefixText := columnsMatch[2] + if prefixText != "" { + prefix = []string{prefixText} + } + + columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix) + + columns := make([]string, 0, len(columnNames)) + for _, strSlice := range columnNames { + tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") + fullName := strSlice[len(strSlice)-1] + if tableName != "" { + fullName = tableName + "." + fullName + } + columns = append(columns, fullName) + } + + columnNamesString := strings.Join(columns, ", ") + query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString) + + return compiledQuery{ + query: query, + destType: destType, + fieldPaths: fieldPaths, + } + } else { + return compiledQuery{ + query: query, + destType: destType, + } + } +} + +func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) { + var columnNames []columnName + var fieldPaths []fieldPath if destType.Kind() == reflect.Ptr { destType = destType.Elem() } if destType.Kind() != reflect.Struct { - return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) + panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)) } type AnonPrefix struct { @@ -348,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str columnNames = append(columnNames, fieldColumnNames) fieldPaths = append(fieldPaths, path) } else if fieldType.Kind() == reflect.Struct { - subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) - if err != nil { - return nil, nil, err - } + subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) columnNames = append(columnNames, subCols...) fieldPaths = append(fieldPaths, subPaths...) } else { - return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) + panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)) } } } - return columnNames, fieldPaths, nil + return columnNames, fieldPaths } /* @@ -370,8 +414,8 @@ result but find nothing. */ var NotFound = errors.New("not found") -func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { - rows, err := QueryIterator(ctx, conn, destExample, query, args...) +func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err } diff --git a/src/db/db_test.go b/src/db/db_test.go index 41b9491d..edc647ed 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -10,52 +10,143 @@ import ( func TestPaths(t *testing.T) { type CustomInt int - type S struct { - I int `db:"I"` - PI *int `db:"PI"` - CI CustomInt `db:"CI"` - PCI *CustomInt `db:"PCI"` - B bool `db:"B"` - PB *bool `db:"PB"` + type S2 struct { + B bool `db:"B"` // field 0 + PB *bool `db:"PB"` // field 1 - NoTag int + NoTag string // field 2 + } + type S struct { + I int `db:"I"` // field 0 + PI *int `db:"PI"` // field 1 + CI CustomInt `db:"CI"` // field 2 + PCI *CustomInt `db:"PCI"` // field 3 + S2 `db:"S2"` // field 4 (embedded!) + PS2 *S2 `db:"PS2"` // field 5 + + NoTag int // field 6 } type Nested struct { - S S `db:"S"` - PS *S `db:"PS"` + S S `db:"S"` // field 0 + PS *S `db:"PS"` // field 1 - NoTag S + NoTag S // field 2 } type Embedded struct { - NoTag S - Nested + NoTag S // field 0 + Nested // field 1 } - names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") - if assert.Nil(t, err) { - assert.Equal(t, []string{ - "S.I", "S.PI", - "S.CI", "S.PCI", - "S.B", "S.PB", - "PS.I", "PS.PI", - "PS.CI", "PS.PCI", - "PS.B", "PS.PB", - }, names) - assert.Equal(t, [][]int{ - {1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, - {1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, - }, paths) - assert.True(t, len(names) == len(paths)) - } + names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil) + assert.Equal(t, []columnName{ + {"S", "I"}, {"S", "PI"}, + {"S", "CI"}, {"S", "PCI"}, + {"S", "S2", "B"}, {"S", "S2", "PB"}, + {"S", "PS2", "B"}, {"S", "PS2", "PB"}, + {"PS", "I"}, {"PS", "PI"}, + {"PS", "CI"}, {"PS", "PCI"}, + {"PS", "S2", "B"}, {"PS", "S2", "PB"}, + {"PS", "PS2", "B"}, {"PS", "PS2", "PB"}, + }, names) + assert.Equal(t, []fieldPath{ + {1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI + {1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI + {1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB + {1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB + {1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI + {1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI + {1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB + {1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB + }, paths) + assert.True(t, len(names) == len(paths)) testStruct := Embedded{} for i, path := range paths { val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) assert.True(t, val.IsValid()) - assert.True(t, strings.Contains(names[i], field.Name)) + assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name)) } } +func TestCompileQuery(t *testing.T) { + t.Run("simple struct", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query) + }) + t.Run("complex structs", func(t *testing.T) { + type CustomInt int + type S2 struct { + B bool `db:"B"` + PB *bool `db:"PB"` + + NoTag string + } + type S struct { + I int `db:"I"` + PI *int `db:"PI"` + CI CustomInt `db:"CI"` + PCI *CustomInt `db:"PCI"` + S2 `db:"S2"` // embedded! + PS2 *S2 `db:"PS2"` + + NoTag int + } + type Nested struct { + S S `db:"S"` + PS *S `db:"PS"` + + NoTag S + } + type Dest struct { + NoTag S + Nested + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query) + }) + t.Run("int", func(t *testing.T) { + type Dest int + + // There should be no error here because we do not need to extract columns from + // the destination type. There may be errors down the line in value iteration, but + // that is always the case if the Go types don't match the query. + compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0))) + assert.Equal(t, "SELECT id FROM greeblies", compiled.query) + }) + t.Run("just one table", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + // The prefix is necessary because otherwise we would have to provide a struct with + // a db tag in order to provide the query with the `greeblies.` prefix in the + // final query. This comes up a lot when we do a JOIN to help with a condition, but + // don't actually care about any of the data we joined to. + compiled := compileQuery( + "SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props", + reflect.TypeOf(Dest{}), + ) + assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query) + }) + + t.Run("using $columns without a struct is not allowed", func(t *testing.T) { + type Dest int + + assert.Panics(t, func() { + compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0))) + }) + }) +} + func TestQueryBuilder(t *testing.T) { t.Run("happy time", func(t *testing.T) { var qb QueryBuilder diff --git a/src/db/doc.go b/src/db/doc.go new file mode 100644 index 00000000..07140f79 --- /dev/null +++ b/src/db/doc.go @@ -0,0 +1,4 @@ +/* +Wow so dobument +*/ +package db