Start improvements to db API
Allowing for querying of scalars with db.Query, and a choice of prefix in $columns to remove the need for single-field structs in many queries.
This commit is contained in:
parent
e74b3273cb
commit
bd6c95203c
103
src/db/db.go
103
src/db/db.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/src/config"
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
|
@ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type columnName []string
|
||||||
|
|
||||||
|
// A path to a particular field in query's destination type. Each index in the slice
|
||||||
|
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
|
||||||
|
type fieldPath []int
|
||||||
|
|
||||||
type StructQueryIterator[T any] struct {
|
type StructQueryIterator[T any] struct {
|
||||||
fieldPaths [][]int
|
fieldPaths []fieldPath
|
||||||
rows pgx.Rows
|
rows pgx.Rows
|
||||||
destType reflect.Type
|
destType reflect.Type
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
@ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte
|
||||||
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
|
func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) {
|
||||||
var destExample T
|
var destExample T
|
||||||
destType := reflect.TypeOf(destExample)
|
destType := reflect.TypeOf(destExample)
|
||||||
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, oops.New(err, "failed to generate column names")
|
|
||||||
}
|
|
||||||
|
|
||||||
columns := make([]string, 0, len(columnNames))
|
compiled := compileQuery(query, destType)
|
||||||
for _, strSlice := range columnNames {
|
|
||||||
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
|
||||||
fullName := strSlice[len(strSlice)-1]
|
|
||||||
if tableName != "" {
|
|
||||||
fullName = tableName + "." + fullName
|
|
||||||
}
|
|
||||||
columns = append(columns, fullName)
|
|
||||||
}
|
|
||||||
|
|
||||||
columnNamesString := strings.Join(columns, ", ")
|
rows, err := conn.Query(ctx, compiled.query, args...)
|
||||||
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
|
||||||
|
|
||||||
rows, err := conn.Query(ctx, query, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
panic("query exceeded its deadline")
|
panic("query exceeded its deadline")
|
||||||
|
@ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
|
||||||
}
|
}
|
||||||
|
|
||||||
it := &StructQueryIterator[T]{
|
it := &StructQueryIterator[T]{
|
||||||
fieldPaths: fieldPaths,
|
fieldPaths: compiled.fieldPaths,
|
||||||
rows: rows,
|
rows: rows,
|
||||||
destType: destType,
|
destType: compiled.destType,
|
||||||
closed: make(chan struct{}, 1),
|
closed: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args
|
||||||
return it, nil
|
return it, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
|
type compiledQuery struct {
|
||||||
var columnNames [][]string
|
query string
|
||||||
var fieldPaths [][]int
|
destType reflect.Type
|
||||||
|
fieldPaths []fieldPath
|
||||||
|
}
|
||||||
|
|
||||||
|
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
|
||||||
|
|
||||||
|
func compileQuery(query string, destType reflect.Type) compiledQuery {
|
||||||
|
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
|
||||||
|
hasColumnsPlaceholder := columnsMatch != nil
|
||||||
|
|
||||||
|
if hasColumnsPlaceholder {
|
||||||
|
// The presence of the $columns placeholder means that the destination type
|
||||||
|
// must be a struct, and we will plonk that struct's fields into the query.
|
||||||
|
|
||||||
|
if destType.Kind() != reflect.Struct {
|
||||||
|
panic("$columns can only be used when querying into some kind of struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix []string
|
||||||
|
prefixText := columnsMatch[2]
|
||||||
|
if prefixText != "" {
|
||||||
|
prefix = []string{prefixText}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(columnNames))
|
||||||
|
for _, strSlice := range columnNames {
|
||||||
|
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
||||||
|
fullName := strSlice[len(strSlice)-1]
|
||||||
|
if tableName != "" {
|
||||||
|
fullName = tableName + "." + fullName
|
||||||
|
}
|
||||||
|
columns = append(columns, fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNamesString := strings.Join(columns, ", ")
|
||||||
|
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
|
||||||
|
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
fieldPaths: fieldPaths,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
|
||||||
|
var columnNames []columnName
|
||||||
|
var fieldPaths []fieldPath
|
||||||
|
|
||||||
if destType.Kind() == reflect.Ptr {
|
if destType.Kind() == reflect.Ptr {
|
||||||
destType = destType.Elem()
|
destType = destType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if destType.Kind() != reflect.Struct {
|
if destType.Kind() != reflect.Struct {
|
||||||
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
|
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnonPrefix struct {
|
type AnonPrefix struct {
|
||||||
|
@ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
|
||||||
columnNames = append(columnNames, fieldColumnNames)
|
columnNames = append(columnNames, fieldColumnNames)
|
||||||
fieldPaths = append(fieldPaths, path)
|
fieldPaths = append(fieldPaths, path)
|
||||||
} else if fieldType.Kind() == reflect.Struct {
|
} else if fieldType.Kind() == reflect.Struct {
|
||||||
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
columnNames = append(columnNames, subCols...)
|
columnNames = append(columnNames, subCols...)
|
||||||
fieldPaths = append(fieldPaths, subPaths...)
|
fieldPaths = append(fieldPaths, subPaths...)
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
|
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columnNames, fieldPaths, nil
|
return columnNames, fieldPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -10,13 +10,90 @@ import (
|
||||||
|
|
||||||
func TestPaths(t *testing.T) {
|
func TestPaths(t *testing.T) {
|
||||||
type CustomInt int
|
type CustomInt int
|
||||||
|
type S2 struct {
|
||||||
|
B bool `db:"B"` // field 0
|
||||||
|
PB *bool `db:"PB"` // field 1
|
||||||
|
|
||||||
|
NoTag string // field 2
|
||||||
|
}
|
||||||
|
type S struct {
|
||||||
|
I int `db:"I"` // field 0
|
||||||
|
PI *int `db:"PI"` // field 1
|
||||||
|
CI CustomInt `db:"CI"` // field 2
|
||||||
|
PCI *CustomInt `db:"PCI"` // field 3
|
||||||
|
S2 `db:"S2"` // field 4 (embedded!)
|
||||||
|
PS2 *S2 `db:"PS2"` // field 5
|
||||||
|
|
||||||
|
NoTag int // field 6
|
||||||
|
}
|
||||||
|
type Nested struct {
|
||||||
|
S S `db:"S"` // field 0
|
||||||
|
PS *S `db:"PS"` // field 1
|
||||||
|
|
||||||
|
NoTag S // field 2
|
||||||
|
}
|
||||||
|
type Embedded struct {
|
||||||
|
NoTag S // field 0
|
||||||
|
Nested // field 1
|
||||||
|
}
|
||||||
|
|
||||||
|
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
|
||||||
|
assert.Equal(t, []columnName{
|
||||||
|
{"S", "I"}, {"S", "PI"},
|
||||||
|
{"S", "CI"}, {"S", "PCI"},
|
||||||
|
{"S", "S2", "B"}, {"S", "S2", "PB"},
|
||||||
|
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
|
||||||
|
{"PS", "I"}, {"PS", "PI"},
|
||||||
|
{"PS", "CI"}, {"PS", "PCI"},
|
||||||
|
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
|
||||||
|
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
|
||||||
|
}, names)
|
||||||
|
assert.Equal(t, []fieldPath{
|
||||||
|
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
|
||||||
|
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
|
||||||
|
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
|
||||||
|
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
|
||||||
|
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
|
||||||
|
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
|
||||||
|
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
|
||||||
|
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
|
||||||
|
}, paths)
|
||||||
|
assert.True(t, len(names) == len(paths))
|
||||||
|
|
||||||
|
testStruct := Embedded{}
|
||||||
|
for i, path := range paths {
|
||||||
|
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
||||||
|
assert.True(t, val.IsValid())
|
||||||
|
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileQuery(t *testing.T) {
|
||||||
|
t.Run("simple struct", func(t *testing.T) {
|
||||||
|
type Dest struct {
|
||||||
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
|
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
|
||||||
|
})
|
||||||
|
t.Run("complex structs", func(t *testing.T) {
|
||||||
|
type CustomInt int
|
||||||
|
type S2 struct {
|
||||||
|
B bool `db:"B"`
|
||||||
|
PB *bool `db:"PB"`
|
||||||
|
|
||||||
|
NoTag string
|
||||||
|
}
|
||||||
type S struct {
|
type S struct {
|
||||||
I int `db:"I"`
|
I int `db:"I"`
|
||||||
PI *int `db:"PI"`
|
PI *int `db:"PI"`
|
||||||
CI CustomInt `db:"CI"`
|
CI CustomInt `db:"CI"`
|
||||||
PCI *CustomInt `db:"PCI"`
|
PCI *CustomInt `db:"PCI"`
|
||||||
B bool `db:"B"`
|
S2 `db:"S2"` // embedded!
|
||||||
PB *bool `db:"PB"`
|
PS2 *S2 `db:"PS2"`
|
||||||
|
|
||||||
NoTag int
|
NoTag int
|
||||||
}
|
}
|
||||||
|
@ -26,34 +103,48 @@ func TestPaths(t *testing.T) {
|
||||||
|
|
||||||
NoTag S
|
NoTag S
|
||||||
}
|
}
|
||||||
type Embedded struct {
|
type Dest struct {
|
||||||
NoTag S
|
NoTag S
|
||||||
Nested
|
Nested
|
||||||
}
|
}
|
||||||
|
|
||||||
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
if assert.Nil(t, err) {
|
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
|
||||||
assert.Equal(t, []string{
|
})
|
||||||
"S.I", "S.PI",
|
t.Run("int", func(t *testing.T) {
|
||||||
"S.CI", "S.PCI",
|
type Dest int
|
||||||
"S.B", "S.PB",
|
|
||||||
"PS.I", "PS.PI",
|
// There should be no error here because we do not need to extract columns from
|
||||||
"PS.CI", "PS.PCI",
|
// the destination type. There may be errors down the line in value iteration, but
|
||||||
"PS.B", "PS.PB",
|
// that is always the case if the Go types don't match the query.
|
||||||
}, names)
|
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
assert.Equal(t, [][]int{
|
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
|
||||||
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
|
})
|
||||||
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
|
t.Run("just one table", func(t *testing.T) {
|
||||||
}, paths)
|
type Dest struct {
|
||||||
assert.True(t, len(names) == len(paths))
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
}
|
}
|
||||||
|
|
||||||
testStruct := Embedded{}
|
// The prefix is necessary because otherwise we would have to provide a struct with
|
||||||
for i, path := range paths {
|
// a db tag in order to provide the query with the `greeblies.` prefix in the
|
||||||
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
// final query. This comes up a lot when we do a JOIN to help with a condition, but
|
||||||
assert.True(t, val.IsValid())
|
// don't actually care about any of the data we joined to.
|
||||||
assert.True(t, strings.Contains(names[i], field.Name))
|
compiled := compileQuery(
|
||||||
}
|
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
|
||||||
|
reflect.TypeOf(Dest{}),
|
||||||
|
)
|
||||||
|
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
|
||||||
|
type Dest int
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryBuilder(t *testing.T) {
|
func TestQueryBuilder(t *testing.T) {
|
||||||
|
|
|
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
userIds = append(userIds, u.User.ID)
|
userIds = append(userIds, u.User.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ul := range userLinks {
|
for _, link := range userLinks {
|
||||||
link := ul.(*models.Link)
|
|
||||||
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
||||||
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
||||||
}
|
}
|
||||||
|
@ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
u, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE auth_user.id = $1
|
WHERE auth_user.id = $1
|
||||||
`,
|
`,
|
||||||
|
@ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := u.(*userQuery).User
|
user := u.User
|
||||||
|
|
||||||
whatHappened := ""
|
whatHappened := ""
|
||||||
if action == ApprovalQueueActionApprove {
|
if action == ApprovalQueueActionApprove {
|
||||||
|
@ -337,7 +337,7 @@ type UnapprovedPost struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
|
res, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved posts")
|
return nil, oops.New(err, "failed to fetch unapproved posts")
|
||||||
}
|
}
|
||||||
var res []*UnapprovedPost
|
|
||||||
for _, iresult := range it {
|
|
||||||
res = append(res, iresult.(*UnapprovedPost))
|
|
||||||
}
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
type unapprovedUser struct {
|
type unapprovedUser struct {
|
||||||
ID int `db:"id"`
|
ID int `db:"id"`
|
||||||
}
|
}
|
||||||
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
|
uids, err := db.Query[unapprovedUser](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved users")
|
return nil, oops.New(err, "failed to fetch unapproved users")
|
||||||
}
|
}
|
||||||
ownerIDs := make([]int, 0, len(it))
|
ownerIDs := make([]int, 0, len(uids))
|
||||||
for _, uid := range it {
|
for _, uid := range uids {
|
||||||
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
|
ownerIDs = append(ownerIDs, uid.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||||
|
|
|
@ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
canonicalUsername = userResult.(*userQuery).User.Username
|
canonicalUsername = userResult.User.Username
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData {
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE LOWER(username) = LOWER($1)
|
WHERE LOWER(username) = LOWER($1)
|
||||||
`,
|
`,
|
||||||
|
@ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := &userRow.(*userQuery).User
|
user := &userRow.User
|
||||||
|
|
||||||
success, err := tryLogin(c, user, password)
|
success, err := tryLogin(c, user, password)
|
||||||
|
|
||||||
|
@ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if userRow != nil {
|
if userRow != nil {
|
||||||
user = &userRow.(*userQuery).User
|
user = &userRow.User
|
||||||
}
|
}
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
c.Perf.StartBlock("SQL", "Fetching existing token")
|
c.Perf.StartBlock("SQL", "Fetching existing token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_onetimetoken
|
FROM handmade_onetimetoken
|
||||||
|
@ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var resetToken *models.OneTimeToken
|
|
||||||
if tokenRow != nil {
|
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
|
||||||
}
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
if resetToken != nil {
|
if resetToken != nil {
|
||||||
|
@ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
if resetToken == nil {
|
if resetToken == nil {
|
||||||
c.Perf.StartBlock("SQL", "Creating new token")
|
c.Perf.StartBlock("SQL", "Creating new token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
||||||
}
|
}
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
|
||||||
|
|
||||||
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
||||||
}
|
}
|
||||||
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
|
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if row != nil {
|
if data != nil {
|
||||||
data := row.(*userAndTokenQuery)
|
|
||||||
result.User = &data.User
|
result.User = &data.User
|
||||||
result.OneTimeToken = data.OneTimeToken
|
result.OneTimeToken = data.OneTimeToken
|
||||||
if result.OneTimeToken != nil {
|
if result.OneTimeToken != nil {
|
||||||
|
|
Loading…
Reference in New Issue