208 lines
5.0 KiB
Go
208 lines
5.0 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.handmade.network/hmn/hmn/src/config"
|
|
"git.handmade.network/hmn/hmn/src/oops"
|
|
"github.com/jackc/pgtype"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/log/zerologadapter"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
var connInfo = pgtype.NewConnInfo()
|
|
|
|
func NewConn() *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
|
if err != nil {
|
|
panic(oops.New(err, "failed to connect to database"))
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
|
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
|
|
|
cfg.MinConns = minConns
|
|
cfg.MaxConns = maxConns
|
|
cfg.ConnConfig.Logger = zerologadapter.NewLogger(log.Logger)
|
|
cfg.ConnConfig.LogLevel = config.Config.Postgres.LogLevel
|
|
|
|
conn, err := pgxpool.ConnectConfig(context.Background(), cfg)
|
|
if err != nil {
|
|
panic(oops.New(err, "failed to create database connection pool"))
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
type StructQueryIterator struct {
|
|
fieldPaths [][]int
|
|
rows pgx.Rows
|
|
destType reflect.Type
|
|
}
|
|
|
|
func (it *StructQueryIterator) Next() (interface{}, bool) {
|
|
hasNext := it.rows.Next()
|
|
if !hasNext {
|
|
return nil, false
|
|
}
|
|
|
|
result := reflect.New(it.destType)
|
|
|
|
vals, err := it.rows.Values()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for i, val := range vals {
|
|
if val == nil {
|
|
continue
|
|
}
|
|
|
|
field := followPathThroughStructs(result, it.fieldPaths[i])
|
|
if field.Kind() == reflect.Ptr {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field = field.Elem()
|
|
}
|
|
|
|
switch field.Kind() {
|
|
case reflect.Int:
|
|
field.SetInt(reflect.ValueOf(val).Int())
|
|
default:
|
|
field.Set(reflect.ValueOf(val))
|
|
}
|
|
}
|
|
|
|
return result.Interface(), true
|
|
}
|
|
|
|
func (it *StructQueryIterator) Close() {
|
|
it.rows.Close()
|
|
}
|
|
|
|
func (it *StructQueryIterator) ToSlice() []interface{} {
|
|
defer it.Close()
|
|
var result []interface{}
|
|
for {
|
|
row, ok := it.Next()
|
|
if !ok {
|
|
break
|
|
}
|
|
result = append(result, row)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func followPathThroughStructs(structVal reflect.Value, path []int) reflect.Value {
|
|
if len(path) < 1 {
|
|
panic("can't follow an empty path")
|
|
}
|
|
|
|
val := structVal
|
|
for _, i := range path {
|
|
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
|
|
if val.IsNil() {
|
|
val.Set(reflect.New(val.Type()))
|
|
}
|
|
val = val.Elem()
|
|
}
|
|
val = val.Field(i)
|
|
}
|
|
return val
|
|
}
|
|
|
|
func Query(ctx context.Context, conn *pgxpool.Pool, destExample interface{}, query string, args ...interface{}) (StructQueryIterator, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
destType := reflect.TypeOf(destExample)
|
|
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, "")
|
|
if err != nil {
|
|
return StructQueryIterator{}, oops.New(err, "failed to generate column names")
|
|
}
|
|
|
|
columnNamesString := strings.Join(columnNames, ", ")
|
|
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
|
|
|
rows, err := conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
panic("query exceeded its deadline")
|
|
}
|
|
return StructQueryIterator{}, err
|
|
}
|
|
|
|
return StructQueryIterator{
|
|
fieldPaths: fieldPaths,
|
|
rows: rows,
|
|
destType: destType,
|
|
}, nil
|
|
}
|
|
|
|
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix string) ([]string, [][]int, error) {
|
|
var columnNames []string
|
|
var fieldPaths [][]int
|
|
|
|
if destType.Kind() == reflect.Ptr {
|
|
destType = destType.Elem()
|
|
}
|
|
|
|
if destType.Kind() != reflect.Struct {
|
|
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
|
|
}
|
|
|
|
for i := 0; i < destType.NumField(); i++ {
|
|
field := destType.Field(i)
|
|
path := append(pathSoFar, i)
|
|
|
|
if columnName := field.Tag.Get("db"); columnName != "" {
|
|
fieldType := field.Type
|
|
if destType.Kind() == reflect.Ptr {
|
|
fieldType = destType.Elem()
|
|
}
|
|
|
|
_, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(fieldType)) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
|
|
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
|
|
|
|
if fieldType.Kind() == reflect.Struct && !isRecognizedByPgtype {
|
|
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, columnName+".")
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
columnNames = append(columnNames, subCols...)
|
|
fieldPaths = append(fieldPaths, subPaths...)
|
|
} else {
|
|
columnNames = append(columnNames, prefix+columnName)
|
|
fieldPaths = append(fieldPaths, path)
|
|
}
|
|
}
|
|
}
|
|
|
|
return columnNames, fieldPaths, nil
|
|
}
|
|
|
|
var ErrNoMatchingRows = errors.New("no matching rows")
|
|
|
|
func QueryOne(ctx context.Context, conn *pgxpool.Pool, destExample interface{}, query string, args ...interface{}) (interface{}, error) {
|
|
rows, err := Query(ctx, conn, destExample, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
result, hasRow := rows.Next()
|
|
if !hasRow {
|
|
return nil, ErrNoMatchingRows
|
|
}
|
|
|
|
return result, nil
|
|
}
|