hmn/src/db/db.go

140 lines
3.3 KiB
Go

package db
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/oops"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/log/zerologadapter"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog/log"
)
func NewConn() *pgx.Conn {
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
if err != nil {
panic(oops.New(err, "failed to connect to database"))
}
return conn
}
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
cfg.MinConns = minConns
cfg.MaxConns = maxConns
cfg.ConnConfig.Logger = zerologadapter.NewLogger(log.Logger)
cfg.ConnConfig.LogLevel = config.Config.Postgres.LogLevel
conn, err := pgxpool.ConnectConfig(context.Background(), cfg)
if err != nil {
panic(oops.New(err, "failed to create database connection pool"))
}
return conn
}
type StructQueryIterator struct {
fieldIndices []int
rows pgx.Rows
}
func (it *StructQueryIterator) Next(dest interface{}) bool {
hasNext := it.rows.Next()
if !hasNext {
return false
}
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
panic(oops.New(nil, "Next requires a pointer type; got %v", v.Kind()))
}
vals, err := it.rows.Values()
if err != nil {
panic(err)
}
fmt.Printf("%#v\n", vals)
for i, val := range vals {
field := v.Elem().Field(it.fieldIndices[i])
switch field.Kind() {
case reflect.Int:
field.SetInt(reflect.ValueOf(val).Int())
case reflect.Ptr:
// TODO: I'm pretty sure we don't handle nullable ints correctly lol. Maybe this needs to be a function somehow, and recurse onto itself?? Reflection + recursion sounds like a great idea
if val != nil {
field.Set(reflect.New(field.Type().Elem()))
field.Elem().Set(reflect.ValueOf(val))
}
default:
field.Set(reflect.ValueOf(val))
}
}
return true
}
func (it *StructQueryIterator) Close() {
it.rows.Close()
}
func QueryToStructs(ctx context.Context, conn *pgxpool.Pool, destType interface{}, query string, args ...interface{}) (StructQueryIterator, error) {
var fieldIndices []int
var columnNames []string
t := reflect.TypeOf(destType)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return StructQueryIterator{}, oops.New(nil, "QueryToStructs requires a struct type or a pointer to a struct type")
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if columnName := f.Tag.Get("db"); columnName != "" {
fieldIndices = append(fieldIndices, i)
columnNames = append(columnNames, columnName)
}
}
columnNamesString := strings.Join(columnNames, ", ")
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
if err != nil {
return StructQueryIterator{}, err
}
return StructQueryIterator{
fieldIndices: fieldIndices,
rows: rows,
}, nil
}
var ErrNoMatchingRows = errors.New("no matching rows")
func QueryOneToStruct(ctx context.Context, conn *pgxpool.Pool, dest interface{}, query string, args ...interface{}) error {
rows, err := QueryToStructs(ctx, conn, dest, query, args...)
if err != nil {
return err
}
defer rows.Close()
hasRow := rows.Next(dest)
if !hasRow {
return ErrNoMatchingRows
}
return nil
}