diff --git a/src/db/db.go b/src/db/db.go index 55b54b90..2048a427 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "regexp" "strings" "git.handmade.network/hmn/hmn/src/config" @@ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { return conn } +type columnName []string + +// A path to a particular field in query's destination type. Each index in the slice +// corresponds to a field index for use with Field on a reflect.Type or reflect.Value. +type fieldPath []int + type StructQueryIterator[T any] struct { - fieldPaths [][]int + fieldPaths []fieldPath rows pgx.Rows destType reflect.Type closed chan struct{} @@ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { var destExample T destType := reflect.TypeOf(destExample) - columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil) - if err != nil { - return nil, oops.New(err, "failed to generate column names") - } - columns := make([]string, 0, len(columnNames)) - for _, strSlice := range columnNames { - tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") - fullName := strSlice[len(strSlice)-1] - if tableName != "" { - fullName = tableName + "." + fullName - } - columns = append(columns, fullName) - } + compiled := compileQuery(query, destType) - columnNamesString := strings.Join(columns, ", ") - query = strings.Replace(query, "$columns", columnNamesString, -1) - - rows, err := conn.Query(ctx, query, args...) + rows, err := conn.Query(ctx, compiled.query, args...) if err != nil { if errors.Is(err, context.DeadlineExceeded) { panic("query exceeded its deadline") @@ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args } it := &StructQueryIterator[T]{ - fieldPaths: fieldPaths, + fieldPaths: compiled.fieldPaths, rows: rows, - destType: destType, + destType: compiled.destType, closed: make(chan struct{}, 1), } @@ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args return it, nil } -func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { - var columnNames [][]string - var fieldPaths [][]int +type compiledQuery struct { + query string + destType reflect.Type + fieldPaths []fieldPath +} + +var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`) + +func compileQuery(query string, destType reflect.Type) compiledQuery { + columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query) + hasColumnsPlaceholder := columnsMatch != nil + + if hasColumnsPlaceholder { + // The presence of the $columns placeholder means that the destination type + // must be a struct, and we will plonk that struct's fields into the query. + + if destType.Kind() != reflect.Struct { + panic("$columns can only be used when querying into some kind of struct") + } + + var prefix []string + prefixText := columnsMatch[2] + if prefixText != "" { + prefix = []string{prefixText} + } + + columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix) + + columns := make([]string, 0, len(columnNames)) + for _, strSlice := range columnNames { + tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") + fullName := strSlice[len(strSlice)-1] + if tableName != "" { + fullName = tableName + "." + fullName + } + columns = append(columns, fullName) + } + + columnNamesString := strings.Join(columns, ", ") + query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString) + + return compiledQuery{ + query: query, + destType: destType, + fieldPaths: fieldPaths, + } + } else { + return compiledQuery{ + query: query, + destType: destType, + } + } +} + +func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) { + var columnNames []columnName + var fieldPaths []fieldPath if destType.Kind() == reflect.Ptr { destType = destType.Elem() } if destType.Kind() != reflect.Struct { - return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) + panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)) } type AnonPrefix struct { @@ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str columnNames = append(columnNames, fieldColumnNames) fieldPaths = append(fieldPaths, path) } else if fieldType.Kind() == reflect.Struct { - subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) - if err != nil { - return nil, nil, err - } + subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) columnNames = append(columnNames, subCols...) fieldPaths = append(fieldPaths, subPaths...) } else { - return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) + panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)) } } } - return columnNames, fieldPaths, nil + return columnNames, fieldPaths } /* diff --git a/src/db/db_test.go b/src/db/db_test.go index 41b9491d..edc647ed 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -10,52 +10,143 @@ import ( func TestPaths(t *testing.T) { type CustomInt int - type S struct { - I int `db:"I"` - PI *int `db:"PI"` - CI CustomInt `db:"CI"` - PCI *CustomInt `db:"PCI"` - B bool `db:"B"` - PB *bool `db:"PB"` + type S2 struct { + B bool `db:"B"` // field 0 + PB *bool `db:"PB"` // field 1 - NoTag int + NoTag string // field 2 + } + type S struct { + I int `db:"I"` // field 0 + PI *int `db:"PI"` // field 1 + CI CustomInt `db:"CI"` // field 2 + PCI *CustomInt `db:"PCI"` // field 3 + S2 `db:"S2"` // field 4 (embedded!) + PS2 *S2 `db:"PS2"` // field 5 + + NoTag int // field 6 } type Nested struct { - S S `db:"S"` - PS *S `db:"PS"` + S S `db:"S"` // field 0 + PS *S `db:"PS"` // field 1 - NoTag S + NoTag S // field 2 } type Embedded struct { - NoTag S - Nested + NoTag S // field 0 + Nested // field 1 } - names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") - if assert.Nil(t, err) { - assert.Equal(t, []string{ - "S.I", "S.PI", - "S.CI", "S.PCI", - "S.B", "S.PB", - "PS.I", "PS.PI", - "PS.CI", "PS.PCI", - "PS.B", "PS.PB", - }, names) - assert.Equal(t, [][]int{ - {1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, - {1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, - }, paths) - assert.True(t, len(names) == len(paths)) - } + names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil) + assert.Equal(t, []columnName{ + {"S", "I"}, {"S", "PI"}, + {"S", "CI"}, {"S", "PCI"}, + {"S", "S2", "B"}, {"S", "S2", "PB"}, + {"S", "PS2", "B"}, {"S", "PS2", "PB"}, + {"PS", "I"}, {"PS", "PI"}, + {"PS", "CI"}, {"PS", "PCI"}, + {"PS", "S2", "B"}, {"PS", "S2", "PB"}, + {"PS", "PS2", "B"}, {"PS", "PS2", "PB"}, + }, names) + assert.Equal(t, []fieldPath{ + {1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI + {1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI + {1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB + {1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB + {1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI + {1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI + {1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB + {1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB + }, paths) + assert.True(t, len(names) == len(paths)) testStruct := Embedded{} for i, path := range paths { val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) assert.True(t, val.IsValid()) - assert.True(t, strings.Contains(names[i], field.Name)) + assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name)) } } +func TestCompileQuery(t *testing.T) { + t.Run("simple struct", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query) + }) + t.Run("complex structs", func(t *testing.T) { + type CustomInt int + type S2 struct { + B bool `db:"B"` + PB *bool `db:"PB"` + + NoTag string + } + type S struct { + I int `db:"I"` + PI *int `db:"PI"` + CI CustomInt `db:"CI"` + PCI *CustomInt `db:"PCI"` + S2 `db:"S2"` // embedded! + PS2 *S2 `db:"PS2"` + + NoTag int + } + type Nested struct { + S S `db:"S"` + PS *S `db:"PS"` + + NoTag S + } + type Dest struct { + NoTag S + Nested + } + + compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) + assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query) + }) + t.Run("int", func(t *testing.T) { + type Dest int + + // There should be no error here because we do not need to extract columns from + // the destination type. There may be errors down the line in value iteration, but + // that is always the case if the Go types don't match the query. + compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0))) + assert.Equal(t, "SELECT id FROM greeblies", compiled.query) + }) + t.Run("just one table", func(t *testing.T) { + type Dest struct { + Foo int `db:"foo"` + Bar bool `db:"bar"` + Nope string // no tag + } + + // The prefix is necessary because otherwise we would have to provide a struct with + // a db tag in order to provide the query with the `greeblies.` prefix in the + // final query. This comes up a lot when we do a JOIN to help with a condition, but + // don't actually care about any of the data we joined to. + compiled := compileQuery( + "SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props", + reflect.TypeOf(Dest{}), + ) + assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query) + }) + + t.Run("using $columns without a struct is not allowed", func(t *testing.T) { + type Dest int + + assert.Panics(t, func() { + compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0))) + }) + }) +} + func TestQueryBuilder(t *testing.T) { t.Run("happy time", func(t *testing.T) { var qb QueryBuilder diff --git a/src/website/admin.go b/src/website/admin.go index 0bdacd62..94ca1fd7 100644 --- a/src/website/admin.go +++ b/src/website/admin.go @@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { userIds = append(userIds, u.User.ID) } - userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, + userLinks, err := db.Query[models.Link](c.Context(), c.Conn, ` SELECT $columns FROM @@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) } - for _, ul := range userLinks { - link := ul.(*models.Link) + for _, link := range userLinks { userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) } @@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { type userQuery struct { User models.User `db:"auth_user"` } - u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + u, err := db.QueryOne[userQuery](c.Context(), c.Conn, ` SELECT $columns - FROM auth_user + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE auth_user.id = $1 `, @@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) } } - user := u.(*userQuery).User + user := u.User whatHappened := "" if action == ApprovalQueueActionApprove { @@ -337,7 +337,7 @@ type UnapprovedPost struct { } func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { - it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, + res, err := db.Query[UnapprovedPost](c.Context(), c.Conn, ` SELECT $columns FROM @@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved posts") } - var res []*UnapprovedPost - for _, iresult := range it { - res = append(res, iresult.(*UnapprovedPost)) - } return res, nil } @@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { type unapprovedUser struct { ID int `db:"id"` } - it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, + uids, err := db.Query[unapprovedUser](c.Context(), c.Conn, ` SELECT $columns FROM @@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { if err != nil { return nil, oops.New(err, "failed to fetch unapproved users") } - ownerIDs := make([]int, 0, len(it)) - for _, uid := range it { - ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) + ownerIDs := make([]int, 0, len(uids)) + for _, uid := range uids { + ownerIDs = append(ownerIDs, uid.ID) } projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ diff --git a/src/website/api.go b/src/website/api.go index f1c5b33d..8d39c22e 100644 --- a/src/website/api.go +++ b/src/website/api.go @@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData { type userQuery struct { User models.User `db:"auth_user"` } - userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn, ` SELECT $columns FROM @@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) } } else { - canonicalUsername = userResult.(*userQuery).User.Username + canonicalUsername = userResult.User.Username } } diff --git a/src/website/auth.go b/src/website/auth.go index 991f4034..4b0681f3 100644 --- a/src/website/auth.go +++ b/src/website/auth.go @@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData { type userQuery struct { User models.User `db:"auth_user"` } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn, ` SELECT $columns - FROM auth_user + FROM + auth_user LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id WHERE LOWER(username) = LOWER($1) `, @@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } } - user := &userRow.(*userQuery).User + user := &userRow.User success, err := tryLogin(c, user, password) @@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { type userQuery struct { User models.User `db:"auth_user"` } - userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, + userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn, ` SELECT $columns FROM auth_user @@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { } } if userRow != nil { - user = &userRow.(*userQuery).User + user = &userRow.User } if user != nil { c.Perf.StartBlock("SQL", "Fetching existing token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` SELECT $columns FROM handmade_onetimetoken @@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) } } - var resetToken *models.OneTimeToken - if tokenRow != nil { - resetToken = tokenRow.(*models.OneTimeToken) - } now := time.Now() if resetToken != nil { @@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if resetToken == nil { c.Perf.StartBlock("SQL", "Creating new token") - tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, + resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, ` INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) VALUES ($1, $2, $3, $4, $5) @@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { if err != nil { return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) } - resetToken = tokenRow.(*models.OneTimeToken) err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) if err != nil { @@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, User models.User `db:"auth_user"` OneTimeToken *models.OneTimeToken `db:"onetimetoken"` } - row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, + data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, ` SELECT $columns FROM auth_user @@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, return result } } - if row != nil { - data := row.(*userAndTokenQuery) + if data != nil { result.User = &data.User result.OneTimeToken = data.OneTimeToken if result.OneTimeToken != nil {