From 97360a1998772fac3db1497f1aeb5a6bd27972df Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Sat, 16 Apr 2022 19:46:17 -0500 Subject: [PATCH] Rework and document the db API Tests still pass, at least...now for everything else --- src/db/db.go | 468 ++++++++++-------- src/db/query_builder.go | 6 +- .../2021-11-06T033930Z_PersonalProjects.go | 2 +- src/models/subforum.go | 18 +- 4 files changed, 258 insertions(+), 236 deletions(-) diff --git a/src/db/db.go b/src/db/db.go index 440c70f..c4f83d5 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -29,9 +29,9 @@ var NotFound = errors.New("not found") // This interface should match both a direct pgx connection or a pgx transaction. type ConnOrTx interface { - Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row - Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) // Both raw database connections and transactions in pgx can begin/commit // transactions. For database connections it does the obvious thing; for @@ -42,6 +42,8 @@ type ConnOrTx interface { var connInfo = pgtype.NewConnInfo() +// Creates a new connection to the HMN database. +// This connection is not safe for concurrent use. func NewConn() *pgx.Conn { conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN()) if err != nil { @@ -51,6 +53,8 @@ func NewConn() *pgx.Conn { return conn } +// Creates a connection pool for the HMN database. +// The resulting pool is safe for concurrent use. func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN()) @@ -67,154 +71,19 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { return conn } -type columnName []string - -// A path to a particular field in query's destination type. Each index in the slice -// corresponds to a field index for use with Field on a reflect.Type or reflect.Value. -type fieldPath []int - -type ResultIterator[T any] struct { - fieldPaths []fieldPath - rows pgx.Rows - destType reflect.Type - closed chan struct{} -} - -func (it *ResultIterator[T]) Next() (*T, bool) { - hasNext := it.rows.Next() - if !hasNext { - it.Close() - return nil, false - } - - result := reflect.New(it.destType) - - vals, err := it.rows.Values() - if err != nil { - panic(err) - } - - // Better logging of panics in this confusing reflection process - var currentField reflect.StructField - var currentValue reflect.Value - var currentIdx int - defer func() { - if r := recover(); r != nil { - if currentValue.IsValid() { - logging.Error(). - Int("index", currentIdx). - Str("field name", currentField.Name). - Stringer("field type", currentField.Type). - Interface("value", currentValue.Interface()). - Stringer("value type", currentValue.Type()). - Msg("panic in iterator") - } - - if currentField.Name != "" { - panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r)) - } else { - panic(r) - } - } - }() - - for i, val := range vals { - currentIdx = i - if val == nil { - continue - } - - var field reflect.Value - field, currentField = followPathThroughStructs(result, it.fieldPaths[i]) - if field.Kind() == reflect.Ptr { - field.Set(reflect.New(field.Type().Elem())) - field = field.Elem() - } - - // Some actual values still come through as pointers (like net.IPNet). Dunno why. - // Regardless, we know it's not nil, so we can get at the contents. - valReflected := reflect.ValueOf(val) - if valReflected.Kind() == reflect.Ptr { - valReflected = valReflected.Elem() - } - currentValue = valReflected - - switch field.Kind() { - case reflect.Int: - field.SetInt(valReflected.Int()) - default: - field.Set(valReflected) - } - - currentField = reflect.StructField{} - currentValue = reflect.Value{} - } - - return result.Interface().(*T), true -} - -func (it *ResultIterator[any]) Close() { - it.rows.Close() - select { - case it.closed <- struct{}{}: - default: - } -} - -func (it *ResultIterator[T]) ToSlice() []*T { - defer it.Close() - var result []*T - for { - row, ok := it.Next() - if !ok { - err := it.rows.Err() - if err != nil { - panic(oops.New(err, "error while iterating through db results")) - } - break - } - result = append(result, row) - } - return result -} - -func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) { - if len(path) < 1 { - panic(oops.New(nil, "can't follow an empty path")) - } - - if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct { - panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type())) - } - - // more informative panic recovery - var field reflect.StructField - defer func() { - if r := recover(); r != nil { - panic(oops.New(nil, "panic at field '%s': %v", field.Name, r)) - } - }() - - val := structPtrVal - for _, i := range path { - if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct { - if val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - val = val.Elem() - } - field = val.Type().Field(i) - val = val.Field(i) - } - return val, field -} - /* Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred. Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection. + +This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar. */ -func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) { +func Query[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) ([]*T, error) { it, err := QueryIterator[T](ctx, conn, query, args...) if err != nil { return nil, err @@ -223,7 +92,93 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte } } -func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*ResultIterator[T], error) { +/* +Identical to Query, but returns only the first result row. If there are no +rows in the result set, returns NotFound. +*/ +func QueryOne[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (*T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + result, hasRow := rows.Next() + if !hasRow { + return nil, NotFound + } + + return result, nil +} + +/* +Identical to Query, but returns concrete values instead of pointers. More convenient +for primitive types. +*/ +func QueryScalar[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) ([]T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []T + for { + val, hasRow := rows.Next() + if !hasRow { + break + } + result = append(result, *val) + } + + return result, nil +} + +/* +Identical to QueryScalar, but returns only the first result value. If there are +no rows in the result set, returns NotFound. +*/ +func QueryOneScalar[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (T, error) { + rows, err := QueryIterator[T](ctx, conn, query, args...) + if err != nil { + var zero T + return zero, err + } + defer rows.Close() + + result, hasRow := rows.Next() + if !hasRow { + var zero T + return zero, NotFound + } + + return *result, nil +} + +/* +Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use. +*/ +func QueryIterator[T any]( + ctx context.Context, + conn ConnOrTx, + query string, + args ...any, +) (*Iterator[T], error) { var destExample T destType := reflect.TypeOf(destExample) @@ -237,11 +192,12 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args return nil, err } - it := &ResultIterator[T]{ - fieldPaths: compiled.fieldPaths, - rows: rows, - destType: compiled.destType, - closed: make(chan struct{}, 1), + it := &Iterator[T]{ + fieldPaths: compiled.fieldPaths, + rows: rows, + destType: compiled.destType, + destTypeIsScalar: typeIsQueryable(compiled.destType), + closed: make(chan struct{}, 1), } // Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold @@ -261,21 +217,6 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args return it, nil } -func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) { - rows, err := QueryIterator[T](ctx, conn, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - result, hasRow := rows.Next() - if !hasRow { - return nil, NotFound - } - - return result, nil -} - // TODO: QueryFunc? type compiledQuery struct { @@ -295,7 +236,7 @@ func compileQuery(query string, destType reflect.Type) compiledQuery { // must be a struct, and we will plonk that struct's fields into the query. if destType.Kind() != reflect.Struct { - panic("$columns can only be used when querying into some kind of struct") + panic("$columns can only be used when querying into a struct") } var prefix []string @@ -436,75 +377,162 @@ func typeIsQueryable(t reflect.Type) bool { return false } -// TODO: Delete in favor of `QueryOne` -func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) { - rows, err := conn.Query(ctx, query, args...) - if err != nil { - return nil, err +type columnName []string + +// A path to a particular field in query's destination type. Each index in the slice +// corresponds to a field index for use with Field on a reflect.Type or reflect.Value. +type fieldPath []int + +type Iterator[T any] struct { + fieldPaths []fieldPath + rows pgx.Rows + destType reflect.Type + destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function. + closed chan struct{} +} + +func (it *Iterator[T]) Next() (*T, bool) { + // TODO(ben): What happens if this panics? Does it leak resources? Do we need + // to put a recover() here and close the rows? + + hasNext := it.rows.Next() + if !hasNext { + it.Close() + return nil, false } - defer rows.Close() - if rows.Next() { - vals, err := rows.Values() - if err != nil { - panic(err) - } + result := reflect.New(it.destType) + vals, err := it.rows.Values() + if err != nil { + panic(err) + } + + if it.destTypeIsScalar { + // This type can be directly queried, meaning pgx recognizes it, it's + // a simple scalar thing, and we can just take the easy way out. if len(vals) != 1 { - return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals)) + panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals))) + } + result.Set(reflect.ValueOf(vals[0])) + return result.Interface().(*T), true + } else { + var currentField reflect.StructField + var currentValue reflect.Value + var currentIdx int + + // Better logging of panics in this confusing reflection process + defer func() { + if r := recover(); r != nil { + if currentValue.IsValid() { + logging.Error(). + Int("index", currentIdx). + Str("field name", currentField.Name). + Stringer("field type", currentField.Type). + Interface("value", currentValue.Interface()). + Stringer("value type", currentValue.Type()). + Msg("panic in iterator") + } + + if currentField.Name != "" { + panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r)) + } else { + panic(r) + } + } + }() + + for i, val := range vals { + currentIdx = i + if val == nil { + continue + } + + var field reflect.Value + field, currentField = followPathThroughStructs(result, it.fieldPaths[i]) + if field.Kind() == reflect.Ptr { + field.Set(reflect.New(field.Type().Elem())) + field = field.Elem() + } + + // Some actual values still come through as pointers (like net.IPNet). Dunno why. + // Regardless, we know it's not nil, so we can get at the contents. + valReflected := reflect.ValueOf(val) + if valReflected.Kind() == reflect.Ptr { + valReflected = valReflected.Elem() + } + currentValue = valReflected + + switch field.Kind() { + case reflect.Int: + field.SetInt(valReflected.Int()) + default: + field.Set(valReflected) + } + + currentField = reflect.StructField{} + currentValue = reflect.Value{} } - return vals[0], nil + return result.Interface().(*T), true } - - return nil, NotFound } -// TODO: Delete in favor of `QueryOne[string]` -func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return "", err - } - - switch r := result.(type) { - case string: - return r, nil +func (it *Iterator[any]) Close() { + it.rows.Close() + select { + case it.closed <- struct{}{}: default: - return "", oops.New(nil, "QueryString got a non-string result: %v", result) } } -// TODO: Delete in favor of `QueryOne[int]` -func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return 0, err - } - - switch r := result.(type) { - case int: - return r, nil - case int32: - return int(r), nil - case int64: - return int(r), nil - default: - return 0, oops.New(nil, "QueryInt got a non-int result: %v", result) +/* +Pulls all the remaining values into a slice, and closes the iterator. +*/ +func (it *Iterator[T]) ToSlice() []*T { + defer it.Close() + var result []*T + for { + row, ok := it.Next() + if !ok { + err := it.rows.Err() + if err != nil { + panic(oops.New(err, "error while iterating through db results")) + } + break + } + result = append(result, row) } + return result } -// TODO: Delete in favor of `QueryOne[bool]` -func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) { - result, err := QueryScalar(ctx, conn, query, args...) - if err != nil { - return false, err +func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) { + if len(path) < 1 { + panic(oops.New(nil, "can't follow an empty path")) } - switch r := result.(type) { - case bool: - return r, nil - default: - return false, oops.New(nil, "QueryBool got a non-bool result: %v", result) + if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct { + panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type())) } + + // more informative panic recovery + var field reflect.StructField + defer func() { + if r := recover(); r != nil { + panic(oops.New(nil, "panic at field '%s': %v", field.Name, r)) + } + }() + + val := structPtrVal + for _, i := range path { + if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + field = val.Type().Field(i) + val = val.Field(i) + } + return val, field } diff --git a/src/db/query_builder.go b/src/db/query_builder.go index 80a687f..d8dcb41 100644 --- a/src/db/query_builder.go +++ b/src/db/query_builder.go @@ -7,7 +7,7 @@ import ( type QueryBuilder struct { sql strings.Builder - args []interface{} + args []any } /* @@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number. foo ARG1 bar ARG2 baz $? foo ARG1 bar ARG2 baz ARG3 */ -func (qb *QueryBuilder) Add(sql string, args ...interface{}) { +func (qb *QueryBuilder) Add(sql string, args ...any) { numPlaceholders := strings.Count(sql, "$?") if numPlaceholders != len(args) { panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args))) @@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string { return qb.sql.String() } -func (qb *QueryBuilder) Args() []interface{} { +func (qb *QueryBuilder) Args() []any { return qb.args } diff --git a/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go b/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go index 552319b..29168d0 100644 --- a/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go +++ b/src/migration/migrations/2021-11-06T033930Z_PersonalProjects.go @@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error { // Port "jam snippets" to use a tag // - jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`) + jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`) if err != nil { return oops.New(err, "failed to create jam tag") } diff --git a/src/models/subforum.go b/src/models/subforum.go index d65346a..1bed9e2 100644 --- a/src/models/subforum.go +++ b/src/models/subforum.go @@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum { } func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { - type subforumRow struct { - Subforum Subforum `db:"sf"` - } - rowsSlice, err := db.Query(ctx, conn, subforumRow{}, + subforums, err := db.Query[Subforum](ctx, conn, ` SELECT $columns - FROM - handmade_subforum as sf + FROM handmade_subforum ORDER BY sort, id ASC `, ) @@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { panic(oops.New(err, "failed to fetch subforum tree")) } - sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice)) - for _, row := range rowsSlice { - sf := row.(*subforumRow).Subforum - sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf} + sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums)) + for _, sf := range subforums { + sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf} } for _, node := range sfTreeMap { @@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { } } - for _, row := range rowsSlice { + for _, cat := range subforums { // NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order. - cat := row.(*subforumRow).Subforum node := sfTreeMap[cat.ID] if node.Parent != nil { node.Parent.Children = append(node.Parent.Children, node)