Rework and document the db API

Tests still pass, at least...now for everything else
This commit is contained in:
Ben Visness 2022-04-16 19:46:17 -05:00
parent a2917b98c0
commit 97360a1998
4 changed files with 258 additions and 236 deletions

View File

@ -29,9 +29,9 @@ var NotFound = errors.New("not found")
// This interface should match both a direct pgx connection or a pgx transaction.
type ConnOrTx interface {
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
// Both raw database connections and transactions in pgx can begin/commit
// transactions. For database connections it does the obvious thing; for
@ -42,6 +42,8 @@ type ConnOrTx interface {
var connInfo = pgtype.NewConnInfo()
// Creates a new connection to the HMN database.
// This connection is not safe for concurrent use.
func NewConn() *pgx.Conn {
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
if err != nil {
@ -51,6 +53,8 @@ func NewConn() *pgx.Conn {
return conn
}
// Creates a connection pool for the HMN database.
// The resulting pool is safe for concurrent use.
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
@ -67,154 +71,19 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn
}
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type ResultIterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
closed chan struct{}
}
func (it *ResultIterator[T]) Next() (*T, bool) {
hasNext := it.rows.Next()
if !hasNext {
it.Close()
return nil, false
}
result := reflect.New(it.destType)
vals, err := it.rows.Values()
if err != nil {
panic(err)
}
// Better logging of panics in this confusing reflection process
var currentField reflect.StructField
var currentValue reflect.Value
var currentIdx int
defer func() {
if r := recover(); r != nil {
if currentValue.IsValid() {
logging.Error().
Int("index", currentIdx).
Str("field name", currentField.Name).
Stringer("field type", currentField.Type).
Interface("value", currentValue.Interface()).
Stringer("value type", currentValue.Type()).
Msg("panic in iterator")
}
if currentField.Name != "" {
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
} else {
panic(r)
}
}
}()
for i, val := range vals {
currentIdx = i
if val == nil {
continue
}
var field reflect.Value
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
if field.Kind() == reflect.Ptr {
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
// Regardless, we know it's not nil, so we can get at the contents.
valReflected := reflect.ValueOf(val)
if valReflected.Kind() == reflect.Ptr {
valReflected = valReflected.Elem()
}
currentValue = valReflected
switch field.Kind() {
case reflect.Int:
field.SetInt(valReflected.Int())
default:
field.Set(valReflected)
}
currentField = reflect.StructField{}
currentValue = reflect.Value{}
}
return result.Interface().(*T), true
}
func (it *ResultIterator[any]) Close() {
it.rows.Close()
select {
case it.closed <- struct{}{}:
default:
}
}
func (it *ResultIterator[T]) ToSlice() []*T {
defer it.Close()
var result []*T
for {
row, ok := it.Next()
if !ok {
err := it.rows.Err()
if err != nil {
panic(oops.New(err, "error while iterating through db results"))
}
break
}
result = append(result, row)
}
return result
}
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
if len(path) < 1 {
panic(oops.New(nil, "can't follow an empty path"))
}
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
}
// more informative panic recovery
var field reflect.StructField
defer func() {
if r := recover(); r != nil {
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
}
}()
val := structPtrVal
for _, i := range path {
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
field = val.Type().Field(i)
val = val.Field(i)
}
return val, field
}
/*
Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred.
Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection.
This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar.
*/
func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) ([]*T, error) {
func Query[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) ([]*T, error) {
it, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
@ -223,7 +92,93 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
}
}
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*ResultIterator[T], error) {
/*
Identical to Query, but returns only the first result row. If there are no
rows in the result set, returns NotFound.
*/
func QueryOne[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
return nil, NotFound
}
return result, nil
}
/*
Identical to Query, but returns concrete values instead of pointers. More convenient
for primitive types.
*/
func QueryScalar[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) ([]T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []T
for {
val, hasRow := rows.Next()
if !hasRow {
break
}
result = append(result, *val)
}
return result, nil
}
/*
Identical to QueryScalar, but returns only the first result value. If there are
no rows in the result set, returns NotFound.
*/
func QueryOneScalar[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
var zero T
return zero, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
var zero T
return zero, NotFound
}
return *result, nil
}
/*
Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use.
*/
func QueryIterator[T any](
ctx context.Context,
conn ConnOrTx,
query string,
args ...any,
) (*Iterator[T], error) {
var destExample T
destType := reflect.TypeOf(destExample)
@ -237,10 +192,11 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return nil, err
}
it := &ResultIterator[T]{
it := &Iterator[T]{
fieldPaths: compiled.fieldPaths,
rows: rows,
destType: compiled.destType,
destTypeIsScalar: typeIsQueryable(compiled.destType),
closed: make(chan struct{}, 1),
}
@ -261,21 +217,6 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return it, nil
}
func QueryOne[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*T, error) {
rows, err := QueryIterator[T](ctx, conn, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result, hasRow := rows.Next()
if !hasRow {
return nil, NotFound
}
return result, nil
}
// TODO: QueryFunc?
type compiledQuery struct {
@ -295,7 +236,7 @@ func compileQuery(query string, destType reflect.Type) compiledQuery {
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
panic("$columns can only be used when querying into a struct")
}
var prefix []string
@ -436,75 +377,162 @@ func typeIsQueryable(t reflect.Type) bool {
return false
}
// TODO: Delete in favor of `QueryOne`
func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) {
rows, err := conn.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
type columnName []string
if rows.Next() {
vals, err := rows.Values()
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type Iterator[T any] struct {
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function.
closed chan struct{}
}
func (it *Iterator[T]) Next() (*T, bool) {
// TODO(ben): What happens if this panics? Does it leak resources? Do we need
// to put a recover() here and close the rows?
hasNext := it.rows.Next()
if !hasNext {
it.Close()
return nil, false
}
result := reflect.New(it.destType)
vals, err := it.rows.Values()
if err != nil {
panic(err)
}
if it.destTypeIsScalar {
// This type can be directly queried, meaning pgx recognizes it, it's
// a simple scalar thing, and we can just take the easy way out.
if len(vals) != 1 {
return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals))
panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals)))
}
result.Set(reflect.ValueOf(vals[0]))
return result.Interface().(*T), true
} else {
var currentField reflect.StructField
var currentValue reflect.Value
var currentIdx int
// Better logging of panics in this confusing reflection process
defer func() {
if r := recover(); r != nil {
if currentValue.IsValid() {
logging.Error().
Int("index", currentIdx).
Str("field name", currentField.Name).
Stringer("field type", currentField.Type).
Interface("value", currentValue.Interface()).
Stringer("value type", currentValue.Type()).
Msg("panic in iterator")
}
return vals[0], nil
if currentField.Name != "" {
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
} else {
panic(r)
}
}
}()
for i, val := range vals {
currentIdx = i
if val == nil {
continue
}
return nil, NotFound
}
// TODO: Delete in favor of `QueryOne[string]`
func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) {
result, err := QueryScalar(ctx, conn, query, args...)
if err != nil {
return "", err
var field reflect.Value
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
if field.Kind() == reflect.Ptr {
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
switch r := result.(type) {
case string:
return r, nil
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
// Regardless, we know it's not nil, so we can get at the contents.
valReflected := reflect.ValueOf(val)
if valReflected.Kind() == reflect.Ptr {
valReflected = valReflected.Elem()
}
currentValue = valReflected
switch field.Kind() {
case reflect.Int:
field.SetInt(valReflected.Int())
default:
return "", oops.New(nil, "QueryString got a non-string result: %v", result)
field.Set(valReflected)
}
currentField = reflect.StructField{}
currentValue = reflect.Value{}
}
return result.Interface().(*T), true
}
}
// TODO: Delete in favor of `QueryOne[int]`
func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) {
result, err := QueryScalar(ctx, conn, query, args...)
if err != nil {
return 0, err
}
switch r := result.(type) {
case int:
return r, nil
case int32:
return int(r), nil
case int64:
return int(r), nil
func (it *Iterator[any]) Close() {
it.rows.Close()
select {
case it.closed <- struct{}{}:
default:
return 0, oops.New(nil, "QueryInt got a non-int result: %v", result)
}
}
// TODO: Delete in favor of `QueryOne[bool]`
func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) {
result, err := QueryScalar(ctx, conn, query, args...)
/*
Pulls all the remaining values into a slice, and closes the iterator.
*/
func (it *Iterator[T]) ToSlice() []*T {
defer it.Close()
var result []*T
for {
row, ok := it.Next()
if !ok {
err := it.rows.Err()
if err != nil {
return false, err
panic(oops.New(err, "error while iterating through db results"))
}
break
}
result = append(result, row)
}
return result
}
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
if len(path) < 1 {
panic(oops.New(nil, "can't follow an empty path"))
}
switch r := result.(type) {
case bool:
return r, nil
default:
return false, oops.New(nil, "QueryBool got a non-bool result: %v", result)
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
}
// more informative panic recovery
var field reflect.StructField
defer func() {
if r := recover(); r != nil {
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
}
}()
val := structPtrVal
for _, i := range path {
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
field = val.Type().Field(i)
val = val.Field(i)
}
return val, field
}

View File

@ -7,7 +7,7 @@ import (
type QueryBuilder struct {
sql strings.Builder
args []interface{}
args []any
}
/*
@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number.
foo ARG1 bar ARG2 baz $?
foo ARG1 bar ARG2 baz ARG3
*/
func (qb *QueryBuilder) Add(sql string, args ...interface{}) {
func (qb *QueryBuilder) Add(sql string, args ...any) {
numPlaceholders := strings.Count(sql, "$?")
if numPlaceholders != len(args) {
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string {
return qb.sql.String()
}
func (qb *QueryBuilder) Args() []interface{} {
func (qb *QueryBuilder) Args() []any {
return qb.args
}

View File

@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error {
// Port "jam snippets" to use a tag
//
jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
if err != nil {
return oops.New(err, "failed to create jam tag")
}

View File

@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum {
}
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
type subforumRow struct {
Subforum Subforum `db:"sf"`
}
rowsSlice, err := db.Query(ctx, conn, subforumRow{},
subforums, err := db.Query[Subforum](ctx, conn,
`
SELECT $columns
FROM
handmade_subforum as sf
FROM handmade_subforum
ORDER BY sort, id ASC
`,
)
@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
panic(oops.New(err, "failed to fetch subforum tree"))
}
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice))
for _, row := range rowsSlice {
sf := row.(*subforumRow).Subforum
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums))
for _, sf := range subforums {
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf}
}
for _, node := range sfTreeMap {
@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
}
}
for _, row := range rowsSlice {
for _, cat := range subforums {
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
cat := row.(*subforumRow).Subforum
node := sfTreeMap[cat.ID]
if node.Parent != nil {
node.Parent.Children = append(node.Parent.Children, node)