2021-03-09 08:05:07 +00:00
package db
import (
"context"
2021-03-21 20:38:37 +00:00
"errors"
"reflect"
"strings"
2021-03-09 08:05:07 +00:00
2021-03-11 03:39:24 +00:00
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/oops"
2021-03-31 03:55:19 +00:00
"github.com/jackc/pgtype"
2021-03-09 08:05:07 +00:00
"github.com/jackc/pgx/v4"
2021-03-21 20:38:37 +00:00
"github.com/jackc/pgx/v4/log/zerologadapter"
2021-03-09 08:05:07 +00:00
"github.com/jackc/pgx/v4/pgxpool"
2021-03-21 20:38:37 +00:00
"github.com/rs/zerolog/log"
2021-03-09 08:05:07 +00:00
)
2021-03-31 03:55:19 +00:00
var connInfo = pgtype . NewConnInfo ( )
2021-03-09 08:05:07 +00:00
func NewConn ( ) * pgx . Conn {
conn , err := pgx . Connect ( context . Background ( ) , config . Config . Postgres . DSN ( ) )
if err != nil {
2021-03-11 03:19:39 +00:00
panic ( oops . New ( err , "failed to connect to database" ) )
2021-03-09 08:05:07 +00:00
}
return conn
}
func NewConnPool ( minConns , maxConns int32 ) * pgxpool . Pool {
2021-03-21 20:38:37 +00:00
cfg , err := pgxpool . ParseConfig ( config . Config . Postgres . DSN ( ) )
2021-03-09 08:05:07 +00:00
2021-03-21 20:38:37 +00:00
cfg . MinConns = minConns
cfg . MaxConns = maxConns
cfg . ConnConfig . Logger = zerologadapter . NewLogger ( log . Logger )
cfg . ConnConfig . LogLevel = config . Config . Postgres . LogLevel
2021-03-09 08:05:07 +00:00
2021-03-21 20:38:37 +00:00
conn , err := pgxpool . ConnectConfig ( context . Background ( ) , cfg )
2021-03-09 08:05:07 +00:00
if err != nil {
2021-03-11 03:19:39 +00:00
panic ( oops . New ( err , "failed to create database connection pool" ) )
2021-03-09 08:05:07 +00:00
}
return conn
}
2021-03-21 20:38:37 +00:00
type StructQueryIterator struct {
2021-03-31 03:55:19 +00:00
fieldPaths [ ] [ ] int
rows pgx . Rows
destType reflect . Type
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
func ( it * StructQueryIterator ) Next ( ) ( interface { } , bool ) {
2021-03-21 20:38:37 +00:00
hasNext := it . rows . Next ( )
if ! hasNext {
2021-03-31 03:55:19 +00:00
return nil , false
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
result := reflect . New ( it . destType )
2021-03-21 20:38:37 +00:00
vals , err := it . rows . Values ( )
if err != nil {
panic ( err )
}
for i , val := range vals {
2021-03-31 03:55:19 +00:00
if val == nil {
continue
}
field := followPathThroughStructs ( result , it . fieldPaths [ i ] )
if field . Kind ( ) == reflect . Ptr {
field . Set ( reflect . New ( field . Type ( ) . Elem ( ) ) )
field = field . Elem ( )
}
2021-03-21 20:38:37 +00:00
switch field . Kind ( ) {
case reflect . Int :
field . SetInt ( reflect . ValueOf ( val ) . Int ( ) )
default :
field . Set ( reflect . ValueOf ( val ) )
}
}
2021-03-31 03:55:19 +00:00
return result . Interface ( ) , true
2021-03-21 20:38:37 +00:00
}
func ( it * StructQueryIterator ) Close ( ) {
it . rows . Close ( )
}
2021-03-31 03:55:19 +00:00
func ( it * StructQueryIterator ) ToSlice ( ) [ ] interface { } {
defer it . Close ( )
var result [ ] interface { }
for {
row , ok := it . Next ( )
if ! ok {
2021-04-11 21:46:06 +00:00
err := it . rows . Err ( )
if err != nil {
panic ( oops . New ( err , "error while iterating through db results" ) )
}
2021-03-31 03:55:19 +00:00
break
}
result = append ( result , row )
2021-03-22 03:07:18 +00:00
}
2021-03-31 03:55:19 +00:00
return result
}
2021-03-22 03:07:18 +00:00
2021-03-31 03:55:19 +00:00
func followPathThroughStructs ( structVal reflect . Value , path [ ] int ) reflect . Value {
if len ( path ) < 1 {
panic ( "can't follow an empty path" )
2021-03-22 03:07:18 +00:00
}
2021-03-31 03:55:19 +00:00
val := structVal
for _ , i := range path {
if val . Kind ( ) == reflect . Ptr && val . Type ( ) . Elem ( ) . Kind ( ) == reflect . Struct {
if val . IsNil ( ) {
val . Set ( reflect . New ( val . Type ( ) ) )
}
val = val . Elem ( )
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
val = val . Field ( i )
}
return val
}
2021-04-11 21:46:06 +00:00
func Query ( ctx context . Context , conn * pgxpool . Pool , destExample interface { } , query string , args ... interface { } ) ( * StructQueryIterator , error ) {
2021-03-31 03:55:19 +00:00
destType := reflect . TypeOf ( destExample )
columnNames , fieldPaths , err := getColumnNamesAndPaths ( destType , nil , "" )
if err != nil {
2021-04-11 21:46:06 +00:00
return nil , oops . New ( err , "failed to generate column names" )
2021-03-21 20:38:37 +00:00
}
columnNamesString := strings . Join ( columnNames , ", " )
query = strings . Replace ( query , "$columns" , columnNamesString , - 1 )
rows , err := conn . Query ( ctx , query , args ... )
if err != nil {
2021-03-31 03:55:19 +00:00
if errors . Is ( err , context . DeadlineExceeded ) {
panic ( "query exceeded its deadline" )
}
2021-04-11 21:46:06 +00:00
return nil , err
2021-03-21 20:38:37 +00:00
}
2021-04-11 21:46:06 +00:00
return & StructQueryIterator {
2021-03-31 03:55:19 +00:00
fieldPaths : fieldPaths ,
rows : rows ,
destType : destType ,
2021-03-21 20:38:37 +00:00
} , nil
}
2021-03-31 03:55:19 +00:00
func getColumnNamesAndPaths ( destType reflect . Type , pathSoFar [ ] int , prefix string ) ( [ ] string , [ ] [ ] int , error ) {
var columnNames [ ] string
var fieldPaths [ ] [ ] int
if destType . Kind ( ) == reflect . Ptr {
destType = destType . Elem ( )
}
if destType . Kind ( ) != reflect . Struct {
return nil , nil , oops . New ( nil , "can only get column names and paths from a struct, got type '%v' (at prefix '%v')" , destType . Name ( ) , prefix )
}
for i := 0 ; i < destType . NumField ( ) ; i ++ {
field := destType . Field ( i )
path := append ( pathSoFar , i )
if columnName := field . Tag . Get ( "db" ) ; columnName != "" {
fieldType := field . Type
if destType . Kind ( ) == reflect . Ptr {
fieldType = destType . Elem ( )
}
_ , isRecognizedByPgtype := connInfo . DataTypeForValue ( reflect . New ( fieldType ) ) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
if fieldType . Kind ( ) == reflect . Struct && ! isRecognizedByPgtype {
subCols , subPaths , err := getColumnNamesAndPaths ( fieldType , path , columnName + "." )
if err != nil {
return nil , nil , err
}
columnNames = append ( columnNames , subCols ... )
fieldPaths = append ( fieldPaths , subPaths ... )
} else {
columnNames = append ( columnNames , prefix + columnName )
fieldPaths = append ( fieldPaths , path )
}
}
}
return columnNames , fieldPaths , nil
}
2021-03-21 20:38:37 +00:00
var ErrNoMatchingRows = errors . New ( "no matching rows" )
2021-03-31 03:55:19 +00:00
func QueryOne ( ctx context . Context , conn * pgxpool . Pool , destExample interface { } , query string , args ... interface { } ) ( interface { } , error ) {
rows , err := Query ( ctx , conn , destExample , query , args ... )
2021-03-21 20:38:37 +00:00
if err != nil {
2021-03-31 03:55:19 +00:00
return nil , err
2021-03-21 20:38:37 +00:00
}
defer rows . Close ( )
2021-03-31 03:55:19 +00:00
result , hasRow := rows . Next ( )
2021-03-21 20:38:37 +00:00
if ! hasRow {
2021-03-31 03:55:19 +00:00
return nil , ErrNoMatchingRows
2021-03-21 20:38:37 +00:00
}
2021-03-31 03:55:19 +00:00
return result , nil
2021-03-21 20:38:37 +00:00
}