Start improvements to db API
Allowing for querying of scalars with db.Query, and a choice of prefix in $columns to remove the need for single-field structs in many queries.
This commit is contained in:
		
							parent
							
								
									e74b3273cb
								
							
						
					
					
						commit
						bd6c95203c
					
				
							
								
								
									
										103
									
								
								src/db/db.go
								
								
								
								
							
							
						
						
									
										103
									
								
								src/db/db.go
								
								
								
								
							|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"git.handmade.network/hmn/hmn/src/config" | 	"git.handmade.network/hmn/hmn/src/config" | ||||||
|  | @ -95,8 +96,14 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool { | ||||||
| 	return conn | 	return conn | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type columnName []string | ||||||
|  | 
 | ||||||
|  | // A path to a particular field in query's destination type. Each index in the slice
 | ||||||
|  | // corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
 | ||||||
|  | type fieldPath []int | ||||||
|  | 
 | ||||||
| type StructQueryIterator[T any] struct { | type StructQueryIterator[T any] struct { | ||||||
| 	fieldPaths [][]int | 	fieldPaths []fieldPath | ||||||
| 	rows       pgx.Rows | 	rows       pgx.Rows | ||||||
| 	destType   reflect.Type | 	destType   reflect.Type | ||||||
| 	closed     chan struct{} | 	closed     chan struct{} | ||||||
|  | @ -243,25 +250,10 @@ func Query[T any](ctx context.Context, conn ConnOrTx, query string, args ...inte | ||||||
| func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { | func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (*StructQueryIterator[T], error) { | ||||||
| 	var destExample T | 	var destExample T | ||||||
| 	destType := reflect.TypeOf(destExample) | 	destType := reflect.TypeOf(destExample) | ||||||
| 	columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, oops.New(err, "failed to generate column names") |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	columns := make([]string, 0, len(columnNames)) | 	compiled := compileQuery(query, destType) | ||||||
| 	for _, strSlice := range columnNames { |  | ||||||
| 		tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") |  | ||||||
| 		fullName := strSlice[len(strSlice)-1] |  | ||||||
| 		if tableName != "" { |  | ||||||
| 			fullName = tableName + "." + fullName |  | ||||||
| 		} |  | ||||||
| 		columns = append(columns, fullName) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	columnNamesString := strings.Join(columns, ", ") | 	rows, err := conn.Query(ctx, compiled.query, args...) | ||||||
| 	query = strings.Replace(query, "$columns", columnNamesString, -1) |  | ||||||
| 
 |  | ||||||
| 	rows, err := conn.Query(ctx, query, args...) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if errors.Is(err, context.DeadlineExceeded) { | 		if errors.Is(err, context.DeadlineExceeded) { | ||||||
| 			panic("query exceeded its deadline") | 			panic("query exceeded its deadline") | ||||||
|  | @ -270,9 +262,9 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	it := &StructQueryIterator[T]{ | 	it := &StructQueryIterator[T]{ | ||||||
| 		fieldPaths: fieldPaths, | 		fieldPaths: compiled.fieldPaths, | ||||||
| 		rows:       rows, | 		rows:       rows, | ||||||
| 		destType:   destType, | 		destType:   compiled.destType, | ||||||
| 		closed:     make(chan struct{}, 1), | 		closed:     make(chan struct{}, 1), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -293,16 +285,70 @@ func QueryIterator[T any](ctx context.Context, conn ConnOrTx, query string, args | ||||||
| 	return it, nil | 	return it, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) { | type compiledQuery struct { | ||||||
| 	var columnNames [][]string | 	query      string | ||||||
| 	var fieldPaths [][]int | 	destType   reflect.Type | ||||||
|  | 	fieldPaths []fieldPath | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`) | ||||||
|  | 
 | ||||||
|  | func compileQuery(query string, destType reflect.Type) compiledQuery { | ||||||
|  | 	columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query) | ||||||
|  | 	hasColumnsPlaceholder := columnsMatch != nil | ||||||
|  | 
 | ||||||
|  | 	if hasColumnsPlaceholder { | ||||||
|  | 		// The presence of the $columns placeholder means that the destination type
 | ||||||
|  | 		// must be a struct, and we will plonk that struct's fields into the query.
 | ||||||
|  | 
 | ||||||
|  | 		if destType.Kind() != reflect.Struct { | ||||||
|  | 			panic("$columns can only be used when querying into some kind of struct") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var prefix []string | ||||||
|  | 		prefixText := columnsMatch[2] | ||||||
|  | 		if prefixText != "" { | ||||||
|  | 			prefix = []string{prefixText} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix) | ||||||
|  | 
 | ||||||
|  | 		columns := make([]string, 0, len(columnNames)) | ||||||
|  | 		for _, strSlice := range columnNames { | ||||||
|  | 			tableName := strings.Join(strSlice[0:len(strSlice)-1], "_") | ||||||
|  | 			fullName := strSlice[len(strSlice)-1] | ||||||
|  | 			if tableName != "" { | ||||||
|  | 				fullName = tableName + "." + fullName | ||||||
|  | 			} | ||||||
|  | 			columns = append(columns, fullName) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		columnNamesString := strings.Join(columns, ", ") | ||||||
|  | 		query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString) | ||||||
|  | 
 | ||||||
|  | 		return compiledQuery{ | ||||||
|  | 			query:      query, | ||||||
|  | 			destType:   destType, | ||||||
|  | 			fieldPaths: fieldPaths, | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		return compiledQuery{ | ||||||
|  | 			query:    query, | ||||||
|  | 			destType: destType, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) { | ||||||
|  | 	var columnNames []columnName | ||||||
|  | 	var fieldPaths []fieldPath | ||||||
| 
 | 
 | ||||||
| 	if destType.Kind() == reflect.Ptr { | 	if destType.Kind() == reflect.Ptr { | ||||||
| 		destType = destType.Elem() | 		destType = destType.Elem() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if destType.Kind() != reflect.Struct { | 	if destType.Kind() != reflect.Struct { | ||||||
| 		return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix) | 		panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	type AnonPrefix struct { | 	type AnonPrefix struct { | ||||||
|  | @ -349,19 +395,16 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str | ||||||
| 				columnNames = append(columnNames, fieldColumnNames) | 				columnNames = append(columnNames, fieldColumnNames) | ||||||
| 				fieldPaths = append(fieldPaths, path) | 				fieldPaths = append(fieldPaths, path) | ||||||
| 			} else if fieldType.Kind() == reflect.Struct { | 			} else if fieldType.Kind() == reflect.Struct { | ||||||
| 				subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) | 				subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames) | ||||||
| 				if err != nil { |  | ||||||
| 					return nil, nil, err |  | ||||||
| 				} |  | ||||||
| 				columnNames = append(columnNames, subCols...) | 				columnNames = append(columnNames, subCols...) | ||||||
| 				fieldPaths = append(fieldPaths, subPaths...) | 				fieldPaths = append(fieldPaths, subPaths...) | ||||||
| 			} else { | 			} else { | ||||||
| 				return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type) | 				panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return columnNames, fieldPaths, nil | 	return columnNames, fieldPaths | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* | /* | ||||||
|  |  | ||||||
|  | @ -10,13 +10,90 @@ import ( | ||||||
| 
 | 
 | ||||||
| func TestPaths(t *testing.T) { | func TestPaths(t *testing.T) { | ||||||
| 	type CustomInt int | 	type CustomInt int | ||||||
|  | 	type S2 struct { | ||||||
|  | 		B  bool  `db:"B"`  // field 0
 | ||||||
|  | 		PB *bool `db:"PB"` // field 1
 | ||||||
|  | 
 | ||||||
|  | 		NoTag string // field 2
 | ||||||
|  | 	} | ||||||
|  | 	type S struct { | ||||||
|  | 		I   int        `db:"I"`   // field 0
 | ||||||
|  | 		PI  *int       `db:"PI"`  // field 1
 | ||||||
|  | 		CI  CustomInt  `db:"CI"`  // field 2
 | ||||||
|  | 		PCI *CustomInt `db:"PCI"` // field 3
 | ||||||
|  | 		S2  `db:"S2"`  // field 4 (embedded!)
 | ||||||
|  | 		PS2 *S2        `db:"PS2"` // field 5
 | ||||||
|  | 
 | ||||||
|  | 		NoTag int // field 6
 | ||||||
|  | 	} | ||||||
|  | 	type Nested struct { | ||||||
|  | 		S  S  `db:"S"`  // field 0
 | ||||||
|  | 		PS *S `db:"PS"` // field 1
 | ||||||
|  | 
 | ||||||
|  | 		NoTag S // field 2
 | ||||||
|  | 	} | ||||||
|  | 	type Embedded struct { | ||||||
|  | 		NoTag  S // field 0
 | ||||||
|  | 		Nested   // field 1
 | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil) | ||||||
|  | 	assert.Equal(t, []columnName{ | ||||||
|  | 		{"S", "I"}, {"S", "PI"}, | ||||||
|  | 		{"S", "CI"}, {"S", "PCI"}, | ||||||
|  | 		{"S", "S2", "B"}, {"S", "S2", "PB"}, | ||||||
|  | 		{"S", "PS2", "B"}, {"S", "PS2", "PB"}, | ||||||
|  | 		{"PS", "I"}, {"PS", "PI"}, | ||||||
|  | 		{"PS", "CI"}, {"PS", "PCI"}, | ||||||
|  | 		{"PS", "S2", "B"}, {"PS", "S2", "PB"}, | ||||||
|  | 		{"PS", "PS2", "B"}, {"PS", "PS2", "PB"}, | ||||||
|  | 	}, names) | ||||||
|  | 	assert.Equal(t, []fieldPath{ | ||||||
|  | 		{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
 | ||||||
|  | 		{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
 | ||||||
|  | 		{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
 | ||||||
|  | 		{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
 | ||||||
|  | 		{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
 | ||||||
|  | 		{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
 | ||||||
|  | 		{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
 | ||||||
|  | 		{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
 | ||||||
|  | 	}, paths) | ||||||
|  | 	assert.True(t, len(names) == len(paths)) | ||||||
|  | 
 | ||||||
|  | 	testStruct := Embedded{} | ||||||
|  | 	for i, path := range paths { | ||||||
|  | 		val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) | ||||||
|  | 		assert.True(t, val.IsValid()) | ||||||
|  | 		assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name)) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCompileQuery(t *testing.T) { | ||||||
|  | 	t.Run("simple struct", func(t *testing.T) { | ||||||
|  | 		type Dest struct { | ||||||
|  | 			Foo  int    `db:"foo"` | ||||||
|  | 			Bar  bool   `db:"bar"` | ||||||
|  | 			Nope string // no tag
 | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) | ||||||
|  | 		assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query) | ||||||
|  | 	}) | ||||||
|  | 	t.Run("complex structs", func(t *testing.T) { | ||||||
|  | 		type CustomInt int | ||||||
|  | 		type S2 struct { | ||||||
|  | 			B  bool  `db:"B"` | ||||||
|  | 			PB *bool `db:"PB"` | ||||||
|  | 
 | ||||||
|  | 			NoTag string | ||||||
|  | 		} | ||||||
| 		type S struct { | 		type S struct { | ||||||
| 			I   int        `db:"I"` | 			I   int        `db:"I"` | ||||||
| 			PI  *int       `db:"PI"` | 			PI  *int       `db:"PI"` | ||||||
| 			CI  CustomInt  `db:"CI"` | 			CI  CustomInt  `db:"CI"` | ||||||
| 			PCI *CustomInt `db:"PCI"` | 			PCI *CustomInt `db:"PCI"` | ||||||
| 		B   bool       `db:"B"` | 			S2  `db:"S2"`  // embedded!
 | ||||||
| 		PB  *bool      `db:"PB"` | 			PS2 *S2        `db:"PS2"` | ||||||
| 
 | 
 | ||||||
| 			NoTag int | 			NoTag int | ||||||
| 		} | 		} | ||||||
|  | @ -26,34 +103,48 @@ func TestPaths(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 			NoTag S | 			NoTag S | ||||||
| 		} | 		} | ||||||
| 	type Embedded struct { | 		type Dest struct { | ||||||
| 			NoTag S | 			NoTag S | ||||||
| 			Nested | 			Nested | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "") | 		compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{})) | ||||||
| 	if assert.Nil(t, err) { | 		assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query) | ||||||
| 		assert.Equal(t, []string{ | 	}) | ||||||
| 			"S.I", "S.PI", | 	t.Run("int", func(t *testing.T) { | ||||||
| 			"S.CI", "S.PCI", | 		type Dest int | ||||||
| 			"S.B", "S.PB", | 
 | ||||||
| 			"PS.I", "PS.PI", | 		// There should be no error here because we do not need to extract columns from
 | ||||||
| 			"PS.CI", "PS.PCI", | 		// the destination type. There may be errors down the line in value iteration, but
 | ||||||
| 			"PS.B", "PS.PB", | 		// that is always the case if the Go types don't match the query.
 | ||||||
| 		}, names) | 		compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0))) | ||||||
| 		assert.Equal(t, [][]int{ | 		assert.Equal(t, "SELECT id FROM greeblies", compiled.query) | ||||||
| 			{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5}, | 	}) | ||||||
| 			{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5}, | 	t.Run("just one table", func(t *testing.T) { | ||||||
| 		}, paths) | 		type Dest struct { | ||||||
| 		assert.True(t, len(names) == len(paths)) | 			Foo  int    `db:"foo"` | ||||||
|  | 			Bar  bool   `db:"bar"` | ||||||
|  | 			Nope string // no tag
 | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	testStruct := Embedded{} | 		// The prefix is necessary because otherwise we would have to provide a struct with
 | ||||||
| 	for i, path := range paths { | 		// a db tag in order to provide the query with the `greeblies.` prefix in the
 | ||||||
| 		val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path) | 		// final query. This comes up a lot when we do a JOIN to help with a condition, but
 | ||||||
| 		assert.True(t, val.IsValid()) | 		// don't actually care about any of the data we joined to.
 | ||||||
| 		assert.True(t, strings.Contains(names[i], field.Name)) | 		compiled := compileQuery( | ||||||
| 	} | 			"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props", | ||||||
|  | 			reflect.TypeOf(Dest{}), | ||||||
|  | 		) | ||||||
|  | 		assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("using $columns without a struct is not allowed", func(t *testing.T) { | ||||||
|  | 		type Dest int | ||||||
|  | 
 | ||||||
|  | 		assert.Panics(t, func() { | ||||||
|  | 			compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0))) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestQueryBuilder(t *testing.T) { | func TestQueryBuilder(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { | ||||||
| 		userIds = append(userIds, u.User.ID) | 		userIds = append(userIds, u.User.ID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	userLinks, err := db.Query(c.Context(), c.Conn, models.Link{}, | 	userLinks, err := db.Query[models.Link](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 		FROM | 		FROM | ||||||
|  | @ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData { | ||||||
| 		return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) | 		return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links")) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, ul := range userLinks { | 	for _, link := range userLinks { | ||||||
| 		link := ul.(*models.Link) |  | ||||||
| 		userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] | 		userData := unapprovedUsers[userIDToDataIdx[*link.UserID]] | ||||||
| 		userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) | 		userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link)) | ||||||
| 	} | 	} | ||||||
|  | @ -260,10 +259,11 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { | ||||||
| 	type userQuery struct { | 	type userQuery struct { | ||||||
| 		User models.User `db:"auth_user"` | 		User models.User `db:"auth_user"` | ||||||
| 	} | 	} | ||||||
| 	u, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, | 	u, err := db.QueryOne[userQuery](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 			FROM auth_user | 		FROM | ||||||
|  | 			auth_user | ||||||
| 			LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id | 			LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id | ||||||
| 		WHERE auth_user.id = $1 | 		WHERE auth_user.id = $1 | ||||||
| 		`, | 		`, | ||||||
|  | @ -276,7 +276,7 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData { | ||||||
| 			return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) | 			return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user")) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	user := u.(*userQuery).User | 	user := u.User | ||||||
| 
 | 
 | ||||||
| 	whatHappened := "" | 	whatHappened := "" | ||||||
| 	if action == ApprovalQueueActionApprove { | 	if action == ApprovalQueueActionApprove { | ||||||
|  | @ -337,7 +337,7 @@ type UnapprovedPost struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { | func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { | ||||||
| 	it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{}, | 	res, err := db.Query[UnapprovedPost](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 		FROM | 		FROM | ||||||
|  | @ -358,10 +358,6 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, oops.New(err, "failed to fetch unapproved posts") | 		return nil, oops.New(err, "failed to fetch unapproved posts") | ||||||
| 	} | 	} | ||||||
| 	var res []*UnapprovedPost |  | ||||||
| 	for _, iresult := range it { |  | ||||||
| 		res = append(res, iresult.(*UnapprovedPost)) |  | ||||||
| 	} |  | ||||||
| 	return res, nil | 	return res, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -375,7 +371,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { | ||||||
| 	type unapprovedUser struct { | 	type unapprovedUser struct { | ||||||
| 		ID int `db:"id"` | 		ID int `db:"id"` | ||||||
| 	} | 	} | ||||||
| 	it, err := db.Query(c.Context(), c.Conn, unapprovedUser{}, | 	uids, err := db.Query[unapprovedUser](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 		FROM | 		FROM | ||||||
|  | @ -388,9 +384,9 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, oops.New(err, "failed to fetch unapproved users") | 		return nil, oops.New(err, "failed to fetch unapproved users") | ||||||
| 	} | 	} | ||||||
| 	ownerIDs := make([]int, 0, len(it)) | 	ownerIDs := make([]int, 0, len(uids)) | ||||||
| 	for _, uid := range it { | 	for _, uid := range uids { | ||||||
| 		ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID) | 		ownerIDs = append(ownerIDs, uid.ID) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ | 	projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{ | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ func APICheckUsername(c *RequestContext) ResponseData { | ||||||
| 		type userQuery struct { | 		type userQuery struct { | ||||||
| 			User models.User `db:"auth_user"` | 			User models.User `db:"auth_user"` | ||||||
| 		} | 		} | ||||||
| 		userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, | 		userResult, err := db.QueryOne[userQuery](c.Context(), c.Conn, | ||||||
| 			` | 			` | ||||||
| 			SELECT $columns | 			SELECT $columns | ||||||
| 			FROM | 			FROM | ||||||
|  | @ -43,7 +43,7 @@ func APICheckUsername(c *RequestContext) ResponseData { | ||||||
| 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) | 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername)) | ||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			canonicalUsername = userResult.(*userQuery).User.Username | 			canonicalUsername = userResult.User.Username | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -78,10 +78,11 @@ func Login(c *RequestContext) ResponseData { | ||||||
| 	type userQuery struct { | 	type userQuery struct { | ||||||
| 		User models.User `db:"auth_user"` | 		User models.User `db:"auth_user"` | ||||||
| 	} | 	} | ||||||
| 	userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, | 	userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 			FROM auth_user | 		FROM | ||||||
|  | 			auth_user | ||||||
| 			LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id | 			LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id | ||||||
| 		WHERE LOWER(username) = LOWER($1) | 		WHERE LOWER(username) = LOWER($1) | ||||||
| 		`, | 		`, | ||||||
|  | @ -94,7 +95,7 @@ func Login(c *RequestContext) ResponseData { | ||||||
| 			return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) | 			return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	user := &userRow.(*userQuery).User | 	user := &userRow.User | ||||||
| 
 | 
 | ||||||
| 	success, err := tryLogin(c, user, password) | 	success, err := tryLogin(c, user, password) | ||||||
| 
 | 
 | ||||||
|  | @ -460,7 +461,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { | ||||||
| 	type userQuery struct { | 	type userQuery struct { | ||||||
| 		User models.User `db:"auth_user"` | 		User models.User `db:"auth_user"` | ||||||
| 	} | 	} | ||||||
| 	userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{}, | 	userRow, err := db.QueryOne[userQuery](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 		FROM auth_user | 		FROM auth_user | ||||||
|  | @ -479,12 +480,12 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if userRow != nil { | 	if userRow != nil { | ||||||
| 		user = &userRow.(*userQuery).User | 		user = &userRow.User | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if user != nil { | 	if user != nil { | ||||||
| 		c.Perf.StartBlock("SQL", "Fetching existing token") | 		c.Perf.StartBlock("SQL", "Fetching existing token") | ||||||
| 		tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, | 		resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, | ||||||
| 			` | 			` | ||||||
| 			SELECT $columns | 			SELECT $columns | ||||||
| 			FROM handmade_onetimetoken | 			FROM handmade_onetimetoken | ||||||
|  | @ -501,10 +502,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { | ||||||
| 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) | 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user")) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		var resetToken *models.OneTimeToken |  | ||||||
| 		if tokenRow != nil { |  | ||||||
| 			resetToken = tokenRow.(*models.OneTimeToken) |  | ||||||
| 		} |  | ||||||
| 		now := time.Now() | 		now := time.Now() | ||||||
| 
 | 
 | ||||||
| 		if resetToken != nil { | 		if resetToken != nil { | ||||||
|  | @ -527,7 +524,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { | ||||||
| 
 | 
 | ||||||
| 		if resetToken == nil { | 		if resetToken == nil { | ||||||
| 			c.Perf.StartBlock("SQL", "Creating new token") | 			c.Perf.StartBlock("SQL", "Creating new token") | ||||||
| 			tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{}, | 			resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn, | ||||||
| 				` | 				` | ||||||
| 				INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) | 				INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id) | ||||||
| 				VALUES ($1, $2, $3, $4, $5) | 				VALUES ($1, $2, $3, $4, $5) | ||||||
|  | @ -543,7 +540,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData { | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) | 				return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken")) | ||||||
| 			} | 			} | ||||||
| 			resetToken = tokenRow.(*models.OneTimeToken) |  | ||||||
| 
 | 
 | ||||||
| 			err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) | 			err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
|  | @ -787,7 +783,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, | ||||||
| 		User         models.User          `db:"auth_user"` | 		User         models.User          `db:"auth_user"` | ||||||
| 		OneTimeToken *models.OneTimeToken `db:"onetimetoken"` | 		OneTimeToken *models.OneTimeToken `db:"onetimetoken"` | ||||||
| 	} | 	} | ||||||
| 	row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{}, | 	data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn, | ||||||
| 		` | 		` | ||||||
| 		SELECT $columns | 		SELECT $columns | ||||||
| 		FROM auth_user | 		FROM auth_user | ||||||
|  | @ -807,8 +803,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string, | ||||||
| 			return result | 			return result | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if row != nil { | 	if data != nil { | ||||||
| 		data := row.(*userAndTokenQuery) |  | ||||||
| 		result.User = &data.User | 		result.User = &data.User | ||||||
| 		result.OneTimeToken = data.OneTimeToken | 		result.OneTimeToken = data.OneTimeToken | ||||||
| 		if result.OneTimeToken != nil { | 		if result.OneTimeToken != nil { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue