Start improvements to db API

Allowing for querying of scalars with db.Query, and a choice of prefix
in $columns to remove the need for single-field structs in many queries.
This commit is contained in:
Ben Visness 2022-03-27 17:20:35 -05:00
parent e74b3273cb
commit bd6c95203c
5 changed files with 217 additions and 92 deletions

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"git.handmade.network/hmn/hmn/src/config"
@ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
return conn
}
type columnName []string
// A path to a particular field in query's destination type. Each index in the slice
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
type fieldPath []int
type StructQueryIterator[T any] struct {
fieldPaths [][]int
fieldPaths []fieldPath
rows pgx.Rows
destType reflect.Type
closed chan struct{}
@ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
var destExample T
destType := reflect.TypeOf(destExample)
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
if err != nil {
return nil, oops.New(err, "failed to generate column names")
}
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
compiled := compileQuery(query, destType)
columnNamesString := strings.Join(columns, ", ")
query = strings.Replace(query, "$columns", columnNamesString, -1)
rows, err := conn.Query(ctx, query, args...)
rows, err := conn.Query(ctx, compiled.query, args...)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
panic("query exceeded its deadline")
@ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
}
it := &StructQueryIterator[T]{
fieldPaths: fieldPaths,
fieldPaths: compiled.fieldPaths,
rows: rows,
destType: destType,
destType: compiled.destType,
closed: make(chan struct{}, 1),
}
@ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
return it, nil
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
var columnNames [][]string
var fieldPaths [][]int
type compiledQuery struct {
query string
destType reflect.Type
fieldPaths []fieldPath
}
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
func compileQuery(query string, destType reflect.Type) compiledQuery {
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
hasColumnsPlaceholder := columnsMatch != nil
if hasColumnsPlaceholder {
// The presence of the $columns placeholder means that the destination type
// must be a struct, and we will plonk that struct's fields into the query.
if destType.Kind() != reflect.Struct {
panic("$columns can only be used when querying into some kind of struct")
}
var prefix []string
prefixText := columnsMatch[2]
if prefixText != "" {
prefix = []string{prefixText}
}
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
columns := make([]string, 0, len(columnNames))
for _, strSlice := range columnNames {
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
fullName := strSlice[len(strSlice)-1]
if tableName != "" {
fullName = tableName + "." + fullName
}
columns = append(columns, fullName)
}
columnNamesString := strings.Join(columns, ", ")
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
return compiledQuery{
query: query,
destType: destType,
fieldPaths: fieldPaths,
}
} else {
return compiledQuery{
query: query,
destType: destType,
}
}
}
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
var columnNames []columnName
var fieldPaths []fieldPath
if destType.Kind() == reflect.Ptr {
destType = destType.Elem()
}
if destType.Kind() != reflect.Struct {
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
}
type AnonPrefix struct {
@ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
columnNames = append(columnNames, fieldColumnNames)
fieldPaths = append(fieldPaths, path)
} else if fieldType.Kind() == reflect.Struct {
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
if err != nil {
return nil, nil, err
}
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
columnNames = append(columnNames, subCols...)
fieldPaths = append(fieldPaths, subPaths...)
} else {
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
}
}
}
return columnNames, fieldPaths, nil
return columnNames, fieldPaths
}
/*

View File

@ -10,52 +10,143 @@ import (
func TestPaths(t *testing.T) {
type CustomInt int
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
B bool `db:"B"`
PB *bool `db:"PB"`
type S2 struct {
B bool `db:"B"` // field 0
PB *bool `db:"PB"` // field 1
NoTag int
NoTag string // field 2
}
type S struct {
I int `db:"I"` // field 0
PI *int `db:"PI"` // field 1
CI CustomInt `db:"CI"` // field 2
PCI *CustomInt `db:"PCI"` // field 3
S2 `db:"S2"` // field 4 (embedded!)
PS2 *S2 `db:"PS2"` // field 5
NoTag int // field 6
}
type Nested struct {
S S `db:"S"`
PS *S `db:"PS"`
S S `db:"S"` // field 0
PS *S `db:"PS"` // field 1
NoTag S
NoTag S // field 2
}
type Embedded struct {
NoTag S
Nested
NoTag S // field 0
Nested // field 1
}
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
if assert.Nil(t, err) {
assert.Equal(t, []string{
"S.I", "S.PI",
"S.CI", "S.PCI",
"S.B", "S.PB",
"PS.I", "PS.PI",
"PS.CI", "PS.PCI",
"PS.B", "PS.PB",
}, names)
assert.Equal(t, [][]int{
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
}, paths)
assert.True(t, len(names) == len(paths))
}
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
assert.Equal(t, []columnName{
{"S", "I"}, {"S", "PI"},
{"S", "CI"}, {"S", "PCI"},
{"S", "S2", "B"}, {"S", "S2", "PB"},
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
{"PS", "I"}, {"PS", "PI"},
{"PS", "CI"}, {"PS", "PCI"},
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
}, names)
assert.Equal(t, []fieldPath{
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
}, paths)
assert.True(t, len(names) == len(paths))
testStruct := Embedded{}
for i, path := range paths {
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
assert.True(t, val.IsValid())
assert.True(t, strings.Contains(names[i], field.Name))
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
}
}
func TestCompileQuery(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
})
t.Run("complex structs", func(t *testing.T) {
type CustomInt int
type S2 struct {
B bool `db:"B"`
PB *bool `db:"PB"`
NoTag string
}
type S struct {
I int `db:"I"`
PI *int `db:"PI"`
CI CustomInt `db:"CI"`
PCI *CustomInt `db:"PCI"`
S2 `db:"S2"` // embedded!
PS2 *S2 `db:"PS2"`
NoTag int
}
type Nested struct {
S S `db:"S"`
PS *S `db:"PS"`
NoTag S
}
type Dest struct {
NoTag S
Nested
}
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
})
t.Run("int", func(t *testing.T) {
type Dest int
// There should be no error here because we do not need to extract columns from
// the destination type. There may be errors down the line in value iteration, but
// that is always the case if the Go types don't match the query.
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
})
t.Run("just one table", func(t *testing.T) {
type Dest struct {
Foo int `db:"foo"`
Bar bool `db:"bar"`
Nope string // no tag
}
// The prefix is necessary because otherwise we would have to provide a struct with
// a db tag in order to provide the query with the `greeblies.` prefix in the
// final query. This comes up a lot when we do a JOIN to help with a condition, but
// don't actually care about any of the data we joined to.
compiled := compileQuery(
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
reflect.TypeOf(Dest{}),
)
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
})
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
type Dest int
assert.Panics(t, func() {
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
})
})
}
func TestQueryBuilder(t *testing.T) {
t.Run("happy time", func(t *testing.T) {
var qb QueryBuilder

View File

@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
userIds = append(userIds, u.User.ID)
}
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
}
for _, ul := range userLinks {
link := ul.(*models.Link)
for _, link := range userLinks {
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
}
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE auth_user.id = $1
`,
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
}
}
user := u.(*userQuery).User
user := u.User
whatHappened := ""
if action == ApprovalQueueActionApprove {
@ -337,7 +337,7 @@ type UnapprovedPost struct {
}
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
if err != nil {
return nil, oops.New(err, "failed to fetch unapproved posts")
}
var res []*UnapprovedPost
for _, iresult := range it {
res = append(res, iresult.(*UnapprovedPost))
}
return res, nil
}
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
type unapprovedUser struct {
ID int `db:"id"`
}
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
if err != nil {
return nil, oops.New(err, "failed to fetch unapproved users")
}
ownerIDs := make([]int, 0, len(it))
for _, uid := range it {
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
ownerIDs := make([]int, 0, len(uids))
for _, uid := range uids {
ownerIDs = append(ownerIDs, uid.ID)
}
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{

View File

@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
}
} else {
canonicalUsername = userResult.(*userQuery).User.Username
canonicalUsername = userResult.User.Username
}
}

View File

@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
FROM
auth_user
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
WHERE LOWER(username) = LOWER($1)
`,
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
}
}
user := &userRow.(*userQuery).User
user := &userRow.User
success, err := tryLogin(c, user, password)
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
type userQuery struct {
User models.User `db:"auth_user"`
}
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
}
}
if userRow != nil {
user = &userRow.(*userQuery).User
user = &userRow.User
}
if user != nil {
c.Perf.StartBlock("SQL", "Fetching existing token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
`
SELECT $columns
FROM handmade_onetimetoken
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
}
}
var resetToken *models.OneTimeToken
if tokenRow != nil {
resetToken = tokenRow.(*models.OneTimeToken)
}
now := time.Now()
if resetToken != nil {
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if resetToken == nil {
c.Perf.StartBlock("SQL", "Creating new token")
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
`
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
VALUES ($1, $2, $3, $4, $5)
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
if err != nil {
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
}
resetToken = tokenRow.(*models.OneTimeToken)
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
if err != nil {
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
User models.User `db:"auth_user"`
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
}
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
`
SELECT $columns
FROM auth_user
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
return result
}
}
if row != nil {
data := row.(*userAndTokenQuery)
if data != nil {
result.User = &data.User
result.OneTimeToken = data.OneTimeToken
if result.OneTimeToken != nil {