Start improvements to db API

Allowing for querying of scalars with db.Query, and a choice of prefix
in $columns to remove the need for single-field structs in many queries.
This commit is contained in:
Ben Visness 2022-03-27 17:20:35 -05:00
parent e74b3273cb
commit bd6c95203c
5 changed files with 217 additions and 92 deletions

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings" "strings"
"git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/config"
@ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn return conn
} }
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type StructQueryIterator[T any] struct { type StructQueryIterator[T any] struct {
fieldPaths [][]int fieldPaths []fieldPath
rows pgx.Rows rows pgx.Rows
destType reflect.Type destType reflect.Type
closed chan struct{} closed chan struct{}
@ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
var destExample T var destExample T
destType := reflect.TypeOf(destExample) destType := reflect.TypeOf(destExample)
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
if err != nil {
return nil, oops.New(err, "failed to generate column names")
}
columns := make([]string, 0, len(columnNames)) compiled := compileQuery(query, destType)
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ") rows, err := conn.Query(ctx, compiled.query, args...)
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
panic("query exceeded its deadline") panic("query exceeded its deadline")
@ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
} }
it := &StructQueryIterator[T]{ it := &StructQueryIterator[T]{
fieldPaths: fieldPaths, fieldPaths: compiled.fieldPaths,
rows: rows, rows: rows,
destType: destType, destType: compiled.destType,
closed: make(chan struct{}, 1), closed: make(chan struct{}, 1),
} }
@ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return it, nil return it, nil
} }
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { type compiledQuery struct {
var columnNames [][]string query string
var fieldPaths [][]int destType reflect.Type
fieldPaths []fieldPath
}
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
func compileQuery(query string, destType reflect.Type) compiledQuery {
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
hasColumnsPlaceholder := columnsMatch != nil
if hasColumnsPlaceholder {
// The presence of the $columns placeholder means that the destination type
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
}
var prefix []string
prefixText := columnsMatch[2]
if prefixText != "" {
prefix = []string{prefixText}
}
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ")
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
return compiledQuery{
query: query,
destType: destType,
fieldPaths: fieldPaths,
}
} else {
return compiledQuery{
query: query,
destType: destType,
}
}
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
var columnNames []columnName
var fieldPaths []fieldPath
if destType.Kind() == reflect.Ptr { if destType.Kind() == reflect.Ptr {
destType = destType.Elem() destType = destType.Elem()
} }
if destType.Kind() != reflect.Struct { if destType.Kind() != reflect.Struct {
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
} }
type AnonPrefix struct { type AnonPrefix struct {
@ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
columnNames = append(columnNames, fieldColumnNames) columnNames = append(columnNames, fieldColumnNames)
fieldPaths = append(fieldPaths, path) fieldPaths = append(fieldPaths, path)
} else if fieldType.Kind() == reflect.Struct { } else if fieldType.Kind() == reflect.Struct {
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
if err != nil {
return nil, nil, err
}
columnNames = append(columnNames, subCols...) columnNames = append(columnNames, subCols...)
fieldPaths = append(fieldPaths, subPaths...) fieldPaths = append(fieldPaths, subPaths...)
} else { } else {
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
} }
} }
} }
return columnNames, fieldPaths, nil return columnNames, fieldPaths
} }
/* /*

View File

@ -10,13 +10,90 @@ import (
func TestPaths(t *testing.T) { func TestPaths(t *testing.T) {
type CustomInt int type CustomInt int
type S2 struct {
B bool `db:"B"` // field 0
PB *bool `db:"PB"` // field 1
NoTag string // field 2
}
type S struct {
I int `db:"I"` // field 0
PI *int `db:"PI"` // field 1
CI CustomInt `db:"CI"` // field 2
PCI *CustomInt `db:"PCI"` // field 3
S2 `db:"S2"` // field 4 (embedded!)
PS2 *S2 `db:"PS2"` // field 5
NoTag int // field 6
}
type Nested struct {
S S `db:"S"` // field 0
PS *S `db:"PS"` // field 1
NoTag S // field 2
}
type Embedded struct {
NoTag S // field 0
Nested // field 1
}
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
assert.Equal(t, []columnName{
{"S", "I"}, {"S", "PI"},
{"S", "CI"}, {"S", "PCI"},
{"S", "S2", "B"}, {"S", "S2", "PB"},
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
{"PS", "I"}, {"PS", "PI"},
{"PS", "CI"}, {"PS", "PCI"},
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
}, names)
assert.Equal(t, []fieldPath{
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
}, paths)
assert.True(t, len(names) == len(paths))
testStruct := Embedded{}
for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
}
}
func TestCompileQuery(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
})
t.Run("complex structs", func(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag string
}
type S struct { type S struct {
I int `db:"I"` I int `db:"I"`
PI *int `db:"PI"` PI *int `db:"PI"`
CI CustomInt `db:"CI"` CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"` PCI *CustomInt `db:"PCI"`
B bool `db:"B"` S2 `db:"S2"` // embedded!
PB *bool `db:"PB"` PS2 *S2 `db:"PS2"`
NoTag int NoTag int
} }
@ -26,34 +103,48 @@ func TestPaths(t *testing.T) {
NoTag S NoTag S
} }
type Embedded struct { type Dest struct {
NoTag S NoTag S
Nested Nested
} }
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
if assert.Nil(t, err) { assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
assert.Equal(t, []string{ })
"S.I", "S.PI", t.Run("int", func(t *testing.T) {
"S.CI", "S.PCI", type Dest int
"S.B", "S.PB",
"PS.I", "PS.PI", // There should be no error here because we do not need to extract columns from
"PS.CI", "PS.PCI", // the destination type. There may be errors down the line in value iteration, but
"PS.B", "PS.PB", // that is always the case if the Go types don't match the query.
}, names) compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
assert.Equal(t, [][]int{ assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, })
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, t.Run("just one table", func(t *testing.T) {
}, paths) type Dest struct {
assert.True(t, len(names) == len(paths)) Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
} }
testStruct := Embedded{} // The prefix is necessary because otherwise we would have to provide a struct with
for i, path := range paths { // a db tag in order to provide the query with the `greeblies.` prefix in the
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) // final query. This comes up a lot when we do a JOIN to help with a condition, but
assert.True(t, val.IsValid()) // don't actually care about any of the data we joined to.
assert.True(t, strings.Contains(names[i], field.Name)) compiled := compileQuery(
} "SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
reflect.TypeOf(Dest{}),
)
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
})
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
type Dest int
assert.Panics(t, func() {
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
})
})
} }
func TestQueryBuilder(t *testing.T) { func TestQueryBuilder(t *testing.T) {

View File

@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID) userIds = append(userIds, u.User.ID)
} }
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
} }
for _, ul := range userLinks { for _, link := range userLinks {
link := ul.(*models.Link)
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
} }
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE auth_user.id = $1 WHERE auth_user.id = $1
`, `,
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
} }
} }
user := u.(*userQuery).User user := u.User
whatHappened := "" whatHappened := ""
if action == ApprovalQueueActionApprove { if action == ApprovalQueueActionApprove {
@ -337,7 +337,7 @@ type UnapprovedPost struct {
} }
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch unapproved posts") return nil, oops.New(err, "failed to fetch unapproved posts")
} }
var res []*UnapprovedPost
for _, iresult := range it {
res = append(res, iresult.(*UnapprovedPost))
}
return res, nil return res, nil
} }
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
type unapprovedUser struct { type unapprovedUser struct {
ID int `db:"id"` ID int `db:"id"`
} }
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
if err != nil { if err != nil {
return nil, oops.New(err, "failed to fetch unapproved users") return nil, oops.New(err, "failed to fetch unapproved users")
} }
ownerIDs := make([]int, 0, len(it)) ownerIDs := make([]int, 0, len(uids))
for _, uid := range it { for _, uid := range uids {
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) ownerIDs = append(ownerIDs, uid.ID)
} }
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{

View File

@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM FROM
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
} }
} else { } else {
canonicalUsername = userResult.(*userQuery).User.Username canonicalUsername = userResult.User.Username
} }
} }

View File

@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE LOWER(username) = LOWER($1) WHERE LOWER(username) = LOWER($1)
`, `,
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
} }
} }
user := &userRow.(*userQuery).User user := &userRow.User
success, err := tryLogin(c, user, password) success, err := tryLogin(c, user, password)
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
type userQuery struct { type userQuery struct {
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
} }
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM auth_user
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
} }
} }
if userRow != nil { if userRow != nil {
user = &userRow.(*userQuery).User user = &userRow.User
} }
if user != nil { if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token") c.Perf.StartBlock("SQL", "Fetching existing token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM handmade_onetimetoken FROM handmade_onetimetoken
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
} }
} }
var resetToken *models.OneTimeToken
if tokenRow != nil {
resetToken = tokenRow.(*models.OneTimeToken)
}
now := time.Now() now := time.Now()
if resetToken != nil { if resetToken != nil {
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil { if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token") c.Perf.StartBlock("SQL", "Creating new token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
` `
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if err != nil { if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
} }
resetToken = tokenRow.(*models.OneTimeToken)
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
if err != nil { if err != nil {
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"auth_user"` User models.User `db:"auth_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
} }
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
` `
SELECT $columns SELECT $columns
FROM auth_user FROM auth_user
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
return result return result
} }
} }
if row != nil { if data != nil {
data := row.(*userAndTokenQuery)
result.User = &data.User result.User = &data.User
result.OneTimeToken = data.OneTimeToken result.OneTimeToken = data.OneTimeToken
if result.OneTimeToken != nil { if result.OneTimeToken != nil {