Start improvements to db API
Allowing for querying of scalars with db.Query, and a choice of prefix in $columns to remove the need for single-field structs in many queries.
This commit is contained in:
parent
e74b3273cb
commit
bd6c95203c
103
src/db/db.go
103
src/db/db.go
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
|
@ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
|||
return conn
|
||||
}
|
||||
|
||||
type columnName []string
|
||||
|
||||
// A path to a particular field in query's destination type. Each index in the slice
|
||||
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
|
||||
type fieldPath []int
|
||||
|
||||
type StructQueryIterator[T any] struct {
|
||||
fieldPaths [][]int
|
||||
fieldPaths []fieldPath
|
||||
rows pgx.Rows
|
||||
destType reflect.Type
|
||||
closed chan struct{}
|
||||
|
@ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
|
|||
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
|
||||
var destExample T
|
||||
destType := reflect.TypeOf(destExample)
|
||||
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
|
||||
if err != nil {
|
||||
return nil, oops.New(err, "failed to generate column names")
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(columnNames))
|
||||
for _, strSlice := range columnNames {
|
||||
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
||||
fullName := strSlice[len(strSlice)-1]
|
||||
if tableName != "" {
|
||||
fullName = tableName + "." + fullName
|
||||
}
|
||||
columns = append(columns, fullName)
|
||||
}
|
||||
compiled := compileQuery(query, destType)
|
||||
|
||||
columnNamesString := strings.Join(columns, ", ")
|
||||
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
||||
|
||||
rows, err := conn.Query(ctx, query, args...)
|
||||
rows, err := conn.Query(ctx, compiled.query, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
panic("query exceeded its deadline")
|
||||
|
@ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
|
|||
}
|
||||
|
||||
it := &StructQueryIterator[T]{
|
||||
fieldPaths: fieldPaths,
|
||||
fieldPaths: compiled.fieldPaths,
|
||||
rows: rows,
|
||||
destType: destType,
|
||||
destType: compiled.destType,
|
||||
closed: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
|
@ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
|
|||
return it, nil
|
||||
}
|
||||
|
||||
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
|
||||
var columnNames [][]string
|
||||
var fieldPaths [][]int
|
||||
type compiledQuery struct {
|
||||
query string
|
||||
destType reflect.Type
|
||||
fieldPaths []fieldPath
|
||||
}
|
||||
|
||||
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
|
||||
|
||||
func compileQuery(query string, destType reflect.Type) compiledQuery {
|
||||
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
|
||||
hasColumnsPlaceholder := columnsMatch != nil
|
||||
|
||||
if hasColumnsPlaceholder {
|
||||
// The presence of the $columns placeholder means that the destination type
|
||||
// must be a struct, and we will plonk that struct's fields into the query.
|
||||
|
||||
if destType.Kind() != reflect.Struct {
|
||||
panic("$columns can only be used when querying into some kind of struct")
|
||||
}
|
||||
|
||||
var prefix []string
|
||||
prefixText := columnsMatch[2]
|
||||
if prefixText != "" {
|
||||
prefix = []string{prefixText}
|
||||
}
|
||||
|
||||
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
|
||||
|
||||
columns := make([]string, 0, len(columnNames))
|
||||
for _, strSlice := range columnNames {
|
||||
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
||||
fullName := strSlice[len(strSlice)-1]
|
||||
if tableName != "" {
|
||||
fullName = tableName + "." + fullName
|
||||
}
|
||||
columns = append(columns, fullName)
|
||||
}
|
||||
|
||||
columnNamesString := strings.Join(columns, ", ")
|
||||
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
|
||||
|
||||
return compiledQuery{
|
||||
query: query,
|
||||
destType: destType,
|
||||
fieldPaths: fieldPaths,
|
||||
}
|
||||
} else {
|
||||
return compiledQuery{
|
||||
query: query,
|
||||
destType: destType,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
|
||||
var columnNames []columnName
|
||||
var fieldPaths []fieldPath
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
destType = destType.Elem()
|
||||
}
|
||||
|
||||
if destType.Kind() != reflect.Struct {
|
||||
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
|
||||
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
|
||||
}
|
||||
|
||||
type AnonPrefix struct {
|
||||
|
@ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
|
|||
columnNames = append(columnNames, fieldColumnNames)
|
||||
fieldPaths = append(fieldPaths, path)
|
||||
} else if fieldType.Kind() == reflect.Struct {
|
||||
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
||||
columnNames = append(columnNames, subCols...)
|
||||
fieldPaths = append(fieldPaths, subPaths...)
|
||||
} else {
|
||||
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
|
||||
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return columnNames, fieldPaths, nil
|
||||
return columnNames, fieldPaths
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -10,52 +10,143 @@ import (
|
|||
|
||||
func TestPaths(t *testing.T) {
|
||||
type CustomInt int
|
||||
type S struct {
|
||||
I int `db:"I"`
|
||||
PI *int `db:"PI"`
|
||||
CI CustomInt `db:"CI"`
|
||||
PCI *CustomInt `db:"PCI"`
|
||||
B bool `db:"B"`
|
||||
PB *bool `db:"PB"`
|
||||
type S2 struct {
|
||||
B bool `db:"B"` // field 0
|
||||
PB *bool `db:"PB"` // field 1
|
||||
|
||||
NoTag int
|
||||
NoTag string // field 2
|
||||
}
|
||||
type S struct {
|
||||
I int `db:"I"` // field 0
|
||||
PI *int `db:"PI"` // field 1
|
||||
CI CustomInt `db:"CI"` // field 2
|
||||
PCI *CustomInt `db:"PCI"` // field 3
|
||||
S2 `db:"S2"` // field 4 (embedded!)
|
||||
PS2 *S2 `db:"PS2"` // field 5
|
||||
|
||||
NoTag int // field 6
|
||||
}
|
||||
type Nested struct {
|
||||
S S `db:"S"`
|
||||
PS *S `db:"PS"`
|
||||
S S `db:"S"` // field 0
|
||||
PS *S `db:"PS"` // field 1
|
||||
|
||||
NoTag S
|
||||
NoTag S // field 2
|
||||
}
|
||||
type Embedded struct {
|
||||
NoTag S
|
||||
Nested
|
||||
NoTag S // field 0
|
||||
Nested // field 1
|
||||
}
|
||||
|
||||
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, []string{
|
||||
"S.I", "S.PI",
|
||||
"S.CI", "S.PCI",
|
||||
"S.B", "S.PB",
|
||||
"PS.I", "PS.PI",
|
||||
"PS.CI", "PS.PCI",
|
||||
"PS.B", "PS.PB",
|
||||
}, names)
|
||||
assert.Equal(t, [][]int{
|
||||
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
|
||||
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
|
||||
}, paths)
|
||||
assert.True(t, len(names) == len(paths))
|
||||
}
|
||||
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
|
||||
assert.Equal(t, []columnName{
|
||||
{"S", "I"}, {"S", "PI"},
|
||||
{"S", "CI"}, {"S", "PCI"},
|
||||
{"S", "S2", "B"}, {"S", "S2", "PB"},
|
||||
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
|
||||
{"PS", "I"}, {"PS", "PI"},
|
||||
{"PS", "CI"}, {"PS", "PCI"},
|
||||
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
|
||||
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
|
||||
}, names)
|
||||
assert.Equal(t, []fieldPath{
|
||||
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
|
||||
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
|
||||
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
|
||||
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
|
||||
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
|
||||
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
|
||||
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
|
||||
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
|
||||
}, paths)
|
||||
assert.True(t, len(names) == len(paths))
|
||||
|
||||
testStruct := Embedded{}
|
||||
for i, path := range paths {
|
||||
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
||||
assert.True(t, val.IsValid())
|
||||
assert.True(t, strings.Contains(names[i], field.Name))
|
||||
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Run("simple struct", func(t *testing.T) {
|
||||
type Dest struct {
|
||||
Foo int `db:"foo"`
|
||||
Bar bool `db:"bar"`
|
||||
Nope string // no tag
|
||||
}
|
||||
|
||||
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
|
||||
})
|
||||
t.Run("complex structs", func(t *testing.T) {
|
||||
type CustomInt int
|
||||
type S2 struct {
|
||||
B bool `db:"B"`
|
||||
PB *bool `db:"PB"`
|
||||
|
||||
NoTag string
|
||||
}
|
||||
type S struct {
|
||||
I int `db:"I"`
|
||||
PI *int `db:"PI"`
|
||||
CI CustomInt `db:"CI"`
|
||||
PCI *CustomInt `db:"PCI"`
|
||||
S2 `db:"S2"` // embedded!
|
||||
PS2 *S2 `db:"PS2"`
|
||||
|
||||
NoTag int
|
||||
}
|
||||
type Nested struct {
|
||||
S S `db:"S"`
|
||||
PS *S `db:"PS"`
|
||||
|
||||
NoTag S
|
||||
}
|
||||
type Dest struct {
|
||||
NoTag S
|
||||
Nested
|
||||
}
|
||||
|
||||
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
|
||||
})
|
||||
t.Run("int", func(t *testing.T) {
|
||||
type Dest int
|
||||
|
||||
// There should be no error here because we do not need to extract columns from
|
||||
// the destination type. There may be errors down the line in value iteration, but
|
||||
// that is always the case if the Go types don't match the query.
|
||||
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
|
||||
})
|
||||
t.Run("just one table", func(t *testing.T) {
|
||||
type Dest struct {
|
||||
Foo int `db:"foo"`
|
||||
Bar bool `db:"bar"`
|
||||
Nope string // no tag
|
||||
}
|
||||
|
||||
// The prefix is necessary because otherwise we would have to provide a struct with
|
||||
// a db tag in order to provide the query with the `greeblies.` prefix in the
|
||||
// final query. This comes up a lot when we do a JOIN to help with a condition, but
|
||||
// don't actually care about any of the data we joined to.
|
||||
compiled := compileQuery(
|
||||
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
|
||||
reflect.TypeOf(Dest{}),
|
||||
)
|
||||
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
|
||||
})
|
||||
|
||||
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
|
||||
type Dest int
|
||||
|
||||
assert.Panics(t, func() {
|
||||
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryBuilder(t *testing.T) {
|
||||
t.Run("happy time", func(t *testing.T) {
|
||||
var qb QueryBuilder
|
||||
|
|
|
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
|||
userIds = append(userIds, u.User.ID)
|
||||
}
|
||||
|
||||
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
||||
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
||||
}
|
||||
|
||||
for _, ul := range userLinks {
|
||||
link := ul.(*models.Link)
|
||||
for _, link := range userLinks {
|
||||
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
||||
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
||||
}
|
||||
|
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
type userQuery struct {
|
||||
User models.User `db:"auth_user"`
|
||||
}
|
||||
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
||||
u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM auth_user
|
||||
FROM
|
||||
auth_user
|
||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||
WHERE auth_user.id = $1
|
||||
`,
|
||||
|
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
||||
}
|
||||
}
|
||||
user := u.(*userQuery).User
|
||||
user := u.User
|
||||
|
||||
whatHappened := ""
|
||||
if action == ApprovalQueueActionApprove {
|
||||
|
@ -337,7 +337,7 @@ type UnapprovedPost struct {
|
|||
}
|
||||
|
||||
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
|
||||
res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
|||
if err != nil {
|
||||
return nil, oops.New(err, "failed to fetch unapproved posts")
|
||||
}
|
||||
var res []*UnapprovedPost
|
||||
for _, iresult := range it {
|
||||
res = append(res, iresult.(*UnapprovedPost))
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
|
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
|||
type unapprovedUser struct {
|
||||
ID int `db:"id"`
|
||||
}
|
||||
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
|
||||
uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
|||
if err != nil {
|
||||
return nil, oops.New(err, "failed to fetch unapproved users")
|
||||
}
|
||||
ownerIDs := make([]int, 0, len(it))
|
||||
for _, uid := range it {
|
||||
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
|
||||
ownerIDs := make([]int, 0, len(uids))
|
||||
for _, uid := range uids {
|
||||
ownerIDs = append(ownerIDs, uid.ID)
|
||||
}
|
||||
|
||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||
|
|
|
@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
|||
type userQuery struct {
|
||||
User models.User `db:"auth_user"`
|
||||
}
|
||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
||||
userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
|
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
||||
}
|
||||
} else {
|
||||
canonicalUsername = userResult.(*userQuery).User.Username
|
||||
canonicalUsername = userResult.User.Username
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
|
|||
type userQuery struct {
|
||||
User models.User `db:"auth_user"`
|
||||
}
|
||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
||||
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM auth_user
|
||||
FROM
|
||||
auth_user
|
||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||
WHERE LOWER(username) = LOWER($1)
|
||||
`,
|
||||
|
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||
}
|
||||
}
|
||||
user := &userRow.(*userQuery).User
|
||||
user := &userRow.User
|
||||
|
||||
success, err := tryLogin(c, user, password)
|
||||
|
||||
|
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
type userQuery struct {
|
||||
User models.User `db:"auth_user"`
|
||||
}
|
||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
||||
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM auth_user
|
||||
|
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
}
|
||||
}
|
||||
if userRow != nil {
|
||||
user = &userRow.(*userQuery).User
|
||||
user = &userRow.User
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
c.Perf.StartBlock("SQL", "Fetching existing token")
|
||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
||||
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM handmade_onetimetoken
|
||||
|
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
||||
}
|
||||
}
|
||||
var resetToken *models.OneTimeToken
|
||||
if tokenRow != nil {
|
||||
resetToken = tokenRow.(*models.OneTimeToken)
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
if resetToken != nil {
|
||||
|
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
|
||||
if resetToken == nil {
|
||||
c.Perf.StartBlock("SQL", "Creating new token")
|
||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
||||
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||
`
|
||||
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
|
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
|||
if err != nil {
|
||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
||||
}
|
||||
resetToken = tokenRow.(*models.OneTimeToken)
|
||||
|
||||
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
||||
if err != nil {
|
||||
|
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
|||
User models.User `db:"auth_user"`
|
||||
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
||||
}
|
||||
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
|
||||
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
|
||||
`
|
||||
SELECT $columns
|
||||
FROM auth_user
|
||||
|
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
|||
return result
|
||||
}
|
||||
}
|
||||
if row != nil {
|
||||
data := row.(*userAndTokenQuery)
|
||||
if data != nil {
|
||||
result.User = &data.User
|
||||
result.OneTimeToken = data.OneTimeToken
|
||||
if result.OneTimeToken != nil {
|
||||
|
|
Loading…
Reference in New Issue