2021-03-09 08:05:07 +00:00
package db
import (
"context"
2021-03-21 20:38:37 +00:00
"errors"
2021-05-03 14:51:07 +00:00
"fmt"
2021-03-21 20:38:37 +00:00
"reflect"
"strings"
2021-03-09 08:05:07 +00:00
2021-03-11 03:39:24 +00:00
"git.handmade.network/hmn/hmn/src/config"
2021-05-03 14:51:07 +00:00
"git.handmade.network/hmn/hmn/src/logging"
2021-03-11 03:39:24 +00:00
"git.handmade.network/hmn/hmn/src/oops"
2021-06-22 09:50:40 +00:00
"github.com/google/uuid"
2021-03-31 03:55:19 +00:00
"github.com/jackc/pgtype"
2021-03-09 08:05:07 +00:00
"github.com/jackc/pgx/v4"
2021-03-21 20:38:37 +00:00
"github.com/jackc/pgx/v4/log/zerologadapter"
2021-03-09 08:05:07 +00:00
"github.com/jackc/pgx/v4/pgxpool"
2021-03-21 20:38:37 +00:00
"github.com/rs/zerolog/log"
2021-03-09 08:05:07 +00:00
)
2021-05-03 14:51:07 +00:00
/ *
Values of these kinds are ok to query even if they are not directly understood by pgtype .
This is common for custom types like :
type CategoryKind int
* /
var queryableKinds = [ ] reflect . Kind {
reflect . Int ,
}
/ *
Checks if we are able to handle a particular type in a database query . This applies only to
primitive types and not structs , since the database only returns individual primitive types
and it is our job to stitch them back together into structs later .
* /
func typeIsQueryable ( t reflect . Type ) bool {
_ , isRecognizedByPgtype := connInfo . DataTypeForValue ( reflect . New ( t ) . Elem ( ) . Interface ( ) ) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
if isRecognizedByPgtype {
return true
2021-06-22 09:50:40 +00:00
} else if t == reflect . TypeOf ( uuid . UUID { } ) {
return true
2021-05-03 14:51:07 +00:00
}
// pgtype doesn't recognize it, but maybe it's a primitive type we can deal with
k := t . Kind ( )
for _ , qk := range queryableKinds {
if k == qk {
return true
}
}
return false
}
2021-07-22 04:42:34 +00:00
// This interface should match both a direct pgx connection or a pgx transaction.
type ConnOrTx interface {
Query ( ctx context . Context , sql string , args ... interface { } ) ( pgx . Rows , error )
}
2021-03-31 03:55:19 +00:00
var connInfo = pgtype . NewConnInfo ( )
2021-03-09 08:05:07 +00:00
func NewConn ( ) * pgx . Conn {
conn , err := pgx . Connect ( context . Background ( ) , config . Config . Postgres . DSN ( ) )
if err != nil {
2021-03-11 03:19:39 +00:00
panic ( oops . New ( err , "failed to connect to database" ) )
2021-03-09 08:05:07 +00:00
}
return conn
}
func NewConnPool ( minConns , maxConns int32 ) * pgxpool . Pool {
2021-03-21 20:38:37 +00:00
cfg , err := pgxpool . ParseConfig ( config . Config . Postgres . DSN ( ) )
2021-03-09 08:05:07 +00:00
2021-03-21 20:38:37 +00:00
cfg . MinConns = minConns
cfg . MaxConns = maxConns
cfg . ConnConfig . Logger = zerologadapter . NewLogger ( log . Logger )
cfg . ConnConfig . LogLevel = config . Config . Postgres . LogLevel
2021-03-09 08:05:07 +00:00
2021-03-21 20:38:37 +00:00
conn , err := pgxpool . ConnectConfig ( context . Background ( ) , cfg )
2021-03-09 08:05:07 +00:00
if err != nil {
2021-03-11 03:19:39 +00:00
panic ( oops . New ( err , "failed to create database connection pool" ) )
2021-03-09 08:05:07 +00:00
}
return conn
}
2021-03-21 20:38:37 +00:00
type StructQueryIterator struct {
2021-03-31 03:55:19 +00:00
fieldPaths [ ] [ ] int
rows pgx . Rows
destType reflect . Type
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
func ( it * StructQueryIterator ) Next ( ) ( interface { } , bool ) {
2021-03-21 20:38:37 +00:00
hasNext := it . rows . Next ( )
if ! hasNext {
2021-03-31 03:55:19 +00:00
return nil , false
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
result := reflect . New ( it . destType )
2021-03-21 20:38:37 +00:00
vals , err := it . rows . Values ( )
if err != nil {
panic ( err )
}
2021-05-03 14:51:07 +00:00
// Better logging of panics in this confusing reflection process
var currentField reflect . StructField
var currentValue reflect . Value
2021-06-22 09:50:40 +00:00
var currentIdx int
2021-05-03 14:51:07 +00:00
defer func ( ) {
if r := recover ( ) ; r != nil {
if currentValue . IsValid ( ) {
logging . Error ( ) .
2021-06-22 09:50:40 +00:00
Int ( "index" , currentIdx ) .
2021-05-03 14:51:07 +00:00
Str ( "field name" , currentField . Name ) .
Stringer ( "field type" , currentField . Type ) .
Interface ( "value" , currentValue . Interface ( ) ) .
Stringer ( "value type" , currentValue . Type ( ) ) .
Msg ( "panic in iterator" )
}
if currentField . Name != "" {
panic ( fmt . Errorf ( "panic while processing field '%s': %v" , currentField . Name , r ) )
} else {
panic ( r )
}
}
} ( )
2021-03-21 20:38:37 +00:00
for i , val := range vals {
2021-06-22 09:50:40 +00:00
currentIdx = i
2021-03-31 03:55:19 +00:00
if val == nil {
continue
}
2021-05-03 14:51:07 +00:00
var field reflect . Value
field , currentField = followPathThroughStructs ( result , it . fieldPaths [ i ] )
2021-03-31 03:55:19 +00:00
if field . Kind ( ) == reflect . Ptr {
field . Set ( reflect . New ( field . Type ( ) . Elem ( ) ) )
field = field . Elem ( )
}
2021-04-24 04:27:45 +00:00
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
// Regardless, we know it's not nil, so we can get at the contents.
valReflected := reflect . ValueOf ( val )
if valReflected . Kind ( ) == reflect . Ptr {
valReflected = valReflected . Elem ( )
}
2021-05-03 14:51:07 +00:00
currentValue = valReflected
2021-04-24 04:27:45 +00:00
2021-03-21 20:38:37 +00:00
switch field . Kind ( ) {
case reflect . Int :
2021-04-24 04:27:45 +00:00
field . SetInt ( valReflected . Int ( ) )
2021-03-21 20:38:37 +00:00
default :
2021-04-24 04:27:45 +00:00
field . Set ( valReflected )
2021-03-21 20:38:37 +00:00
}
2021-05-03 14:51:07 +00:00
currentField = reflect . StructField { }
currentValue = reflect . Value { }
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
return result . Interface ( ) , true
2021-03-21 20:38:37 +00:00
}
func ( it * StructQueryIterator ) Close ( ) {
it . rows . Close ( )
}
2021-03-31 03:55:19 +00:00
func ( it * StructQueryIterator ) ToSlice ( ) [ ] interface { } {
defer it . Close ( )
var result [ ] interface { }
for {
row , ok := it . Next ( )
if ! ok {
2021-04-11 21:46:06 +00:00
err := it . rows . Err ( )
if err != nil {
panic ( oops . New ( err , "error while iterating through db results" ) )
}
2021-03-31 03:55:19 +00:00
break
}
result = append ( result , row )
2021-03-22 03:07:18 +00:00
}
2021-03-31 03:55:19 +00:00
return result
}
2021-03-22 03:07:18 +00:00
2021-05-03 14:51:07 +00:00
func followPathThroughStructs ( structPtrVal reflect . Value , path [ ] int ) ( reflect . Value , reflect . StructField ) {
2021-03-31 03:55:19 +00:00
if len ( path ) < 1 {
2021-05-03 14:51:07 +00:00
panic ( oops . New ( nil , "can't follow an empty path" ) )
2021-03-22 03:07:18 +00:00
}
2021-05-03 14:51:07 +00:00
if structPtrVal . Kind ( ) != reflect . Ptr || structPtrVal . Elem ( ) . Kind ( ) != reflect . Struct {
panic ( oops . New ( nil , "structPtrVal must be a pointer to a struct; got value of type %s" , structPtrVal . Type ( ) ) )
}
// more informative panic recovery
var field reflect . StructField
defer func ( ) {
if r := recover ( ) ; r != nil {
panic ( oops . New ( nil , "panic at field '%s': %v" , field . Name , r ) )
}
} ( )
val := structPtrVal
2021-03-31 03:55:19 +00:00
for _ , i := range path {
if val . Kind ( ) == reflect . Ptr && val . Type ( ) . Elem ( ) . Kind ( ) == reflect . Struct {
if val . IsNil ( ) {
2021-05-03 14:51:07 +00:00
val . Set ( reflect . New ( val . Type ( ) . Elem ( ) ) )
2021-03-31 03:55:19 +00:00
}
val = val . Elem ( )
2021-03-21 20:38:37 +00:00
}
2021-05-03 14:51:07 +00:00
field = val . Type ( ) . Field ( i )
2021-03-31 03:55:19 +00:00
val = val . Field ( i )
}
2021-05-03 14:51:07 +00:00
return val , field
2021-03-31 03:55:19 +00:00
}
2021-07-22 04:42:34 +00:00
func Query ( ctx context . Context , conn ConnOrTx , destExample interface { } , query string , args ... interface { } ) ( * StructQueryIterator , error ) {
2021-03-31 03:55:19 +00:00
destType := reflect . TypeOf ( destExample )
columnNames , fieldPaths , err := getColumnNamesAndPaths ( destType , nil , "" )
if err != nil {
2021-04-11 21:46:06 +00:00
return nil , oops . New ( err , "failed to generate column names" )
2021-03-21 20:38:37 +00:00
}
columnNamesString := strings . Join ( columnNames , ", " )
query = strings . Replace ( query , "$columns" , columnNamesString , - 1 )
rows , err := conn . Query ( ctx , query , args ... )
if err != nil {
2021-03-31 03:55:19 +00:00
if errors . Is ( err , context . DeadlineExceeded ) {
panic ( "query exceeded its deadline" )
}
2021-04-11 21:46:06 +00:00
return nil , err
2021-03-21 20:38:37 +00:00
}
2021-07-23 16:33:53 +00:00
it := & StructQueryIterator {
2021-03-31 03:55:19 +00:00
fieldPaths : fieldPaths ,
rows : rows ,
destType : destType ,
2021-07-23 16:33:53 +00:00
}
// Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold
// open connections even after a request is cancelled, causing the app to deadlock.
go func ( ) {
done := ctx . Done ( )
if done == nil {
return
}
<- done
it . Close ( )
} ( )
return it , nil
2021-03-21 20:38:37 +00:00
}
2021-05-03 14:51:07 +00:00
func getColumnNamesAndPaths ( destType reflect . Type , pathSoFar [ ] int , prefix string ) ( names [ ] string , paths [ ] [ ] int , err error ) {
2021-03-31 03:55:19 +00:00
var columnNames [ ] string
var fieldPaths [ ] [ ] int
if destType . Kind ( ) == reflect . Ptr {
destType = destType . Elem ( )
}
if destType . Kind ( ) != reflect . Struct {
return nil , nil , oops . New ( nil , "can only get column names and paths from a struct, got type '%v' (at prefix '%v')" , destType . Name ( ) , prefix )
}
for i := 0 ; i < destType . NumField ( ) ; i ++ {
field := destType . Field ( i )
path := append ( pathSoFar , i )
if columnName := field . Tag . Get ( "db" ) ; columnName != "" {
fieldType := field . Type
2021-05-03 14:51:07 +00:00
if fieldType . Kind ( ) == reflect . Ptr {
fieldType = fieldType . Elem ( )
2021-03-31 03:55:19 +00:00
}
2021-05-03 14:51:07 +00:00
if typeIsQueryable ( fieldType ) {
columnNames = append ( columnNames , prefix + columnName )
fieldPaths = append ( fieldPaths , path )
} else if fieldType . Kind ( ) == reflect . Struct {
2021-03-31 03:55:19 +00:00
subCols , subPaths , err := getColumnNamesAndPaths ( fieldType , path , columnName + "." )
if err != nil {
return nil , nil , err
}
columnNames = append ( columnNames , subCols ... )
fieldPaths = append ( fieldPaths , subPaths ... )
} else {
2021-05-03 14:51:07 +00:00
return nil , nil , oops . New ( nil , "field '%s' in type %s has invalid type '%s'" , field . Name , destType , field . Type )
2021-03-31 03:55:19 +00:00
}
}
}
return columnNames , fieldPaths , nil
}
2021-03-21 20:38:37 +00:00
var ErrNoMatchingRows = errors . New ( "no matching rows" )
2021-07-22 04:42:34 +00:00
func QueryOne ( ctx context . Context , conn ConnOrTx , destExample interface { } , query string , args ... interface { } ) ( interface { } , error ) {
2021-03-31 03:55:19 +00:00
rows , err := Query ( ctx , conn , destExample , query , args ... )
2021-03-21 20:38:37 +00:00
if err != nil {
2021-03-31 03:55:19 +00:00
return nil , err
2021-03-21 20:38:37 +00:00
}
defer rows . Close ( )
2021-03-31 03:55:19 +00:00
result , hasRow := rows . Next ( )
2021-03-21 20:38:37 +00:00
if ! hasRow {
2021-03-31 03:55:19 +00:00
return nil , ErrNoMatchingRows
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
return result , nil
2021-03-21 20:38:37 +00:00
}
2021-04-25 19:33:22 +00:00
2021-07-22 04:42:34 +00:00
func QueryScalar ( ctx context . Context , conn ConnOrTx , query string , args ... interface { } ) ( interface { } , error ) {
2021-04-25 19:33:22 +00:00
rows , err := conn . Query ( ctx , query , args ... )
if err != nil {
return nil , err
}
defer rows . Close ( )
if rows . Next ( ) {
vals , err := rows . Values ( )
if err != nil {
panic ( err )
}
if len ( vals ) != 1 {
return nil , oops . New ( nil , "you must query exactly one field with QueryScalar, not %v" , len ( vals ) )
}
return vals [ 0 ] , nil
}
return nil , ErrNoMatchingRows
}
2021-07-22 04:42:34 +00:00
func QueryInt ( ctx context . Context , conn ConnOrTx , query string , args ... interface { } ) ( int , error ) {
2021-04-25 19:33:22 +00:00
result , err := QueryScalar ( ctx , conn , query , args ... )
if err != nil {
return 0 , err
}
switch r := result . ( type ) {
case int :
return r , nil
case int32 :
return int ( r ) , nil
case int64 :
return int ( r ) , nil
default :
return 0 , oops . New ( nil , "QueryInt got a non-int result: %v" , result )
}
}
2021-07-22 04:42:34 +00:00
func QueryBool ( ctx context . Context , conn ConnOrTx , query string , args ... interface { } ) ( bool , error ) {
result , err := QueryScalar ( ctx , conn , query , args ... )
if err != nil {
return false , err
}
switch r := result . ( type ) {
case bool :
return r , nil
default :
return false , oops . New ( nil , "QueryBool got a non-bool result: %v" , result )
}
}