140 lines
3.3 KiB
Go
140 lines
3.3 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"git.handmade.network/hmn/hmn/src/config"
|
|
"git.handmade.network/hmn/hmn/src/oops"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/log/zerologadapter"
|
|
"github.com/jackc/pgx/v4/pgxpool"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
func NewConn() *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
|
if err != nil {
|
|
panic(oops.New(err, "failed to connect to database"))
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
|
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
|
|
|
cfg.MinConns = minConns
|
|
cfg.MaxConns = maxConns
|
|
cfg.ConnConfig.Logger = zerologadapter.NewLogger(log.Logger)
|
|
cfg.ConnConfig.LogLevel = config.Config.Postgres.LogLevel
|
|
|
|
conn, err := pgxpool.ConnectConfig(context.Background(), cfg)
|
|
if err != nil {
|
|
panic(oops.New(err, "failed to create database connection pool"))
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
type StructQueryIterator struct {
|
|
fieldIndices []int
|
|
rows pgx.Rows
|
|
}
|
|
|
|
func (it *StructQueryIterator) Next(dest interface{}) bool {
|
|
hasNext := it.rows.Next()
|
|
if !hasNext {
|
|
return false
|
|
}
|
|
|
|
v := reflect.ValueOf(dest)
|
|
if v.Kind() != reflect.Ptr {
|
|
panic(oops.New(nil, "Next requires a pointer type; got %v", v.Kind()))
|
|
}
|
|
|
|
vals, err := it.rows.Values()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
fmt.Printf("%#v\n", vals)
|
|
|
|
for i, val := range vals {
|
|
field := v.Elem().Field(it.fieldIndices[i])
|
|
switch field.Kind() {
|
|
case reflect.Int:
|
|
field.SetInt(reflect.ValueOf(val).Int())
|
|
case reflect.Ptr:
|
|
// TODO: I'm pretty sure we don't handle nullable ints correctly lol. Maybe this needs to be a function somehow, and recurse onto itself?? Reflection + recursion sounds like a great idea
|
|
if val != nil {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
field.Elem().Set(reflect.ValueOf(val))
|
|
}
|
|
default:
|
|
field.Set(reflect.ValueOf(val))
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (it *StructQueryIterator) Close() {
|
|
it.rows.Close()
|
|
}
|
|
|
|
func QueryToStructs(ctx context.Context, conn *pgxpool.Pool, destType interface{}, query string, args ...interface{}) (StructQueryIterator, error) {
|
|
var fieldIndices []int
|
|
var columnNames []string
|
|
|
|
t := reflect.TypeOf(destType)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
|
|
if t.Kind() != reflect.Struct {
|
|
return StructQueryIterator{}, oops.New(nil, "QueryToStructs requires a struct type or a pointer to a struct type")
|
|
}
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
f := t.Field(i)
|
|
if columnName := f.Tag.Get("db"); columnName != "" {
|
|
fieldIndices = append(fieldIndices, i)
|
|
columnNames = append(columnNames, columnName)
|
|
}
|
|
}
|
|
|
|
columnNamesString := strings.Join(columnNames, ", ")
|
|
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
|
|
|
rows, err := conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return StructQueryIterator{}, err
|
|
}
|
|
|
|
return StructQueryIterator{
|
|
fieldIndices: fieldIndices,
|
|
rows: rows,
|
|
}, nil
|
|
}
|
|
|
|
var ErrNoMatchingRows = errors.New("no matching rows")
|
|
|
|
func QueryOneToStruct(ctx context.Context, conn *pgxpool.Pool, dest interface{}, query string, args ...interface{}) error {
|
|
rows, err := QueryToStructs(ctx, conn, dest, query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
hasRow := rows.Next(dest)
|
|
if !hasRow {
|
|
return ErrNoMatchingRows
|
|
}
|
|
|
|
return nil
|
|
}
|