Rework and document the db API

Tests still pass, at least...now for everything else
This commit is contained in:
Ben Visness 2022-04-16 19:46:17 -05:00
parent a2917b98c0
commit 97360a1998
4 changed files with 258 additions and 236 deletions

View File

@ -29,9 +29,9 @@ var NotFound = errors.New("not found")
// This interface should match both a direct pgx connection or a pgx transaction. // This interface should match both a direct pgx connection or a pgx transaction.
type ConnOrTx interface { type ConnOrTx interface {
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
// Both raw database connections and transactions in pgx can begin/commit // Both raw database connections and transactions in pgx can begin/commit
// transactions. For database connections it does the obvious thing; for // transactions. For database connections it does the obvious thing; for
@ -42,6 +42,8 @@ type ConnOrTx interface {
var connInfo = pgtype.NewConnInfo() var connInfo = pgtype.NewConnInfo()
// Creates a new connection to the HMN database.
// This connection is not safe for concurrent use.
func NewConn() *pgx.Conn { func NewConn() *pgx.Conn {
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN()) conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
if err != nil { if err != nil {
@ -51,6 +53,8 @@ func NewConn() *pgx.Conn {
return conn return conn
} }
// Creates a connection pool for the HMN database.
// The resulting pool is safe for concurrent use.
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN()) cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
@ -67,154 +71,19 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn return conn
} }
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type ResultIterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
closed chan struct{}
}
func (it *ResultIterator[T]) Next() (*T, bool) {
hasNext := it.rows.Next()
if !hasNext {
it.Close()
return nil, false
}
result := reflect.New(it.destType)
vals, err := it.rows.Values()
if err != nil {
panic(err)
}
// Better logging of panics in this confusing reflection process
var currentField reflect.StructField
var currentValue reflect.Value
var currentIdx int
defer func() {
if r := recover(); r != nil {
if currentValue.IsValid() {
logging.Error().
Int("index", currentIdx).
Str("field name", currentField.Name).
Stringer("field type", currentField.Type).
Interface("value", currentValue.Interface()).
Stringer("value type", currentValue.Type()).
Msg("panic in iterator")
}
if currentField.Name != "" {
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
} else {
panic(r)
}
}
}()
for i, val := range vals {
currentIdx = i
if val == nil {
continue
}
var field reflect.Value
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
if field.Kind() == reflect.Ptr {
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
// Regardless, we know it's not nil, so we can get at the contents.
valReflected := reflect.ValueOf(val)
if valReflected.Kind() == reflect.Ptr {
valReflected = valReflected.Elem()
}
currentValue = valReflected
switch field.Kind() {
case reflect.Int:
field.SetInt(valReflected.Int())
default:
field.Set(valReflected)
}
currentField = reflect.StructField{}
currentValue = reflect.Value{}
}
return result.Interface().(*T), true
}
func (it *ResultIterator[any]) Close() {
it.rows.Close()
select {
case it.closed <- struct{}{}:
default:
}
}
func (it *ResultIterator[T]) ToSlice() []*T {
defer it.Close()
var result []*T
for {
row, ok := it.Next()
if !ok {
err := it.rows.Err()
if err != nil {
panic(oops.New(err, "error while iterating through db results"))
}
break
}
result = append(result, row)
}
return result
}
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
if len(path) < 1 {
panic(oops.New(nil, "can't follow an empty path"))
}
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
}
// more informative panic recovery
var field reflect.StructField
defer func() {
if r := recover(); r != nil {
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
}
}()
val := structPtrVal
for _, i := range path {
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
field = val.Type().Field(i)
val = val.Field(i)
}
return val, field
}
/* /*
Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred. Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred.
Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection. Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection.
This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar.
*/ */
func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) { func Query[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) ([]*T, error) {
it, err := QueryIterator[T](ctx, conn, query, args...) it, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -223,7 +92,93 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
} }
} }
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*ResultIterator[T], error) { /*
Identical to Query, but returns only the first result row. If there are no
rows in the result set, returns NotFound.
*/
func QueryOne[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
return nil, NotFound
}
return result, nil
}
/*
Identical to Query, but returns concrete values instead of pointers. More convenient
for primitive types.
*/
func QueryScalar[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) ([]T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []T
for {
val, hasRow := rows.Next()
if !hasRow {
break
}
result = append(result, *val)
}
return result, nil
}
/*
Identical to QueryScalar, but returns only the first result value. If there are
no rows in the result set, returns NotFound.
*/
func QueryOneScalar[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
var zero T
return zero, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
var zero T
return zero, NotFound
}
return *result, nil
}
/*
Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use.
*/
func QueryIterator[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (*Iterator[T], error) {
var destExample T var destExample T
destType := reflect.TypeOf(destExample) destType := reflect.TypeOf(destExample)
@ -237,10 +192,11 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return nil, err return nil, err
} }
it := &ResultIterator[T]{ it := &Iterator[T]{
fieldPaths: compiled.fieldPaths, fieldPaths: compiled.fieldPaths,
rows: rows, rows: rows,
destType: compiled.destType, destType: compiled.destType,
destTypeIsScalar: typeIsQueryable(compiled.destType),
closed: make(chan struct{}, 1), closed: make(chan struct{}, 1),
} }
@ -261,21 +217,6 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return it, nil return it, nil
} }
func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
return nil, NotFound
}
return result, nil
}
// TODO: QueryFunc? // TODO: QueryFunc?
type compiledQuery struct { type compiledQuery struct {
@ -295,7 +236,7 @@ func compileQuery(query string, destType reflect.Type) compiledQuery {
// must be a struct, and we will plonk that struct's fields into the query. // must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct { if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct") panic("$columns can only be used when querying into a struct")
} }
var prefix []string var prefix []string
@ -436,75 +377,162 @@ func typeIsQueryable(t reflect.Type) bool {
return false return false
} }
// TODO: Delete in favor of `QueryOne` type columnName []string
func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) {
rows, err := conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
if rows.Next() { // A path to a particular field in query's destination type. Each index in the slice
vals, err := rows.Values() // corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type Iterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function.
closed chan struct{}
}
func (it *Iterator[T]) Next() (*T, bool) {
// TODO(ben): What happens if this panics? Does it leak resources? Do we need
// to put a recover() here and close the rows?
hasNext := it.rows.Next()
if !hasNext {
it.Close()
return nil, false
}
result := reflect.New(it.destType)
vals, err := it.rows.Values()
if err != nil { if err != nil {
panic(err) panic(err)
} }
if it.destTypeIsScalar {
// This type can be directly queried, meaning pgx recognizes it, it's
// a simple scalar thing, and we can just take the easy way out.
if len(vals) != 1 { if len(vals) != 1 {
return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals)) panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals)))
}
result.Set(reflect.ValueOf(vals[0]))
return result.Interface().(*T), true
} else {
var currentField reflect.StructField
var currentValue reflect.Value
var currentIdx int
// Better logging of panics in this confusing reflection process
defer func() {
if r := recover(); r != nil {
if currentValue.IsValid() {
logging.Error().
Int("index", currentIdx).
Str("field name", currentField.Name).
Stringer("field type", currentField.Type).
Interface("value", currentValue.Interface()).
Stringer("value type", currentValue.Type()).
Msg("panic in iterator")
} }
return vals[0], nil if currentField.Name != "" {
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
} else {
panic(r)
}
}
}()
for i, val := range vals {
currentIdx = i
if val == nil {
continue
} }
return nil, NotFound var field reflect.Value
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
if field.Kind() == reflect.Ptr {
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
} }
// TODO: Delete in favor of `QueryOne[string]` // Some actual values still come through as pointers (like net.IPNet). Dunno why.
func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { // Regardless, we know it's not nil, so we can get at the contents.
result, err := QueryScalar(ctx, conn, query, args...) valReflected := reflect.ValueOf(val)
if err != nil { if valReflected.Kind() == reflect.Ptr {
return "", err valReflected = valReflected.Elem()
} }
currentValue = valReflected
switch r := result.(type) { switch field.Kind() {
case string: case reflect.Int:
return r, nil field.SetInt(valReflected.Int())
default: default:
return "", oops.New(nil, "QueryString got a non-string result: %v", result) field.Set(valReflected)
}
currentField = reflect.StructField{}
currentValue = reflect.Value{}
}
return result.Interface().(*T), true
} }
} }
// TODO: Delete in favor of `QueryOne[int]` func (it *Iterator[any]) Close() {
func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) { it.rows.Close()
result, err := QueryScalar(ctx, conn, query, args...) select {
if err != nil { case it.closed <- struct{}{}:
return 0, err
}
switch r := result.(type) {
case int:
return r, nil
case int32:
return int(r), nil
case int64:
return int(r), nil
default: default:
return 0, oops.New(nil, "QueryInt got a non-int result: %v", result)
} }
} }
// TODO: Delete in favor of `QueryOne[bool]` /*
func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) { Pulls all the remaining values into a slice, and closes the iterator.
result, err := QueryScalar(ctx, conn, query, args...) */
func (it *Iterator[T]) ToSlice() []*T {
defer it.Close()
var result []*T
for {
row, ok := it.Next()
if !ok {
err := it.rows.Err()
if err != nil { if err != nil {
return false, err panic(oops.New(err, "error while iterating through db results"))
}
break
}
result = append(result, row)
}
return result
} }
switch r := result.(type) { func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
case bool: if len(path) < 1 {
return r, nil panic(oops.New(nil, "can't follow an empty path"))
default:
return false, oops.New(nil, "QueryBool got a non-bool result: %v", result)
} }
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
}
// more informative panic recovery
var field reflect.StructField
defer func() {
if r := recover(); r != nil {
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
}
}()
val := structPtrVal
for _, i := range path {
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
field = val.Type().Field(i)
val = val.Field(i)
}
return val, field
} }

View File

@ -7,7 +7,7 @@ import (
type QueryBuilder struct { type QueryBuilder struct {
sql strings.Builder sql strings.Builder
args []interface{} args []any
} }
/* /*
@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number.
foo ARG1 bar ARG2 baz $? foo ARG1 bar ARG2 baz $?
foo ARG1 bar ARG2 baz ARG3 foo ARG1 bar ARG2 baz ARG3
*/ */
func (qb *QueryBuilder) Add(sql string, args ...interface{}) { func (qb *QueryBuilder) Add(sql string, args ...any) {
numPlaceholders := strings.Count(sql, "$?") numPlaceholders := strings.Count(sql, "$?")
if numPlaceholders != len(args) { if numPlaceholders != len(args) {
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args))) panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string {
return qb.sql.String() return qb.sql.String()
} }
func (qb *QueryBuilder) Args() []interface{} { func (qb *QueryBuilder) Args() []any {
return qb.args return qb.args
} }

View File

@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error {
// Port "jam snippets" to use a tag // Port "jam snippets" to use a tag
// //
jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`) jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
if err != nil { if err != nil {
return oops.New(err, "failed to create jam tag") return oops.New(err, "failed to create jam tag")
} }

View File

@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum {
} }
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree { func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
type subforumRow struct { subforums, err := db.Query[Subforum](ctx, conn,
Subforum Subforum `db:"sf"`
}
rowsSlice, err := db.Query(ctx, conn, subforumRow{},
` `
SELECT $columns SELECT $columns
FROM FROM handmade_subforum
handmade_subforum as sf
ORDER BY sort, id ASC ORDER BY sort, id ASC
`, `,
) )
@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
panic(oops.New(err, "failed to fetch subforum tree")) panic(oops.New(err, "failed to fetch subforum tree"))
} }
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice)) sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums))
for _, row := range rowsSlice { for _, sf := range subforums {
sf := row.(*subforumRow).Subforum sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf}
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
} }
for _, node := range sfTreeMap { for _, node := range sfTreeMap {
@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
} }
} }
for _, row := range rowsSlice { for _, cat := range subforums {
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order. // NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
cat := row.(*subforumRow).Subforum
node := sfTreeMap[cat.ID] node := sfTreeMap[cat.ID]
if node.Parent != nil { if node.Parent != nil {
node.Parent.Children = append(node.Parent.Children, node) node.Parent.Children = append(node.Parent.Children, node)