package db import ( "context" "errors" "fmt" "reflect" "strings" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/oops" "github.com/google/uuid" "github.com/jackc/pgconn" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/log/zerologadapter" "github.com/jackc/pgx/v4/pgxpool" "github.com/rs/zerolog/log" ) /* Values of these kinds are ok to query even if they are not directly understood by pgtype. This is common for custom types like: type ThreadType int */ var queryableKinds = []reflect.Kind{ reflect.Int, } /* Checks if we are able to handle a particular type in a database query. This applies only to primitive types and not structs, since the database only returns individual primitive types and it is our job to stitch them back together into structs later. */ func typeIsQueryable(t reflect.Type) bool { _, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags // NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway if isRecognizedByPgtype { return true } else if t == reflect.TypeOf(uuid.UUID{}) { return true } // pgtype doesn't recognize it, but maybe it's a primitive type we can deal with k := t.Kind() for _, qk := range queryableKinds { if k == qk { return true } } return false } // This interface should match both a direct pgx connection or a pgx transaction. type ConnOrTx interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) } var connInfo = pgtype.NewConnInfo() func NewConn() *pgx.Conn { conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN()) if err != nil { panic(oops.New(err, "failed to connect to database")) } return conn } func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN()) cfg.MinConns = minConns cfg.MaxConns = maxConns cfg.ConnConfig.Logger = zerologadapter.NewLogger(log.Logger) cfg.ConnConfig.LogLevel = config.Config.Postgres.LogLevel conn, err := pgxpool.ConnectConfig(context.Background(), cfg) if err != nil { panic(oops.New(err, "failed to create database connection pool")) } return conn } type StructQueryIterator struct { fieldPaths [][]int rows pgx.Rows destType reflect.Type } func (it *StructQueryIterator) Next() (interface{}, bool) { hasNext := it.rows.Next() if !hasNext { return nil, false } result := reflect.New(it.destType) vals, err := it.rows.Values() if err != nil { panic(err) } // Better logging of panics in this confusing reflection process var currentField reflect.StructField var currentValue reflect.Value var currentIdx int defer func() { if r := recover(); r != nil { if currentValue.IsValid() { logging.Error(). Int("index", currentIdx). Str("field name", currentField.Name). Stringer("field type", currentField.Type). Interface("value", currentValue.Interface()). Stringer("value type", currentValue.Type()). Msg("panic in iterator") } if currentField.Name != "" { panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r)) } else { panic(r) } } }() for i, val := range vals { currentIdx = i if val == nil { continue } var field reflect.Value field, currentField = followPathThroughStructs(result, it.fieldPaths[i]) if field.Kind() == reflect.Ptr { field.Set(reflect.New(field.Type().Elem())) field = field.Elem() } // Some actual values still come through as pointers (like net.IPNet). Dunno why. // Regardless, we know it's not nil, so we can get at the contents. valReflected := reflect.ValueOf(val) if valReflected.Kind() == reflect.Ptr { valReflected = valReflected.Elem() } currentValue = valReflected switch field.Kind() { case reflect.Int: field.SetInt(valReflected.Int()) default: field.Set(valReflected) } currentField = reflect.StructField{} currentValue = reflect.Value{} } return result.Interface(), true } func (it *StructQueryIterator) Close() { it.rows.Close() } func (it *StructQueryIterator) ToSlice() []interface{} { defer it.Close() var result []interface{} for { row, ok := it.Next() if !ok { err := it.rows.Err() if err != nil { panic(oops.New(err, "error while iterating through db results")) } break } result = append(result, row) } return result } func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) { if len(path) < 1 { panic(oops.New(nil, "can't follow an empty path")) } if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct { panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type())) } // more informative panic recovery var field reflect.StructField defer func() { if r := recover(); r != nil { panic(oops.New(nil, "panic at field '%s': %v", field.Name, r)) } }() val := structPtrVal for _, i := range path { if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct { if val.IsNil() { val.Set(reflect.New(val.Type().Elem())) } val = val.Elem() } field = val.Type().Field(i) val = val.Field(i) } return val, field } func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) { destType := reflect.TypeOf(destExample) columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, "") if err != nil { return nil, oops.New(err, "failed to generate column names") } columnNamesString := strings.Join(columnNames, ", ") query = strings.Replace(query, "$columns", columnNamesString, -1) rows, err := conn.Query(ctx, query, args...) if err != nil { if errors.Is(err, context.DeadlineExceeded) { panic("query exceeded its deadline") } return nil, err } it := &StructQueryIterator{ fieldPaths: fieldPaths, rows: rows, destType: destType, } // Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold // open connections even after a request is cancelled, causing the app to deadlock. go func() { done := ctx.Done() if done == nil { return } <-done it.Close() }() return it, nil } func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix string) (names []string, paths [][]int, err error) { var columnNames []string var fieldPaths [][]int if destType.Kind() == reflect.Ptr { destType = destType.Elem() } if destType.Kind() != reflect.Struct { return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) } for i := 0; i < destType.NumField(); i++ { field := destType.Field(i) path := append(pathSoFar, i) if columnName := field.Tag.Get("db"); columnName != "" { fieldType := field.Type if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() } if typeIsQueryable(fieldType) { columnNames = append(columnNames, prefix+columnName) fieldPaths = append(fieldPaths, path) } else if fieldType.Kind() == reflect.Struct { subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, columnName+".") if err != nil { return nil, nil, err } columnNames = append(columnNames, subCols...) fieldPaths = append(fieldPaths, subPaths...) } else { return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) } } } return columnNames, fieldPaths, nil } var ErrNoMatchingRows = errors.New("no matching rows") func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) { rows, err := Query(ctx, conn, destExample, query, args...) if err != nil { return nil, err } defer rows.Close() result, hasRow := rows.Next() if !hasRow { return nil, ErrNoMatchingRows } return result, nil } func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) { rows, err := conn.Query(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() if rows.Next() { vals, err := rows.Values() if err != nil { panic(err) } if len(vals) != 1 { return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals)) } return vals[0], nil } return nil, ErrNoMatchingRows } func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) { result, err := QueryScalar(ctx, conn, query, args...) if err != nil { return "", err } switch r := result.(type) { case string: return r, nil default: return "", oops.New(nil, "QueryString got a non-string result: %v", result) } } func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) { result, err := QueryScalar(ctx, conn, query, args...) if err != nil { return 0, err } switch r := result.(type) { case int: return r, nil case int32: return int(r), nil case int64: return int(r), nil default: return 0, oops.New(nil, "QueryInt got a non-int result: %v", result) } } func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) { result, err := QueryScalar(ctx, conn, query, args...) if err != nil { return false, err } switch r := result.(type) { case bool: return r, nil default: return false, oops.New(nil, "QueryBool got a non-bool result: %v", result) } }