Rework DB API
This takes advantage of generics, and generally clears up a lot of inconsistencies and quality-of-life issues. Start of db rework: clean up, start generics, improve tests Write some nice aspirational package docs Rework and document the db API Tests still pass, at least...now for everything else Update all callsites of db functions Finish converting all callsites Not too bad actually! Centralizing access into the helpers makes a big difference. wtf it works
This commit is contained in:
parent
6004149417
commit
a147cfa325
40
go.mod
40
go.mod
|
@ -1,10 +1,8 @@
|
||||||
module git.handmade.network/hmn/hmn
|
module git.handmade.network/hmn/hmn
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
|
||||||
github.com/Masterminds/semver v1.5.0 // indirect
|
|
||||||
github.com/Masterminds/sprig v2.22.0+incompatible
|
github.com/Masterminds/sprig v2.22.0+incompatible
|
||||||
github.com/alecthomas/chroma v0.9.2
|
github.com/alecthomas/chroma v0.9.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.8.1
|
github.com/aws/aws-sdk-go-v2 v1.8.1
|
||||||
|
@ -16,13 +14,10 @@ require (
|
||||||
github.com/go-stack/stack v1.8.0
|
github.com/go-stack/stack v1.8.0
|
||||||
github.com/google/uuid v1.2.0
|
github.com/google/uuid v1.2.0
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/huandu/xstrings v1.3.2 // indirect
|
|
||||||
github.com/imdario/mergo v0.3.12 // indirect
|
|
||||||
github.com/jackc/pgconn v1.8.0
|
github.com/jackc/pgconn v1.8.0
|
||||||
github.com/jackc/pgtype v1.6.2
|
github.com/jackc/pgtype v1.6.2
|
||||||
github.com/jackc/pgx/v4 v4.10.1
|
github.com/jackc/pgx/v4 v4.10.1
|
||||||
github.com/jpillora/backoff v1.0.0
|
github.com/jpillora/backoff v1.0.0
|
||||||
github.com/mitchellh/copystructure v1.1.1 // indirect
|
|
||||||
github.com/rs/zerolog v1.21.0
|
github.com/rs/zerolog v1.21.0
|
||||||
github.com/spf13/cobra v1.1.3
|
github.com/spf13/cobra v1.1.3
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
|
@ -35,6 +30,39 @@ require (
|
||||||
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
|
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||||
|
github.com/Masterminds/semver v1.5.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect
|
||||||
|
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.4.0 // indirect
|
||||||
|
github.com/huandu/xstrings v1.3.2 // indirect
|
||||||
|
github.com/imdario/mergo v0.3.12 // indirect
|
||||||
|
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
|
||||||
|
github.com/jackc/pgio v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.6 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||||
|
github.com/jackc/puddle v1.1.3 // indirect
|
||||||
|
github.com/mitchellh/copystructure v1.1.1 // indirect
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
|
||||||
|
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect
|
||||||
|
golang.org/x/text v0.3.6 // indirect
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||||
|
)
|
||||||
|
|
||||||
replace (
|
replace (
|
||||||
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
|
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
|
||||||
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e
|
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
|
||||||
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
||||||
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
||||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
|
||||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||||
|
@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47
|
||||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
|
||||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||||
|
|
|
@ -55,7 +55,7 @@ func addCreateProjectCommand(projectCommand *cobra.Command) {
|
||||||
}
|
}
|
||||||
hmn := p.Project
|
hmn := p.Project
|
||||||
|
|
||||||
newProjectID, err := db.QueryInt(ctx, tx,
|
newProjectID, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_project (
|
INSERT INTO handmade_project (
|
||||||
slug,
|
slug,
|
||||||
|
|
|
@ -210,7 +210,7 @@ func init() {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -218,7 +218,7 @@ func init() {
|
||||||
var parentId *int
|
var parentId *int
|
||||||
if parentSlug == "" {
|
if parentSlug == "" {
|
||||||
// Select the root subforum
|
// Select the root subforum
|
||||||
id, err := db.QueryInt(ctx, tx,
|
id, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`,
|
`SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`,
|
||||||
projectId,
|
projectId,
|
||||||
)
|
)
|
||||||
|
@ -228,7 +228,7 @@ func init() {
|
||||||
parentId = &id
|
parentId = &id
|
||||||
} else {
|
} else {
|
||||||
// Select the parent
|
// Select the parent
|
||||||
id, err := db.QueryInt(ctx, tx,
|
id, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
||||||
parentSlug, projectId,
|
parentSlug, projectId,
|
||||||
)
|
)
|
||||||
|
@ -238,7 +238,7 @@ func init() {
|
||||||
parentId = &id
|
parentId = &id
|
||||||
}
|
}
|
||||||
|
|
||||||
newId, err := db.QueryInt(ctx, tx,
|
newId, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id)
|
INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -289,12 +289,12 @@ func init() {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subforumId, err := db.QueryInt(ctx, tx,
|
subforumId, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
||||||
subforumSlug, projectId,
|
subforumSlug, projectId,
|
||||||
)
|
)
|
||||||
|
|
|
@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch and return the new record
|
// Fetch and return the new record
|
||||||
iasset, err := db.QueryOne(ctx, dbConn, models.Asset{},
|
asset, err := db.QueryOne[models.Asset](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_asset
|
FROM handmade_asset
|
||||||
|
@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
|
||||||
return nil, oops.New(err, "failed to fetch newly-created asset")
|
return nil, oops.New(err, "failed to fetch newly-created asset")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iasset.(*models.Asset), nil
|
return asset, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ func makeCSRFToken() string {
|
||||||
var ErrNoSession = errors.New("no session found")
|
var ErrNoSession = errors.New("no session found")
|
||||||
|
|
||||||
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
|
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
|
||||||
row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id)
|
sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.NotFound) {
|
if errors.Is(err, db.NotFound) {
|
||||||
return nil, ErrNoSession
|
return nil, ErrNoSession
|
||||||
|
@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses
|
||||||
return nil, oops.New(err, "failed to get session")
|
return nil, oops.New(err, "failed to get session")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sess := row.(*models.Session)
|
|
||||||
|
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
607
src/db/db.go
607
src/db/db.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/src/config"
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
|
@ -20,46 +21,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Values of these kinds are ok to query even if they are not directly understood by pgtype.
|
A general error to be used when no results are found. This is the error returned
|
||||||
This is common for custom types like:
|
by QueryOne, and can generally be used by other database helpers that fetch a single
|
||||||
|
result but find nothing.
|
||||||
type ThreadType int
|
|
||||||
*/
|
*/
|
||||||
var queryableKinds = []reflect.Kind{
|
var NotFound = errors.New("not found")
|
||||||
reflect.Int,
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Checks if we are able to handle a particular type in a database query. This applies only to
|
|
||||||
primitive types and not structs, since the database only returns individual primitive types
|
|
||||||
and it is our job to stitch them back together into structs later.
|
|
||||||
*/
|
|
||||||
func typeIsQueryable(t reflect.Type) bool {
|
|
||||||
_, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
|
|
||||||
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
|
|
||||||
|
|
||||||
if isRecognizedByPgtype {
|
|
||||||
return true
|
|
||||||
} else if t == reflect.TypeOf(uuid.UUID{}) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// pgtype doesn't recognize it, but maybe it's a primitive type we can deal with
|
|
||||||
k := t.Kind()
|
|
||||||
for _, qk := range queryableKinds {
|
|
||||||
if k == qk {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// This interface should match both a direct pgx connection or a pgx transaction.
|
// This interface should match both a direct pgx connection or a pgx transaction.
|
||||||
type ConnOrTx interface {
|
type ConnOrTx interface {
|
||||||
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||||
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||||
Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error)
|
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
|
||||||
|
|
||||||
// Both raw database connections and transactions in pgx can begin/commit
|
// Both raw database connections and transactions in pgx can begin/commit
|
||||||
// transactions. For database connections it does the obvious thing; for
|
// transactions. For database connections it does the obvious thing; for
|
||||||
|
@ -70,6 +42,8 @@ type ConnOrTx interface {
|
||||||
|
|
||||||
var connInfo = pgtype.NewConnInfo()
|
var connInfo = pgtype.NewConnInfo()
|
||||||
|
|
||||||
|
// Creates a new connection to the HMN database.
|
||||||
|
// This connection is not safe for concurrent use.
|
||||||
func NewConn() *pgx.Conn {
|
func NewConn() *pgx.Conn {
|
||||||
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,6 +53,8 @@ func NewConn() *pgx.Conn {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates a connection pool for the HMN database.
|
||||||
|
// The resulting pool is safe for concurrent use.
|
||||||
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
||||||
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
||||||
|
|
||||||
|
@ -95,144 +71,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type StructQueryIterator struct {
|
/*
|
||||||
fieldPaths [][]int
|
Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred.
|
||||||
rows pgx.Rows
|
|
||||||
destType reflect.Type
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) Next() (interface{}, bool) {
|
Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection.
|
||||||
hasNext := it.rows.Next()
|
|
||||||
if !hasNext {
|
|
||||||
it.Close()
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
result := reflect.New(it.destType)
|
This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar.
|
||||||
|
*/
|
||||||
vals, err := it.rows.Values()
|
func Query[T any](
|
||||||
if err != nil {
|
ctx context.Context,
|
||||||
panic(err)
|
conn ConnOrTx,
|
||||||
}
|
query string,
|
||||||
|
args ...any,
|
||||||
// Better logging of panics in this confusing reflection process
|
) ([]*T, error) {
|
||||||
var currentField reflect.StructField
|
it, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
var currentValue reflect.Value
|
|
||||||
var currentIdx int
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
if currentValue.IsValid() {
|
|
||||||
logging.Error().
|
|
||||||
Int("index", currentIdx).
|
|
||||||
Str("field name", currentField.Name).
|
|
||||||
Stringer("field type", currentField.Type).
|
|
||||||
Interface("value", currentValue.Interface()).
|
|
||||||
Stringer("value type", currentValue.Type()).
|
|
||||||
Msg("panic in iterator")
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentField.Name != "" {
|
|
||||||
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
|
|
||||||
} else {
|
|
||||||
panic(r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i, val := range vals {
|
|
||||||
currentIdx = i
|
|
||||||
if val == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var field reflect.Value
|
|
||||||
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
|
|
||||||
if field.Kind() == reflect.Ptr {
|
|
||||||
field.Set(reflect.New(field.Type().Elem()))
|
|
||||||
field = field.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
|
|
||||||
// Regardless, we know it's not nil, so we can get at the contents.
|
|
||||||
valReflected := reflect.ValueOf(val)
|
|
||||||
if valReflected.Kind() == reflect.Ptr {
|
|
||||||
valReflected = valReflected.Elem()
|
|
||||||
}
|
|
||||||
currentValue = valReflected
|
|
||||||
|
|
||||||
switch field.Kind() {
|
|
||||||
case reflect.Int:
|
|
||||||
field.SetInt(valReflected.Int())
|
|
||||||
default:
|
|
||||||
field.Set(valReflected)
|
|
||||||
}
|
|
||||||
|
|
||||||
currentField = reflect.StructField{}
|
|
||||||
currentValue = reflect.Value{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.Interface(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) Close() {
|
|
||||||
it.rows.Close()
|
|
||||||
select {
|
|
||||||
case it.closed <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) ToSlice() []interface{} {
|
|
||||||
defer it.Close()
|
|
||||||
var result []interface{}
|
|
||||||
for {
|
|
||||||
row, ok := it.Next()
|
|
||||||
if !ok {
|
|
||||||
err := it.rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
panic(oops.New(err, "error while iterating through db results"))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
result = append(result, row)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
|
|
||||||
if len(path) < 1 {
|
|
||||||
panic(oops.New(nil, "can't follow an empty path"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
|
|
||||||
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// more informative panic recovery
|
|
||||||
var field reflect.StructField
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
val := structPtrVal
|
|
||||||
for _, i := range path {
|
|
||||||
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
|
|
||||||
if val.IsNil() {
|
|
||||||
val.Set(reflect.New(val.Type().Elem()))
|
|
||||||
}
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
field = val.Type().Field(i)
|
|
||||||
val = val.Field(i)
|
|
||||||
}
|
|
||||||
return val, field
|
|
||||||
}
|
|
||||||
|
|
||||||
func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
|
||||||
it, err := QueryIterator(ctx, conn, destExample, query, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -240,27 +92,99 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) {
|
/*
|
||||||
destType := reflect.TypeOf(destExample)
|
Identical to Query, but returns only the first result row. If there are no
|
||||||
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
|
rows in the result set, returns NotFound.
|
||||||
|
*/
|
||||||
|
func QueryOne[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (*T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to generate column names")
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
return nil, NotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
columns := make([]string, 0, len(columnNames))
|
return result, nil
|
||||||
for _, strSlice := range columnNames {
|
}
|
||||||
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
|
||||||
fullName := strSlice[len(strSlice)-1]
|
/*
|
||||||
if tableName != "" {
|
Identical to Query, but returns concrete values instead of pointers. More convenient
|
||||||
fullName = tableName + "." + fullName
|
for primitive types.
|
||||||
|
*/
|
||||||
|
func QueryScalar[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) ([]T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
columns = append(columns, fullName)
|
defer rows.Close()
|
||||||
|
|
||||||
|
var result []T
|
||||||
|
for {
|
||||||
|
val, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, *val)
|
||||||
}
|
}
|
||||||
|
|
||||||
columnNamesString := strings.Join(columns, ", ")
|
return result, nil
|
||||||
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
}
|
||||||
|
|
||||||
rows, err := conn.Query(ctx, query, args...)
|
/*
|
||||||
|
Identical to QueryScalar, but returns only the first result value. If there are
|
||||||
|
no rows in the result set, returns NotFound.
|
||||||
|
*/
|
||||||
|
func QueryOneScalar[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
var zero T
|
||||||
|
return zero, NotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return *result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use.
|
||||||
|
*/
|
||||||
|
func QueryIterator[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (*Iterator[T], error) {
|
||||||
|
var destExample T
|
||||||
|
destType := reflect.TypeOf(destExample)
|
||||||
|
|
||||||
|
compiled := compileQuery(query, destType)
|
||||||
|
|
||||||
|
rows, err := conn.Query(ctx, compiled.query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
panic("query exceeded its deadline")
|
panic("query exceeded its deadline")
|
||||||
|
@ -268,10 +192,11 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
it := &StructQueryIterator{
|
it := &Iterator[T]{
|
||||||
fieldPaths: fieldPaths,
|
fieldPaths: compiled.fieldPaths,
|
||||||
rows: rows,
|
rows: rows,
|
||||||
destType: destType,
|
destType: compiled.destType,
|
||||||
|
destTypeIsScalar: typeIsQueryable(compiled.destType),
|
||||||
closed: make(chan struct{}, 1),
|
closed: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,16 +217,72 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
|
||||||
return it, nil
|
return it, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
|
// TODO: QueryFunc?
|
||||||
var columnNames [][]string
|
|
||||||
var fieldPaths [][]int
|
type compiledQuery struct {
|
||||||
|
query string
|
||||||
|
destType reflect.Type
|
||||||
|
fieldPaths []fieldPath
|
||||||
|
}
|
||||||
|
|
||||||
|
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
|
||||||
|
|
||||||
|
func compileQuery(query string, destType reflect.Type) compiledQuery {
|
||||||
|
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
|
||||||
|
hasColumnsPlaceholder := columnsMatch != nil
|
||||||
|
|
||||||
|
if hasColumnsPlaceholder {
|
||||||
|
// The presence of the $columns placeholder means that the destination type
|
||||||
|
// must be a struct, and we will plonk that struct's fields into the query.
|
||||||
|
|
||||||
|
if destType.Kind() != reflect.Struct {
|
||||||
|
panic("$columns can only be used when querying into a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix []string
|
||||||
|
prefixText := columnsMatch[2]
|
||||||
|
if prefixText != "" {
|
||||||
|
prefix = []string{prefixText}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(columnNames))
|
||||||
|
for _, strSlice := range columnNames {
|
||||||
|
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
||||||
|
fullName := strSlice[len(strSlice)-1]
|
||||||
|
if tableName != "" {
|
||||||
|
fullName = tableName + "." + fullName
|
||||||
|
}
|
||||||
|
columns = append(columns, fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNamesString := strings.Join(columns, ", ")
|
||||||
|
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
|
||||||
|
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
fieldPaths: fieldPaths,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
|
||||||
|
var columnNames []columnName
|
||||||
|
var fieldPaths []fieldPath
|
||||||
|
|
||||||
if destType.Kind() == reflect.Ptr {
|
if destType.Kind() == reflect.Ptr {
|
||||||
destType = destType.Elem()
|
destType = destType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if destType.Kind() != reflect.Struct {
|
if destType.Kind() != reflect.Struct {
|
||||||
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
|
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnonPrefix struct {
|
type AnonPrefix struct {
|
||||||
|
@ -348,108 +329,214 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
|
||||||
columnNames = append(columnNames, fieldColumnNames)
|
columnNames = append(columnNames, fieldColumnNames)
|
||||||
fieldPaths = append(fieldPaths, path)
|
fieldPaths = append(fieldPaths, path)
|
||||||
} else if fieldType.Kind() == reflect.Struct {
|
} else if fieldType.Kind() == reflect.Struct {
|
||||||
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
columnNames = append(columnNames, subCols...)
|
columnNames = append(columnNames, subCols...)
|
||||||
fieldPaths = append(fieldPaths, subPaths...)
|
fieldPaths = append(fieldPaths, subPaths...)
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
|
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columnNames, fieldPaths, nil
|
return columnNames, fieldPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
A general error to be used when no results are found. This is the error returned
|
Values of these kinds are ok to query even if they are not directly understood by pgtype.
|
||||||
by QueryOne, and can generally be used by other database helpers that fetch a single
|
This is common for custom types like:
|
||||||
result but find nothing.
|
|
||||||
|
type ThreadType int
|
||||||
*/
|
*/
|
||||||
var NotFound = errors.New("not found")
|
var queryableKinds = []reflect.Kind{
|
||||||
|
reflect.Int,
|
||||||
func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) {
|
|
||||||
rows, err := QueryIterator(ctx, conn, destExample, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
result, hasRow := rows.Next()
|
|
||||||
if !hasRow {
|
|
||||||
return nil, NotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) {
|
/*
|
||||||
rows, err := conn.Query(ctx, query, args...)
|
Checks if we are able to handle a particular type in a database query. This applies only to
|
||||||
if err != nil {
|
primitive types and not structs, since the database only returns individual primitive types
|
||||||
return nil, err
|
and it is our job to stitch them back together into structs later.
|
||||||
}
|
*/
|
||||||
defer rows.Close()
|
func typeIsQueryable(t reflect.Type) bool {
|
||||||
|
_, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
|
||||||
|
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
|
||||||
|
|
||||||
if rows.Next() {
|
if isRecognizedByPgtype {
|
||||||
vals, err := rows.Values()
|
return true
|
||||||
|
} else if t == reflect.TypeOf(uuid.UUID{}) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// pgtype doesn't recognize it, but maybe it's a primitive type we can deal with
|
||||||
|
k := t.Kind()
|
||||||
|
for _, qk := range queryableKinds {
|
||||||
|
if k == qk {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type columnName []string
|
||||||
|
|
||||||
|
// A path to a particular field in query's destination type. Each index in the slice
|
||||||
|
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
|
||||||
|
type fieldPath []int
|
||||||
|
|
||||||
|
type Iterator[T any] struct {
|
||||||
|
fieldPaths []fieldPath
|
||||||
|
rows pgx.Rows
|
||||||
|
destType reflect.Type
|
||||||
|
destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function.
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *Iterator[T]) Next() (*T, bool) {
|
||||||
|
// TODO(ben): What happens if this panics? Does it leak resources? Do we need
|
||||||
|
// to put a recover() here and close the rows?
|
||||||
|
|
||||||
|
hasNext := it.rows.Next()
|
||||||
|
if !hasNext {
|
||||||
|
it.Close()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
result := reflect.New(it.destType)
|
||||||
|
|
||||||
|
vals, err := it.rows.Values()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if it.destTypeIsScalar {
|
||||||
|
// This type can be directly queried, meaning pgx recognizes it, it's
|
||||||
|
// a simple scalar thing, and we can just take the easy way out.
|
||||||
if len(vals) != 1 {
|
if len(vals) != 1 {
|
||||||
return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals))
|
panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals)))
|
||||||
|
}
|
||||||
|
setValueFromDB(result.Elem(), reflect.ValueOf(vals[0]))
|
||||||
|
return result.Interface().(*T), true
|
||||||
|
} else {
|
||||||
|
var currentField reflect.StructField
|
||||||
|
var currentValue reflect.Value
|
||||||
|
var currentIdx int
|
||||||
|
|
||||||
|
// Better logging of panics in this confusing reflection process
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if currentValue.IsValid() {
|
||||||
|
logging.Error().
|
||||||
|
Int("index", currentIdx).
|
||||||
|
Str("field name", currentField.Name).
|
||||||
|
Stringer("field type", currentField.Type).
|
||||||
|
Interface("value", currentValue.Interface()).
|
||||||
|
Stringer("value type", currentValue.Type()).
|
||||||
|
Msg("panic in iterator")
|
||||||
}
|
}
|
||||||
|
|
||||||
return vals[0], nil
|
if currentField.Name != "" {
|
||||||
|
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
|
||||||
|
} else {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i, val := range vals {
|
||||||
|
currentIdx = i
|
||||||
|
if val == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, NotFound
|
var field reflect.Value
|
||||||
|
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
|
||||||
|
if field.Kind() == reflect.Ptr {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field = field.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
|
||||||
|
// Regardless, we know it's not nil, so we can get at the contents.
|
||||||
|
valReflected := reflect.ValueOf(val)
|
||||||
|
if valReflected.Kind() == reflect.Ptr {
|
||||||
|
valReflected = valReflected.Elem()
|
||||||
|
}
|
||||||
|
currentValue = valReflected
|
||||||
|
|
||||||
|
setValueFromDB(field, valReflected)
|
||||||
|
|
||||||
|
currentField = reflect.StructField{}
|
||||||
|
currentValue = reflect.Value{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Interface().(*T), true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) {
|
func setValueFromDB(dest reflect.Value, value reflect.Value) {
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
switch dest.Kind() {
|
||||||
if err != nil {
|
case reflect.Int:
|
||||||
return "", err
|
dest.SetInt(value.Int())
|
||||||
}
|
|
||||||
|
|
||||||
switch r := result.(type) {
|
|
||||||
case string:
|
|
||||||
return r, nil
|
|
||||||
default:
|
default:
|
||||||
return "", oops.New(nil, "QueryString got a non-string result: %v", result)
|
dest.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) {
|
func (it *Iterator[any]) Close() {
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
it.rows.Close()
|
||||||
if err != nil {
|
select {
|
||||||
return 0, err
|
case it.closed <- struct{}{}:
|
||||||
}
|
|
||||||
|
|
||||||
switch r := result.(type) {
|
|
||||||
case int:
|
|
||||||
return r, nil
|
|
||||||
case int32:
|
|
||||||
return int(r), nil
|
|
||||||
case int64:
|
|
||||||
return int(r), nil
|
|
||||||
default:
|
default:
|
||||||
return 0, oops.New(nil, "QueryInt got a non-int result: %v", result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) {
|
/*
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
Pulls all the remaining values into a slice, and closes the iterator.
|
||||||
|
*/
|
||||||
|
func (it *Iterator[T]) ToSlice() []*T {
|
||||||
|
defer it.Close()
|
||||||
|
var result []*T
|
||||||
|
for {
|
||||||
|
row, ok := it.Next()
|
||||||
|
if !ok {
|
||||||
|
err := it.rows.Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
panic(oops.New(err, "error while iterating through db results"))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, row)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
|
||||||
|
if len(path) < 1 {
|
||||||
|
panic(oops.New(nil, "can't follow an empty path"))
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r := result.(type) {
|
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
|
||||||
case bool:
|
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
|
||||||
return r, nil
|
|
||||||
default:
|
|
||||||
return false, oops.New(nil, "QueryBool got a non-bool result: %v", result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// more informative panic recovery
|
||||||
|
var field reflect.StructField
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
val := structPtrVal
|
||||||
|
for _, i := range path {
|
||||||
|
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
|
||||||
|
if val.IsNil() {
|
||||||
|
val.Set(reflect.New(val.Type().Elem()))
|
||||||
|
}
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
field = val.Type().Field(i)
|
||||||
|
val = val.Field(i)
|
||||||
|
}
|
||||||
|
return val, field
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,13 +10,90 @@ import (
|
||||||
|
|
||||||
func TestPaths(t *testing.T) {
|
func TestPaths(t *testing.T) {
|
||||||
type CustomInt int
|
type CustomInt int
|
||||||
|
type S2 struct {
|
||||||
|
B bool `db:"B"` // field 0
|
||||||
|
PB *bool `db:"PB"` // field 1
|
||||||
|
|
||||||
|
NoTag string // field 2
|
||||||
|
}
|
||||||
|
type S struct {
|
||||||
|
I int `db:"I"` // field 0
|
||||||
|
PI *int `db:"PI"` // field 1
|
||||||
|
CI CustomInt `db:"CI"` // field 2
|
||||||
|
PCI *CustomInt `db:"PCI"` // field 3
|
||||||
|
S2 `db:"S2"` // field 4 (embedded!)
|
||||||
|
PS2 *S2 `db:"PS2"` // field 5
|
||||||
|
|
||||||
|
NoTag int // field 6
|
||||||
|
}
|
||||||
|
type Nested struct {
|
||||||
|
S S `db:"S"` // field 0
|
||||||
|
PS *S `db:"PS"` // field 1
|
||||||
|
|
||||||
|
NoTag S // field 2
|
||||||
|
}
|
||||||
|
type Embedded struct {
|
||||||
|
NoTag S // field 0
|
||||||
|
Nested // field 1
|
||||||
|
}
|
||||||
|
|
||||||
|
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
|
||||||
|
assert.Equal(t, []columnName{
|
||||||
|
{"S", "I"}, {"S", "PI"},
|
||||||
|
{"S", "CI"}, {"S", "PCI"},
|
||||||
|
{"S", "S2", "B"}, {"S", "S2", "PB"},
|
||||||
|
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
|
||||||
|
{"PS", "I"}, {"PS", "PI"},
|
||||||
|
{"PS", "CI"}, {"PS", "PCI"},
|
||||||
|
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
|
||||||
|
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
|
||||||
|
}, names)
|
||||||
|
assert.Equal(t, []fieldPath{
|
||||||
|
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
|
||||||
|
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
|
||||||
|
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
|
||||||
|
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
|
||||||
|
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
|
||||||
|
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
|
||||||
|
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
|
||||||
|
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
|
||||||
|
}, paths)
|
||||||
|
assert.True(t, len(names) == len(paths))
|
||||||
|
|
||||||
|
testStruct := Embedded{}
|
||||||
|
for i, path := range paths {
|
||||||
|
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
||||||
|
assert.True(t, val.IsValid())
|
||||||
|
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileQuery(t *testing.T) {
|
||||||
|
t.Run("simple struct", func(t *testing.T) {
|
||||||
|
type Dest struct {
|
||||||
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
|
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
|
||||||
|
})
|
||||||
|
t.Run("complex structs", func(t *testing.T) {
|
||||||
|
type CustomInt int
|
||||||
|
type S2 struct {
|
||||||
|
B bool `db:"B"`
|
||||||
|
PB *bool `db:"PB"`
|
||||||
|
|
||||||
|
NoTag string
|
||||||
|
}
|
||||||
type S struct {
|
type S struct {
|
||||||
I int `db:"I"`
|
I int `db:"I"`
|
||||||
PI *int `db:"PI"`
|
PI *int `db:"PI"`
|
||||||
CI CustomInt `db:"CI"`
|
CI CustomInt `db:"CI"`
|
||||||
PCI *CustomInt `db:"PCI"`
|
PCI *CustomInt `db:"PCI"`
|
||||||
B bool `db:"B"`
|
S2 `db:"S2"` // embedded!
|
||||||
PB *bool `db:"PB"`
|
PS2 *S2 `db:"PS2"`
|
||||||
|
|
||||||
NoTag int
|
NoTag int
|
||||||
}
|
}
|
||||||
|
@ -26,34 +103,48 @@ func TestPaths(t *testing.T) {
|
||||||
|
|
||||||
NoTag S
|
NoTag S
|
||||||
}
|
}
|
||||||
type Embedded struct {
|
type Dest struct {
|
||||||
NoTag S
|
NoTag S
|
||||||
Nested
|
Nested
|
||||||
}
|
}
|
||||||
|
|
||||||
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
if assert.Nil(t, err) {
|
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
|
||||||
assert.Equal(t, []string{
|
})
|
||||||
"S.I", "S.PI",
|
t.Run("int", func(t *testing.T) {
|
||||||
"S.CI", "S.PCI",
|
type Dest int
|
||||||
"S.B", "S.PB",
|
|
||||||
"PS.I", "PS.PI",
|
// There should be no error here because we do not need to extract columns from
|
||||||
"PS.CI", "PS.PCI",
|
// the destination type. There may be errors down the line in value iteration, but
|
||||||
"PS.B", "PS.PB",
|
// that is always the case if the Go types don't match the query.
|
||||||
}, names)
|
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
assert.Equal(t, [][]int{
|
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
|
||||||
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
|
})
|
||||||
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
|
t.Run("just one table", func(t *testing.T) {
|
||||||
}, paths)
|
type Dest struct {
|
||||||
assert.True(t, len(names) == len(paths))
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
}
|
}
|
||||||
|
|
||||||
testStruct := Embedded{}
|
// The prefix is necessary because otherwise we would have to provide a struct with
|
||||||
for i, path := range paths {
|
// a db tag in order to provide the query with the `greeblies.` prefix in the
|
||||||
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
// final query. This comes up a lot when we do a JOIN to help with a condition, but
|
||||||
assert.True(t, val.IsValid())
|
// don't actually care about any of the data we joined to.
|
||||||
assert.True(t, strings.Contains(names[i], field.Name))
|
compiled := compileQuery(
|
||||||
}
|
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
|
||||||
|
reflect.TypeOf(Dest{}),
|
||||||
|
)
|
||||||
|
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
|
||||||
|
type Dest int
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryBuilder(t *testing.T) {
|
func TestQueryBuilder(t *testing.T) {
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*
|
||||||
|
This package contains lowish-level APIs for making database queries to our Postgres database. It streamlines the process of mapping query results to Go types, while allowing you to write arbitrary SQL queries.
|
||||||
|
|
||||||
|
The primary functions are Query and QueryIterator. See the package and function examples for detailed usage.
|
||||||
|
|
||||||
|
Query syntax
|
||||||
|
|
||||||
|
This package allows a few small extensions to SQL syntax to streamline the interaction between Go and Postgres.
|
||||||
|
|
||||||
|
Arguments can be provided using placeholders like $1, $2, etc. All arguments will be safely escaped and mapped from their Go type to the correct Postgres type. (This is a direct proxy to pgx.)
|
||||||
|
|
||||||
|
projectIDs, err := db.Query[int](ctx, conn,
|
||||||
|
`
|
||||||
|
SELECT id
|
||||||
|
FROM handmade_project
|
||||||
|
WHERE
|
||||||
|
slug = ANY($1)
|
||||||
|
AND hidden = $2
|
||||||
|
`,
|
||||||
|
[]string{"4coder", "metadesk"},
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
(This also demonstrates a useful tip: if you want to use a slice in your query, use Postgres arrays instead of IN.)
|
||||||
|
|
||||||
|
When querying individual fields, you can simply select the field like so:
|
||||||
|
|
||||||
|
ids, err := db.Query[int](ctx, conn, `SELECT id FROM handmade_project`)
|
||||||
|
|
||||||
|
To query multiple columns at once, you may use a struct type with `db:"column_name"` tags, and the special $columns placeholder:
|
||||||
|
|
||||||
|
type Project struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Slug string `db:"slug"`
|
||||||
|
DateCreated time.Time `db:"date_created"`
|
||||||
|
}
|
||||||
|
projects, err := db.Query[Project](ctx, conn, `SELECT $columns FROM ...`)
|
||||||
|
// Resulting query:
|
||||||
|
// SELECT id, slug, date_created FROM ...
|
||||||
|
|
||||||
|
Sometimes a table name prefix is required on each column to disambiguate between column names, especially when performing a JOIN. In those situations, you can include the prefix in the $columns placeholder like $columns{prefix}:
|
||||||
|
|
||||||
|
type Project struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Slug string `db:"slug"`
|
||||||
|
DateCreated time.Time `db:"date_created"`
|
||||||
|
}
|
||||||
|
orphanedProjects, err := db.Query[Project](ctx, conn, `
|
||||||
|
SELECT $columns{projects}
|
||||||
|
FROM
|
||||||
|
handmade_project AS projects
|
||||||
|
LEFT JOIN handmade_user_projects AS uproj
|
||||||
|
WHERE
|
||||||
|
uproj.user_id IS NULL
|
||||||
|
`)
|
||||||
|
// Resulting query:
|
||||||
|
// SELECT projects.id, projects.slug, projects.date_created FROM ...
|
||||||
|
*/
|
||||||
|
package db
|
|
@ -7,7 +7,7 @@ import (
|
||||||
|
|
||||||
type QueryBuilder struct {
|
type QueryBuilder struct {
|
||||||
sql strings.Builder
|
sql strings.Builder
|
||||||
args []interface{}
|
args []any
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number.
|
||||||
foo ARG1 bar ARG2 baz $?
|
foo ARG1 bar ARG2 baz $?
|
||||||
foo ARG1 bar ARG2 baz ARG3
|
foo ARG1 bar ARG2 baz ARG3
|
||||||
*/
|
*/
|
||||||
func (qb *QueryBuilder) Add(sql string, args ...interface{}) {
|
func (qb *QueryBuilder) Add(sql string, args ...any) {
|
||||||
numPlaceholders := strings.Count(sql, "$?")
|
numPlaceholders := strings.Count(sql, "$?")
|
||||||
if numPlaceholders != len(args) {
|
if numPlaceholders != len(args) {
|
||||||
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
|
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
|
||||||
|
@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string {
|
||||||
return qb.sql.String()
|
return qb.sql.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qb *QueryBuilder) Args() []interface{} {
|
func (qb *QueryBuilder) Args() []any {
|
||||||
return qb.args
|
return qb.args
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,12 +90,9 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type profileResult struct {
|
hmnUser, err := db.QueryOne[models.User](ctx, bot.dbConn,
|
||||||
HMNUser models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
handmade_discorduser AS duser
|
handmade_discorduser AS duser
|
||||||
JOIN auth_user ON duser.hmn_user_id = auth_user.id
|
JOIN auth_user ON duser.hmn_user_id = auth_user.id
|
||||||
|
@ -122,16 +119,15 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := ires.(*profileResult)
|
|
||||||
|
|
||||||
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
|
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
|
||||||
OwnerIDs: []int{res.HMNUser.ID},
|
OwnerIDs: []int{hmnUser.ID},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects")
|
logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
url := hmnurl.BuildUserProfile(res.HMNUser.Username)
|
url := hmnurl.BuildUserProfile(hmnUser.Username)
|
||||||
msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url)
|
msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url)
|
||||||
if len(projectsAndStuff) > 0 {
|
if len(projectsAndStuff) > 0 {
|
||||||
projectNoun := "projects"
|
projectNoun := "projects"
|
||||||
|
|
|
@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error {
|
||||||
// an old one or starting a new one.
|
// an old one or starting a new one.
|
||||||
|
|
||||||
shouldResume := true
|
shouldResume := true
|
||||||
isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`)
|
session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.NotFound) {
|
if errors.Is(err, db.NotFound) {
|
||||||
// No session yet! Just identify and get on with it
|
// No session yet! Just identify and get on with it
|
||||||
|
@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error {
|
||||||
|
|
||||||
if shouldResume {
|
if shouldResume {
|
||||||
// Reconnect to the previous session
|
// Reconnect to the previous session
|
||||||
session := isession.(*models.DiscordSession)
|
|
||||||
|
|
||||||
err := bot.sendGatewayMessage(ctx, GatewayMessage{
|
err := bot.sendGatewayMessage(ctx, GatewayMessage{
|
||||||
Opcode: OpcodeResume,
|
Opcode: OpcodeResume,
|
||||||
Data: Resume{
|
Data: Resume{
|
||||||
|
@ -356,7 +354,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
}
|
}
|
||||||
bot.didAckHeartbeat = false
|
bot.didAckHeartbeat = false
|
||||||
|
|
||||||
latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`)
|
latestSequenceNumber, err := db.QueryOneScalar[int](ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to fetch latest sequence number from the db")
|
log.Error().Err(err).Msg("failed to fetch latest sequence number from the db")
|
||||||
return false
|
return false
|
||||||
|
@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, `
|
msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, `
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM discord_outgoingmessages
|
FROM discord_outgoingmessages
|
||||||
ORDER BY id ASC
|
ORDER BY id ASC
|
||||||
|
@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, imsg := range msgs {
|
for _, msg := range msgs {
|
||||||
msg := imsg.(*models.DiscordOutgoingMessage)
|
|
||||||
if time.Now().After(msg.ExpiresAt) {
|
if time.Now().After(msg.ExpiresAt) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,12 +73,9 @@ func RunHistoryWatcher(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{
|
||||||
func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
log := logging.ExtractLogger(ctx)
|
log := logging.ExtractLogger(ctx)
|
||||||
|
|
||||||
type query struct {
|
messagesWithoutContent, err := db.Query[models.DiscordMessage](ctx, dbConn,
|
||||||
Message models.DiscordMessage `db:"msg"`
|
|
||||||
}
|
|
||||||
imessagesWithoutContent, err := db.Query(ctx, dbConn, query{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{msg}
|
||||||
FROM
|
FROM
|
||||||
handmade_discordmessage AS msg
|
handmade_discordmessage AS msg
|
||||||
JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users
|
JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users
|
||||||
|
@ -95,10 +92,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(imessagesWithoutContent) > 0 {
|
if len(messagesWithoutContent) > 0 {
|
||||||
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent))
|
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent))
|
||||||
msgloop:
|
msgloop:
|
||||||
for _, imsg := range imessagesWithoutContent {
|
for _, msg := range messagesWithoutContent {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info().Msg("Scrape was canceled")
|
log.Info().Msg("Scrape was canceled")
|
||||||
|
@ -106,8 +103,6 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := imsg.(*query).Message
|
|
||||||
|
|
||||||
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
|
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
|
||||||
if errors.Is(err, NotFound) {
|
if errors.Is(err, NotFound) {
|
||||||
// This message has apparently been deleted; delete it from our database
|
// This message has apparently been deleted; delete it from our database
|
||||||
|
|
|
@ -165,7 +165,7 @@ func InternMessage(
|
||||||
dbConn db.ConnOrTx,
|
dbConn db.ConnOrTx,
|
||||||
msg *Message,
|
msg *Message,
|
||||||
) error {
|
) error {
|
||||||
_, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{},
|
_, err := db.QueryOne[models.DiscordMessage](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessage
|
FROM handmade_discordmessage
|
||||||
|
@ -219,7 +219,7 @@ type InternedMessage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
|
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
|
||||||
result, err := db.QueryOne(ctx, dbConn, InternedMessage{},
|
interned, err := db.QueryOne[InternedMessage](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -235,8 +235,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
interned := result.(*InternedMessage)
|
|
||||||
return interned, nil
|
return interned, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,7 +281,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
|
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
|
||||||
isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{},
|
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_snippet
|
FROM handmade_snippet
|
||||||
|
@ -294,10 +292,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In
|
||||||
if err != nil && !errors.Is(err, db.NotFound) {
|
if err != nil && !errors.Is(err, db.NotFound) {
|
||||||
return oops.New(err, "failed to fetch snippet for discord message")
|
return oops.New(err, "failed to fetch snippet for discord message")
|
||||||
}
|
}
|
||||||
var snippet *models.Snippet
|
|
||||||
if !errors.Is(err, db.NotFound) {
|
|
||||||
snippet = isnippet.(*models.Snippet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE(asaf): Also deletes the following through a db cascade:
|
// NOTE(asaf): Also deletes the following through a db cascade:
|
||||||
// * handmade_discordmessageattachment
|
// * handmade_discordmessageattachment
|
||||||
|
@ -367,7 +361,7 @@ func SaveMessageContents(
|
||||||
return oops.New(err, "failed to create or update message contents")
|
return oops.New(err, "failed to create or update message contents")
|
||||||
}
|
}
|
||||||
|
|
||||||
icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{},
|
content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -380,7 +374,7 @@ func SaveMessageContents(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to fetch message contents")
|
return oops.New(err, "failed to fetch message contents")
|
||||||
}
|
}
|
||||||
interned.MessageContent = icontent.(*models.DiscordMessageContent)
|
interned.MessageContent = content
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save attachments
|
// Save attachments
|
||||||
|
@ -395,7 +389,7 @@ func SaveMessageContents(
|
||||||
|
|
||||||
// Save / delete embeds
|
// Save / delete embeds
|
||||||
if msg.OriginalHasFields("embeds") {
|
if msg.OriginalHasFields("embeds") {
|
||||||
numSavedEmbeds, err := db.QueryInt(ctx, dbConn,
|
numSavedEmbeds, err := db.QueryOneScalar[int](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
|
@ -472,7 +466,7 @@ func saveAttachment(
|
||||||
hmnUserID int,
|
hmnUserID int,
|
||||||
discordMessageID string,
|
discordMessageID string,
|
||||||
) (*models.DiscordMessageAttachment, error) {
|
) (*models.DiscordMessageAttachment, error) {
|
||||||
iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
|
existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -481,7 +475,7 @@ func saveAttachment(
|
||||||
attachment.ID,
|
attachment.ID,
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return iexisting.(*models.DiscordMessageAttachment), nil
|
return existing, nil
|
||||||
} else if errors.Is(err, db.NotFound) {
|
} else if errors.Is(err, db.NotFound) {
|
||||||
// this is fine, just create it
|
// this is fine, just create it
|
||||||
} else {
|
} else {
|
||||||
|
@ -534,7 +528,7 @@ func saveAttachment(
|
||||||
return nil, oops.New(err, "failed to save Discord attachment data")
|
return nil, oops.New(err, "failed to save Discord attachment data")
|
||||||
}
|
}
|
||||||
|
|
||||||
iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
|
discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -546,7 +540,7 @@ func saveAttachment(
|
||||||
return nil, oops.New(err, "failed to fetch new Discord attachment data")
|
return nil, oops.New(err, "failed to fetch new Discord attachment data")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iDiscordAttachment.(*models.DiscordMessageAttachment), nil
|
return discordAttachment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
|
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
|
||||||
|
@ -636,7 +630,7 @@ func saveEmbed(
|
||||||
return nil, oops.New(err, "failed to insert new embed")
|
return nil, oops.New(err, "failed to insert new embed")
|
||||||
}
|
}
|
||||||
|
|
||||||
iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{},
|
discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
|
@ -648,11 +642,11 @@ func saveEmbed(
|
||||||
return nil, oops.New(err, "failed to fetch new Discord embed data")
|
return nil, oops.New(err, "failed to fetch new Discord embed data")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iDiscordEmbed.(*models.DiscordMessageEmbed), nil
|
return discordEmbed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
|
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
|
||||||
iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{},
|
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_snippet
|
FROM handmade_snippet
|
||||||
|
@ -669,7 +663,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return iresult.(*models.Snippet), nil
|
return snippet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -805,12 +799,9 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
|
||||||
projectIDs[i] = p.Project.ID
|
projectIDs[i] = p.Project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type tagsRow struct {
|
userTags, err := db.Query[models.Tag](ctx, tx,
|
||||||
Tag models.Tag `db:"tags"`
|
|
||||||
}
|
|
||||||
iUserTags, err := db.Query(ctx, tx, tagsRow{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{tags}
|
||||||
FROM
|
FROM
|
||||||
tags
|
tags
|
||||||
JOIN handmade_project AS project ON project.tag = tags.id
|
JOIN handmade_project AS project ON project.tag = tags.id
|
||||||
|
@ -823,8 +814,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
|
||||||
return oops.New(err, "failed to fetch tags for user projects")
|
return oops.New(err, "failed to fetch tags for user projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, itag := range iUserTags {
|
for _, tag := range userTags {
|
||||||
tag := itag.(*tagsRow).Tag
|
|
||||||
allTags = append(allTags, tag.ID)
|
allTags = append(allTags, tag.ID)
|
||||||
for _, messageTag := range messageTags {
|
for _, messageTag := range messageTags {
|
||||||
if strings.EqualFold(tag.Text, messageTag) {
|
if strings.EqualFold(tag.Text, messageTag) {
|
||||||
|
@ -890,7 +880,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\
|
||||||
|
|
||||||
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
|
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
|
||||||
// Check attachments
|
// Check attachments
|
||||||
attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{},
|
attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -901,13 +891,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, oops.New(err, "failed to fetch message attachments")
|
return nil, nil, oops.New(err, "failed to fetch message attachments")
|
||||||
}
|
}
|
||||||
for _, iattachment := range attachments {
|
for _, attachment := range attachments {
|
||||||
attachment := iattachment.(*models.DiscordMessageAttachment)
|
|
||||||
return &attachment.AssetID, nil, nil
|
return &attachment.AssetID, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check embeds
|
// Check embeds
|
||||||
embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{},
|
embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
|
@ -918,8 +907,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, oops.New(err, "failed to fetch discord embeds")
|
return nil, nil, oops.New(err, "failed to fetch discord embeds")
|
||||||
}
|
}
|
||||||
for _, iembed := range embeds {
|
for _, embed := range embeds {
|
||||||
embed := iembed.(*models.DiscordMessageEmbed)
|
|
||||||
if embed.VideoID != nil {
|
if embed.VideoID != nil {
|
||||||
return embed.VideoID, nil, nil
|
return embed.VideoID, nil, nil
|
||||||
} else if embed.ImageID != nil {
|
} else if embed.ImageID != nil {
|
||||||
|
|
|
@ -140,15 +140,15 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the query
|
// Do the query
|
||||||
iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...)
|
projectRows, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch projects")
|
return nil, oops.New(err, "failed to fetch projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch project owners to do permission checks
|
// Fetch project owners to do permission checks
|
||||||
projectIds := make([]int, len(iprojects))
|
projectIds := make([]int, len(projectRows))
|
||||||
for i, iproject := range iprojects {
|
for i, p := range projectRows {
|
||||||
projectIds[i] = iproject.(*projectRow).Project.ID
|
projectIds[i] = p.Project.ID
|
||||||
}
|
}
|
||||||
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
|
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -156,8 +156,7 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
var res []ProjectAndStuff
|
var res []ProjectAndStuff
|
||||||
for i, iproject := range iprojects {
|
for i, p := range projectRows {
|
||||||
row := iproject.(*projectRow)
|
|
||||||
owners := projectOwners[i].Owners
|
owners := projectOwners[i].Owners
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -191,10 +190,10 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
projectGenerallyVisible := true &&
|
projectGenerallyVisible := true &&
|
||||||
row.Project.Lifecycle.In(models.VisibleProjectLifecycles) &&
|
p.Project.Lifecycle.In(models.VisibleProjectLifecycles) &&
|
||||||
!row.Project.Hidden &&
|
!p.Project.Hidden &&
|
||||||
(!row.Project.Personal || allOwnersApproved || row.Project.IsHMN())
|
(!p.Project.Personal || allOwnersApproved || p.Project.IsHMN())
|
||||||
if row.Project.IsHMN() {
|
if p.Project.IsHMN() {
|
||||||
projectGenerallyVisible = true // hard override
|
projectGenerallyVisible = true // hard override
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,11 +204,11 @@ func FetchProjects(
|
||||||
|
|
||||||
if projectVisible {
|
if projectVisible {
|
||||||
res = append(res, ProjectAndStuff{
|
res = append(res, ProjectAndStuff{
|
||||||
Project: row.Project,
|
Project: p.Project,
|
||||||
LogoLightAsset: row.LogoLightAsset,
|
LogoLightAsset: p.LogoLightAsset,
|
||||||
LogoDarkAsset: row.LogoDarkAsset,
|
LogoDarkAsset: p.LogoDarkAsset,
|
||||||
Owners: owners,
|
Owners: owners,
|
||||||
Tag: row.Tag,
|
Tag: p.Tag,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners(
|
||||||
UserID int `db:"user_id"`
|
UserID int `db:"user_id"`
|
||||||
ProjectID int `db:"project_id"`
|
ProjectID int `db:"project_id"`
|
||||||
}
|
}
|
||||||
iuserprojects, err := db.Query(ctx, tx, userProject{},
|
userProjects, err := db.Query[userProject](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_user_projects
|
FROM handmade_user_projects
|
||||||
|
@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners(
|
||||||
|
|
||||||
// Get the unique user IDs from this set and fetch the users from the db
|
// Get the unique user IDs from this set and fetch the users from the db
|
||||||
var userIds []int
|
var userIds []int
|
||||||
for _, iuserproject := range iuserprojects {
|
for _, userProject := range userProjects {
|
||||||
userProject := iuserproject.(*userProject)
|
|
||||||
|
|
||||||
addUserId := true
|
addUserId := true
|
||||||
for _, uid := range userIds {
|
for _, uid := range userIds {
|
||||||
if uid == userProject.UserID {
|
if uid == userProject.UserID {
|
||||||
|
@ -361,13 +358,11 @@ func FetchMultipleProjectsOwners(
|
||||||
userIds = append(userIds, userProject.UserID)
|
userIds = append(userIds, userProject.UserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
type userQuery struct {
|
users, err := db.Query[models.User](ctx, tx,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
iusers, err := db.Query(ctx, tx, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE
|
WHERE
|
||||||
auth_user.id = ANY($1)
|
auth_user.id = ANY($1)
|
||||||
|
@ -383,9 +378,7 @@ func FetchMultipleProjectsOwners(
|
||||||
for i, pid := range projectIds {
|
for i, pid := range projectIds {
|
||||||
res[i] = ProjectOwners{ProjectID: pid}
|
res[i] = ProjectOwners{ProjectID: pid}
|
||||||
}
|
}
|
||||||
for _, iuserproject := range iuserprojects {
|
for _, userProject := range userProjects {
|
||||||
userProject := iuserproject.(*userProject)
|
|
||||||
|
|
||||||
// Get a pointer to the existing record in the result
|
// Get a pointer to the existing record in the result
|
||||||
var projectOwners *ProjectOwners
|
var projectOwners *ProjectOwners
|
||||||
for i := range res {
|
for i := range res {
|
||||||
|
@ -396,10 +389,9 @@ func FetchMultipleProjectsOwners(
|
||||||
|
|
||||||
// Get the full user record we fetched
|
// Get the full user record we fetched
|
||||||
var user *models.User
|
var user *models.User
|
||||||
for _, iuser := range iusers {
|
for _, u := range users {
|
||||||
u := iuser.(*userQuery).User
|
|
||||||
if u.ID == userProject.UserID {
|
if u.ID == userProject.UserID {
|
||||||
user = &u
|
user = u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
@ -473,7 +465,7 @@ func SetProjectTag(
|
||||||
resultTag = p.Tag
|
resultTag = p.Tag
|
||||||
} else if p.Project.TagID == nil {
|
} else if p.Project.TagID == nil {
|
||||||
// Create a tag
|
// Create a tag
|
||||||
itag, err := db.QueryOne(ctx, tx, models.Tag{},
|
tag, err := db.QueryOne[models.Tag](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO tags (text) VALUES ($1)
|
INSERT INTO tags (text) VALUES ($1)
|
||||||
RETURNING $columns
|
RETURNING $columns
|
||||||
|
@ -483,7 +475,7 @@ func SetProjectTag(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to create new tag for project")
|
return nil, oops.New(err, "failed to create new tag for project")
|
||||||
}
|
}
|
||||||
resultTag = itag.(*models.Tag)
|
resultTag = tag
|
||||||
|
|
||||||
// Attach it to the project
|
// Attach it to the project
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
|
@ -499,7 +491,7 @@ func SetProjectTag(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Update the text of an existing one
|
// Update the text of an existing one
|
||||||
itag, err := db.QueryOne(ctx, tx, models.Tag{},
|
tag, err := db.QueryOne[models.Tag](ctx, tx,
|
||||||
`
|
`
|
||||||
UPDATE tags
|
UPDATE tags
|
||||||
SET text = $1
|
SET text = $1
|
||||||
|
@ -511,7 +503,7 @@ func SetProjectTag(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to update existing tag")
|
return nil, oops.New(err, "failed to update existing tag")
|
||||||
}
|
}
|
||||||
resultTag = itag.(*models.Tag)
|
resultTag = tag
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit(ctx)
|
err = tx.Commit(ctx)
|
||||||
|
|
|
@ -44,10 +44,7 @@ func FetchSnippets(
|
||||||
|
|
||||||
if len(q.Tags) > 0 {
|
if len(q.Tags) > 0 {
|
||||||
// Get snippet IDs with this tag, then use that in the main query
|
// Get snippet IDs with this tag, then use that in the main query
|
||||||
type snippetIDRow struct {
|
snippetIDs, err := db.QueryScalar[int](ctx, tx,
|
||||||
SnippetID int `db:"snippet_id"`
|
|
||||||
}
|
|
||||||
iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{},
|
|
||||||
`
|
`
|
||||||
SELECT DISTINCT snippet_id
|
SELECT DISTINCT snippet_id
|
||||||
FROM
|
FROM
|
||||||
|
@ -63,14 +60,11 @@ func FetchSnippets(
|
||||||
}
|
}
|
||||||
|
|
||||||
// special early-out: no snippets found for these tags at all
|
// special early-out: no snippets found for these tags at all
|
||||||
if len(iSnippetIDs) == 0 {
|
if len(snippetIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
q.IDs = make([]int, len(iSnippetIDs))
|
q.IDs = snippetIDs
|
||||||
for i := range iSnippetIDs {
|
|
||||||
q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var qb db.QueryBuilder
|
var qb db.QueryBuilder
|
||||||
|
@ -125,16 +119,14 @@ func FetchSnippets(
|
||||||
DiscordMessage *models.DiscordMessage `db:"discord_message"`
|
DiscordMessage *models.DiscordMessage `db:"discord_message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...)
|
results, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch threads")
|
return nil, oops.New(err, "failed to fetch threads")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not
|
result := make([]SnippetAndStuff, len(results)) // allocate extra space because why not
|
||||||
snippetIDs := make([]int, len(iresults))
|
snippetIDs := make([]int, len(results))
|
||||||
for i, iresult := range iresults {
|
for i, row := range results {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
result[i] = SnippetAndStuff{
|
result[i] = SnippetAndStuff{
|
||||||
Snippet: row.Snippet,
|
Snippet: row.Snippet,
|
||||||
Owner: row.Owner,
|
Owner: row.Owner,
|
||||||
|
@ -150,7 +142,7 @@ func FetchSnippets(
|
||||||
SnippetID int `db:"snippet_tags.snippet_id"`
|
SnippetID int `db:"snippet_tags.snippet_id"`
|
||||||
Tag *models.Tag `db:"tags"`
|
Tag *models.Tag `db:"tags"`
|
||||||
}
|
}
|
||||||
iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{},
|
snippetTags, err := db.Query[snippetTagRow](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -170,8 +162,7 @@ func FetchSnippets(
|
||||||
for i := range result {
|
for i := range result {
|
||||||
resultBySnippetId[result[i].Snippet.ID] = &result[i]
|
resultBySnippetId[result[i].Snippet.ID] = &result[i]
|
||||||
}
|
}
|
||||||
for _, iSnippetTag := range iSnippetTags {
|
for _, snippetTag := range snippetTags {
|
||||||
snippetTag := iSnippetTag.(*snippetTagRow)
|
|
||||||
item := resultBySnippetId[snippetTag.SnippetID]
|
item := resultBySnippetId[snippetTag.SnippetID]
|
||||||
item.Tags = append(item.Tags, snippetTag.Tag)
|
item.Tags = append(item.Tags, snippetTag.Tag)
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,18 +40,11 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T
|
||||||
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...)
|
tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch tags")
|
return nil, oops.New(err, "failed to fetch tags")
|
||||||
}
|
}
|
||||||
|
return tags, nil
|
||||||
res := make([]*models.Tag, len(itags))
|
|
||||||
for i, itag := range itags {
|
|
||||||
tag := itag.(*models.Tag)
|
|
||||||
res[i] = tag
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {
|
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {
|
||||||
|
|
|
@ -145,15 +145,13 @@ func FetchThreads(
|
||||||
ForumLastReadTime *time.Time `db:"slri.lastread"`
|
ForumLastReadTime *time.Time `db:"slri.lastread"`
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
|
rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch threads")
|
return nil, oops.New(err, "failed to fetch threads")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]ThreadAndStuff, len(iresults))
|
result := make([]ThreadAndStuff, len(rows))
|
||||||
for i, iresult := range iresults {
|
for i, row := range rows {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
hasRead := false
|
hasRead := false
|
||||||
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
|
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
|
||||||
hasRead = true
|
hasRead = true
|
||||||
|
@ -263,7 +261,7 @@ func CountThreads(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...)
|
count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, oops.New(err, "failed to fetch count of threads")
|
return 0, oops.New(err, "failed to fetch count of threads")
|
||||||
}
|
}
|
||||||
|
@ -405,15 +403,13 @@ func FetchPosts(
|
||||||
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
|
rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch posts")
|
return nil, oops.New(err, "failed to fetch posts")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]PostAndStuff, len(iresults))
|
result := make([]PostAndStuff, len(rows))
|
||||||
for i, iresult := range iresults {
|
for i, row := range rows {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
hasRead := false
|
hasRead := false
|
||||||
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
|
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
|
||||||
hasRead = true
|
hasRead = true
|
||||||
|
@ -595,7 +591,7 @@ func CountPosts(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...)
|
count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, oops.New(err, "failed to count posts")
|
return 0, oops.New(err, "failed to count posts")
|
||||||
}
|
}
|
||||||
|
@ -608,12 +604,9 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
type postResult struct {
|
authorID, err := db.QueryOneScalar[*int](ctx, connOrTx,
|
||||||
AuthorID *int `db:"post.author_id"`
|
|
||||||
}
|
|
||||||
iresult, err := db.QueryOne(ctx, connOrTx, postResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT post.author_id
|
||||||
FROM
|
FROM
|
||||||
handmade_post AS post
|
handmade_post AS post
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -629,9 +622,8 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
|
||||||
panic(oops.New(err, "failed to get author of post when checking permissions"))
|
panic(oops.New(err, "failed to get author of post when checking permissions"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result := iresult.(*postResult)
|
|
||||||
|
|
||||||
return result.AuthorID != nil && *result.AuthorID == user.ID
|
return authorID != nil && *authorID == user.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateNewPost(
|
func CreateNewPost(
|
||||||
|
@ -709,7 +701,7 @@ func DeletePost(
|
||||||
FirstPostID int `db:"first_id"`
|
FirstPostID int `db:"first_id"`
|
||||||
Deleted bool `db:"deleted"`
|
Deleted bool `db:"deleted"`
|
||||||
}
|
}
|
||||||
ti, err := db.QueryOne(ctx, tx, threadInfo{},
|
info, err := db.QueryOne[threadInfo](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -722,7 +714,6 @@ func DeletePost(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(oops.New(err, "failed to fetch thread info"))
|
panic(oops.New(err, "failed to fetch thread info"))
|
||||||
}
|
}
|
||||||
info := ti.(*threadInfo)
|
|
||||||
if info.Deleted {
|
if info.Deleted {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -848,12 +839,9 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
|
||||||
keys = append(keys, key)
|
keys = append(keys, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type assetId struct {
|
assetIDs, err := db.QueryScalar[uuid.UUID](ctx, tx,
|
||||||
AssetID uuid.UUID `db:"id"`
|
|
||||||
}
|
|
||||||
assetResult, err := db.Query(ctx, tx, assetId{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT id
|
||||||
FROM handmade_asset
|
FROM handmade_asset
|
||||||
WHERE s3_key = ANY($1)
|
WHERE s3_key = ANY($1)
|
||||||
`,
|
`,
|
||||||
|
@ -865,8 +853,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
|
||||||
|
|
||||||
var values [][]interface{}
|
var values [][]interface{}
|
||||||
|
|
||||||
for _, asset := range assetResult {
|
for _, assetID := range assetIDs {
|
||||||
values = append(values, []interface{}{postId, asset.(*assetId).AssetID})
|
values = append(values, []interface{}{postId, assetID})
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
|
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
|
||||||
|
@ -886,7 +874,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more.
|
||||||
You should probably mark the thread as deleted in this case.
|
You should probably mark the thread as deleted in this case.
|
||||||
*/
|
*/
|
||||||
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
||||||
postsIter, err := db.Query(ctx, tx, models.Post{},
|
posts, err := db.Query[models.Post](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_post
|
FROM handmade_post
|
||||||
|
@ -901,9 +889,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var firstPost, lastPost *models.Post
|
var firstPost, lastPost *models.Post
|
||||||
for _, ipost := range postsIter {
|
for _, post := range posts {
|
||||||
post := ipost.(*models.Post)
|
|
||||||
|
|
||||||
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
|
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
|
||||||
firstPost = post
|
firstPost = post
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,12 +22,9 @@ type TwitchStreamer struct {
|
||||||
var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P<login>[^/]+)$`)
|
var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P<login>[^/]+)$`)
|
||||||
|
|
||||||
func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) {
|
func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) {
|
||||||
type linkResult struct {
|
dbStreamers, err := db.Query[models.Link](ctx, dbConn,
|
||||||
Link models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
streamers, err := db.Query(ctx, dbConn, linkResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{link}
|
||||||
FROM
|
FROM
|
||||||
handmade_links AS link
|
handmade_links AS link
|
||||||
LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id
|
LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id
|
||||||
|
@ -49,10 +46,8 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre
|
||||||
return nil, oops.New(err, "failed to fetch twitch links")
|
return nil, oops.New(err, "failed to fetch twitch links")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]TwitchStreamer, 0, len(streamers))
|
result := make([]TwitchStreamer, 0, len(dbStreamers))
|
||||||
for _, s := range streamers {
|
for _, dbStreamer := range dbStreamers {
|
||||||
dbStreamer := s.(*linkResult).Link
|
|
||||||
|
|
||||||
streamer := TwitchStreamer{
|
streamer := TwitchStreamer{
|
||||||
UserID: dbStreamer.UserID,
|
UserID: dbStreamer.UserID,
|
||||||
ProjectID: dbStreamer.ProjectID,
|
ProjectID: dbStreamer.ProjectID,
|
||||||
|
@ -81,7 +76,7 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) {
|
func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) {
|
||||||
links, err := db.Query(ctx, dbConn, models.Link{},
|
links, err := db.Query[models.Link](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -100,8 +95,7 @@ func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx,
|
||||||
result := make([]string, 0, len(links))
|
result := make([]string, 0, len(links))
|
||||||
|
|
||||||
for _, l := range links {
|
for _, l := range links {
|
||||||
url := l.(*models.Link).URL
|
match := twitchRegex.FindStringSubmatch(l.URL)
|
||||||
match := twitchRegex.FindStringSubmatch(url)
|
|
||||||
if match != nil {
|
if match != nil {
|
||||||
login := strings.ToLower(match[twitchRegex.SubexpIndex("login")])
|
login := strings.ToLower(match[twitchRegex.SubexpIndex("login")])
|
||||||
result = append(result, login)
|
result = append(result, login)
|
||||||
|
|
|
@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error {
|
||||||
// Port "jam snippets" to use a tag
|
// Port "jam snippets" to use a tag
|
||||||
//
|
//
|
||||||
|
|
||||||
jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
|
jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to create jam tag")
|
return oops.New(err, "failed to create jam tag")
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
type subforumRow struct {
|
subforums, err := db.Query[Subforum](ctx, conn,
|
||||||
Subforum Subforum `db:"sf"`
|
|
||||||
}
|
|
||||||
rowsSlice, err := db.Query(ctx, conn, subforumRow{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM handmade_subforum
|
||||||
handmade_subforum as sf
|
|
||||||
ORDER BY sort, id ASC
|
ORDER BY sort, id ASC
|
||||||
`,
|
`,
|
||||||
)
|
)
|
||||||
|
@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
panic(oops.New(err, "failed to fetch subforum tree"))
|
panic(oops.New(err, "failed to fetch subforum tree"))
|
||||||
}
|
}
|
||||||
|
|
||||||
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice))
|
sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums))
|
||||||
for _, row := range rowsSlice {
|
for _, sf := range subforums {
|
||||||
sf := row.(*subforumRow).Subforum
|
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf}
|
||||||
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range sfTreeMap {
|
for _, node := range sfTreeMap {
|
||||||
|
@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, row := range rowsSlice {
|
for _, cat := range subforums {
|
||||||
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
|
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
|
||||||
cat := row.(*subforumRow).Subforum
|
|
||||||
node := sfTreeMap[cat.ID]
|
node := sfTreeMap[cat.ID]
|
||||||
if node.Parent != nil {
|
if node.Parent != nil {
|
||||||
node.Parent.Children = append(node.Parent.Children, node)
|
node.Parent.Children = append(node.Parent.Children, node)
|
||||||
|
|
|
@ -440,7 +440,7 @@ func updateStreamStatusInDB(ctx context.Context, conn db.ConnOrTx, status *strea
|
||||||
inserted := false
|
inserted := false
|
||||||
if isStatusRelevant(status) {
|
if isStatusRelevant(status) {
|
||||||
log.Debug().Msg("Status relevant")
|
log.Debug().Msg("Status relevant")
|
||||||
_, err := db.QueryOne(ctx, conn, models.TwitchStream{},
|
_, err := db.QueryOne[models.TwitchStream](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM twitch_streams
|
FROM twitch_streams
|
||||||
|
|
|
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
userIds = append(userIds, u.User.ID)
|
userIds = append(userIds, u.User.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ul := range userLinks {
|
for _, link := range userLinks {
|
||||||
link := ul.(*models.Link)
|
|
||||||
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
||||||
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
||||||
}
|
}
|
||||||
|
@ -257,12 +256,9 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
return RejectRequest(c, "User id can't be parsed")
|
return RejectRequest(c, "User id can't be parsed")
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE auth_user.id = $1
|
WHERE auth_user.id = $1
|
||||||
|
@ -276,7 +272,6 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := u.(*userQuery).User
|
|
||||||
|
|
||||||
whatHappened := ""
|
whatHappened := ""
|
||||||
if action == ApprovalQueueActionApprove {
|
if action == ApprovalQueueActionApprove {
|
||||||
|
@ -337,7 +332,7 @@ type UnapprovedPost struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
|
posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -358,11 +353,7 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved posts")
|
return nil, oops.New(err, "failed to fetch unapproved posts")
|
||||||
}
|
}
|
||||||
var res []*UnapprovedPost
|
return posts, nil
|
||||||
for _, iresult := range it {
|
|
||||||
res = append(res, iresult.(*UnapprovedPost))
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UnapprovedProject struct {
|
type UnapprovedProject struct {
|
||||||
|
@ -372,12 +363,9 @@ type UnapprovedProject struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
type unapprovedUser struct {
|
ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn,
|
||||||
ID int `db:"id"`
|
|
||||||
}
|
|
||||||
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT id
|
||||||
FROM
|
FROM
|
||||||
auth_user AS u
|
auth_user AS u
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -388,10 +376,6 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved users")
|
return nil, oops.New(err, "failed to fetch unapproved users")
|
||||||
}
|
}
|
||||||
ownerIDs := make([]int, 0, len(it))
|
|
||||||
for _, uid := range it {
|
|
||||||
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||||
OwnerIDs: ownerIDs,
|
OwnerIDs: ownerIDs,
|
||||||
|
@ -406,7 +390,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
projectIDs = append(projectIDs, p.Project.ID)
|
projectIDs = append(projectIDs, p.Project.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
projectLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -425,8 +409,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
|
|
||||||
for idx, proj := range projects {
|
for idx, proj := range projects {
|
||||||
links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects.
|
links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects.
|
||||||
for _, l := range projectLinks {
|
for _, link := range projectLinks {
|
||||||
link := l.(*models.Link)
|
|
||||||
if *link.ProjectID == proj.Project.ID {
|
if *link.ProjectID == proj.Project.ID {
|
||||||
links = append(links, link)
|
links = append(links, link)
|
||||||
}
|
}
|
||||||
|
@ -455,7 +438,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int)
|
||||||
ThreadID int `db:"thread.id"`
|
ThreadID int `db:"thread.id"`
|
||||||
PostID int `db:"post.id"`
|
PostID int `db:"post.id"`
|
||||||
}
|
}
|
||||||
it, err := db.Query(ctx, tx, toDelete{},
|
rows, err := db.Query[toDelete](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -471,8 +454,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int)
|
||||||
return oops.New(err, "failed to fetch posts to delete for user")
|
return oops.New(err, "failed to fetch posts to delete for user")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, iResult := range it {
|
for _, row := range rows {
|
||||||
row := iResult.(*toDelete)
|
|
||||||
hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID)
|
hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID)
|
||||||
}
|
}
|
||||||
err = tx.Commit(ctx)
|
err = tx.Commit(ctx)
|
||||||
|
@ -489,9 +471,9 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
toDelete, err := db.Query(ctx, tx, models.Project{},
|
projectIDsToDelete, err := db.QueryScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT project.id
|
||||||
FROM
|
FROM
|
||||||
handmade_project AS project
|
handmade_project AS project
|
||||||
JOIN handmade_user_projects AS up ON up.project_id = project.id
|
JOIN handmade_user_projects AS up ON up.project_id = project.id
|
||||||
|
@ -504,17 +486,12 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in
|
||||||
return oops.New(err, "failed to fetch user's projects")
|
return oops.New(err, "failed to fetch user's projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
var projectIds []int
|
if len(projectIDsToDelete) > 0 {
|
||||||
for _, p := range toDelete {
|
|
||||||
projectIds = append(projectIds, p.(*models.Project).ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(projectIds) > 0 {
|
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
`
|
`
|
||||||
DELETE FROM handmade_project WHERE id = ANY($1)
|
DELETE FROM handmade_project WHERE id = ANY($1)
|
||||||
`,
|
`,
|
||||||
projectIds,
|
projectIDsToDelete,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to delete user's projects")
|
return oops.New(err, "failed to delete user's projects")
|
||||||
|
|
|
@ -19,12 +19,9 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
requestedUsername := usernameArgs[0]
|
requestedUsername := usernameArgs[0]
|
||||||
found = true
|
found = true
|
||||||
c.Perf.StartBlock("SQL", "Fetch user")
|
c.Perf.StartBlock("SQL", "Fetch user")
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
auth_user
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
|
@ -43,7 +40,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
canonicalUsername = userResult.(*userQuery).User.Username
|
canonicalUsername = user.Username
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,13 +75,11 @@ func Login(c *RequestContext) ResponseData {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE LOWER(username) = LOWER($1)
|
WHERE LOWER(username) = LOWER($1)
|
||||||
`,
|
`,
|
||||||
|
@ -94,7 +92,6 @@ func Login(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := &userRow.(*userQuery).User
|
|
||||||
|
|
||||||
success, err := tryLogin(c, user, password)
|
success, err := tryLogin(c, user, password)
|
||||||
|
|
||||||
|
@ -174,7 +171,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
|
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
|
||||||
userAlreadyExists := true
|
userAlreadyExists := true
|
||||||
_, err := db.QueryInt(c.Context(), c.Conn,
|
_, err := db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -195,7 +192,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
emailAlreadyExists := true
|
emailAlreadyExists := true
|
||||||
_, err = db.QueryInt(c.Context(), c.Conn,
|
_, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -454,16 +451,15 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return RejectRequest(c, "You must provide a username and an email address.")
|
return RejectRequest(c, "You must provide a username and an email address.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *models.User
|
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching user")
|
c.Perf.StartBlock("SQL", "Fetching user")
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE
|
WHERE
|
||||||
LOWER(username) = LOWER($1)
|
LOWER(username) = LOWER($1)
|
||||||
|
@ -478,13 +474,10 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if userRow != nil {
|
|
||||||
user = &userRow.(*userQuery).User
|
|
||||||
}
|
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
c.Perf.StartBlock("SQL", "Fetching existing token")
|
c.Perf.StartBlock("SQL", "Fetching existing token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_onetimetoken
|
FROM handmade_onetimetoken
|
||||||
|
@ -501,10 +494,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var resetToken *models.OneTimeToken
|
|
||||||
if tokenRow != nil {
|
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
|
||||||
}
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
if resetToken != nil {
|
if resetToken != nil {
|
||||||
|
@ -527,7 +516,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
if resetToken == nil {
|
if resetToken == nil {
|
||||||
c.Perf.StartBlock("SQL", "Creating new token")
|
c.Perf.StartBlock("SQL", "Creating new token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -543,7 +532,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
||||||
}
|
}
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
resetToken = newToken
|
||||||
|
|
||||||
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -787,7 +776,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
||||||
}
|
}
|
||||||
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
|
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -807,8 +796,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if row != nil {
|
if data != nil {
|
||||||
data := row.(*userAndTokenQuery)
|
|
||||||
result.User = &data.User
|
result.User = &data.User
|
||||||
result.OneTimeToken = data.OneTimeToken
|
result.OneTimeToken = data.OneTimeToken
|
||||||
if result.OneTimeToken != nil {
|
if result.OneTimeToken != nil {
|
||||||
|
|
|
@ -558,7 +558,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
||||||
res.ThreadID = threadId
|
res.ThreadID = threadId
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Verify that the thread exists")
|
c.Perf.StartBlock("SQL", "Verify that the thread exists")
|
||||||
threadExists, err := db.QueryBool(c.Context(), c.Conn,
|
threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*) > 0
|
SELECT COUNT(*) > 0
|
||||||
FROM handmade_thread
|
FROM handmade_thread
|
||||||
|
@ -586,7 +586,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
||||||
res.PostID = postId
|
res.PostID = postId
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Verify that the post exists")
|
c.Perf.StartBlock("SQL", "Verify that the post exists")
|
||||||
postExists, err := db.QueryBool(c.Context(), c.Conn,
|
postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*) > 0
|
SELECT COUNT(*) > 0
|
||||||
FROM handmade_post
|
FROM handmade_post
|
||||||
|
|
|
@ -104,7 +104,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c.Context())
|
defer tx.Rollback(c.Context())
|
||||||
|
|
||||||
iDiscordUser, err := db.QueryOne(c.Context(), tx, models.DiscordUser{},
|
discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discorduser
|
FROM handmade_discorduser
|
||||||
|
@ -119,7 +119,6 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
discordUser := iDiscordUser.(*models.DiscordUser)
|
|
||||||
|
|
||||||
_, err = tx.Exec(c.Context(),
|
_, err = tx.Exec(c.Context(),
|
||||||
`
|
`
|
||||||
|
@ -146,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||||
`SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`,
|
`SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`,
|
||||||
c.CurrentUser.ID,
|
c.CurrentUser.ID,
|
||||||
)
|
)
|
||||||
|
@ -157,14 +156,10 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
|
||||||
}
|
}
|
||||||
duser := iduser.(*models.DiscordUser)
|
|
||||||
|
|
||||||
type messageIdQuery struct {
|
msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||||
MessageID string `db:"msg.id"`
|
|
||||||
}
|
|
||||||
iMsgIDs, err := db.Query(c.Context(), c.Conn, messageIdQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT msg.id
|
||||||
FROM
|
FROM
|
||||||
handmade_discordmessage AS msg
|
handmade_discordmessage AS msg
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -178,10 +173,6 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var msgIDs []string
|
|
||||||
for _, imsgId := range iMsgIDs {
|
|
||||||
msgIDs = append(msgIDs, imsgId.(*messageIdQuery).MessageID)
|
|
||||||
}
|
|
||||||
for _, msgID := range msgIDs {
|
for _, msgID := range msgIDs {
|
||||||
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
|
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
|
||||||
if err != nil && !errors.Is(err, db.NotFound) {
|
if err != nil && !errors.Is(err, db.NotFound) {
|
||||||
|
|
|
@ -936,7 +936,7 @@ func addForumUrlsToPost(urlContext *hmnurl.UrlContext, p *templates.Post, subfor
|
||||||
// Takes a template post and adds information about how many posts the user has made
|
// Takes a template post and adds information about how many posts the user has made
|
||||||
// on the site.
|
// on the site.
|
||||||
func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) {
|
func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) {
|
||||||
numPosts, err := db.QueryInt(ctx, conn,
|
numPosts, err := db.QueryOneScalar[int](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
@ -956,7 +956,7 @@ func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.P
|
||||||
p.AuthorNumPosts = numPosts
|
p.AuthorNumPosts = numPosts
|
||||||
}
|
}
|
||||||
|
|
||||||
numProjects, err := db.QueryInt(ctx, conn,
|
numProjects, err := db.QueryOneScalar[int](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
|
|
@ -89,8 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
|
||||||
img.Seek(0, io.SeekStart)
|
img.Seek(0, io.SeekStart)
|
||||||
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
|
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
|
||||||
sha1sum := hasher.Sum(nil)
|
sha1sum := hasher.Sum(nil)
|
||||||
// TODO(db): Should use insert helper
|
imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn,
|
||||||
imageFile, err := db.QueryOne(c.Context(), dbConn, models.ImageFile{},
|
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height)
|
INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
@ -105,7 +104,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
|
||||||
}
|
}
|
||||||
|
|
||||||
return SaveImageFileResult{
|
return SaveImageFileResult{
|
||||||
ImageFile: imageFile.(*models.ImageFile),
|
ImageFile: imageFile,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,10 +30,9 @@ func ParseLinks(text string) []ParsedLink {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func LinksToText(links []interface{}) string {
|
func LinksToText(links []*models.Link) string {
|
||||||
linksText := ""
|
linksText := ""
|
||||||
for _, l := range links {
|
for _, link := range links {
|
||||||
link := l.(*models.Link)
|
|
||||||
linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name)
|
linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name)
|
||||||
}
|
}
|
||||||
return linksText
|
return linksText
|
||||||
|
|
|
@ -532,10 +532,11 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
Podcast models.Podcast `db:"podcast"`
|
Podcast models.Podcast `db:"podcast"`
|
||||||
ImageFilename string `db:"imagefile.file"`
|
ImageFilename string `db:"imagefile.file"`
|
||||||
}
|
}
|
||||||
podcastQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastQuery{},
|
podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcast AS podcast
|
FROM
|
||||||
|
handmade_podcast AS podcast
|
||||||
LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id
|
LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id
|
||||||
WHERE podcast.project_id = $1
|
WHERE podcast.project_id = $1
|
||||||
`,
|
`,
|
||||||
|
@ -549,18 +550,15 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
return result, oops.New(err, "failed to fetch podcast")
|
return result, oops.New(err, "failed to fetch podcast")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
podcast := podcastQueryResult.(*podcastQuery).Podcast
|
podcast := podcastQueryResult.Podcast
|
||||||
podcastImageFilename := podcastQueryResult.(*podcastQuery).ImageFilename
|
podcastImageFilename := podcastQueryResult.ImageFilename
|
||||||
result.Podcast = &podcast
|
result.Podcast = &podcast
|
||||||
result.ImageFile = podcastImageFilename
|
result.ImageFile = podcastImageFilename
|
||||||
|
|
||||||
if fetchEpisodes {
|
if fetchEpisodes {
|
||||||
type podcastEpisodeQuery struct {
|
|
||||||
Episode models.PodcastEpisode `db:"episode"`
|
|
||||||
}
|
|
||||||
if episodeGUID == "" {
|
if episodeGUID == "" {
|
||||||
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
|
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
|
||||||
podcastEpisodeQueryResult, err := db.Query(c.Context(), c.Conn, podcastEpisodeQuery{},
|
episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcastepisode AS episode
|
FROM handmade_podcastepisode AS episode
|
||||||
|
@ -573,16 +571,14 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, oops.New(err, "failed to fetch podcast episodes")
|
return result, oops.New(err, "failed to fetch podcast episodes")
|
||||||
}
|
}
|
||||||
for _, episodeRow := range podcastEpisodeQueryResult {
|
result.Episodes = episodes
|
||||||
result.Episodes = append(result.Episodes, &episodeRow.(*podcastEpisodeQuery).Episode)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
guid, err := uuid.Parse(episodeGUID)
|
guid, err := uuid.Parse(episodeGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
c.Perf.StartBlock("SQL", "Fetch podcast episode")
|
c.Perf.StartBlock("SQL", "Fetch podcast episode")
|
||||||
podcastEpisodeQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastEpisodeQuery{},
|
episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcastepisode AS episode
|
FROM handmade_podcastepisode AS episode
|
||||||
|
@ -599,8 +595,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
return result, oops.New(err, "failed to fetch podcast episode")
|
return result, oops.New(err, "failed to fetch podcast episode")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
episode := podcastEpisodeQueryResult.(*podcastEpisodeQuery).Episode
|
result.Episodes = append(result.Episodes, episode)
|
||||||
result.Episodes = append(result.Episodes, &episode)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,12 +187,9 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching screenshots")
|
c.Perf.StartBlock("SQL", "Fetching screenshots")
|
||||||
type screenshotQuery struct {
|
screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||||
Filename string `db:"screenshot.file"`
|
|
||||||
}
|
|
||||||
screenshotQueryResult, err := db.Query(c.Context(), c.Conn, screenshotQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT screenshot.file
|
||||||
FROM
|
FROM
|
||||||
handmade_imagefile AS screenshot
|
handmade_imagefile AS screenshot
|
||||||
INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id
|
INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id
|
||||||
|
@ -207,10 +204,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
c.Perf.EndBlock()
|
c.Perf.EndBlock()
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||||
type projectLinkQuery struct {
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
Link models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
projectLinkResult, err := db.Query(c.Context(), c.Conn, projectLinkQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -237,7 +231,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
Thread models.Thread `db:"thread"`
|
Thread models.Thread `db:"thread"`
|
||||||
Author models.User `db:"author"`
|
Author models.User `db:"author"`
|
||||||
}
|
}
|
||||||
postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{},
|
posts, err := db.Query[postQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -318,21 +312,21 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, screenshot := range screenshotQueryResult {
|
for _, screenshotFilename := range screenshotFilenames {
|
||||||
templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshot.(*screenshotQuery).Filename))
|
templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshotFilename))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, link := range projectLinkResult {
|
for _, link := range projectLinks {
|
||||||
templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(&link.(*projectLinkQuery).Link))
|
templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(link))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, post := range postQueryResult {
|
for _, post := range posts {
|
||||||
templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem(
|
templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem(
|
||||||
c.UrlContext,
|
c.UrlContext,
|
||||||
lineageBuilder,
|
lineageBuilder,
|
||||||
&post.(*postQuery).Post,
|
&post.Post,
|
||||||
&post.(*postQuery).Thread,
|
&post.Thread,
|
||||||
&post.(*postQuery).Author,
|
&post.Author,
|
||||||
c.Theme,
|
c.Theme,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
@ -498,7 +492,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||||
projectLinkResult, err := db.Query(c.Context(), c.Conn, models.Link{},
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -525,7 +519,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
||||||
c.Theme,
|
c.Theme,
|
||||||
)
|
)
|
||||||
|
|
||||||
projectSettings.LinksText = LinksToText(projectLinkResult)
|
projectSettings.LinksText = LinksToText(projectLinks)
|
||||||
|
|
||||||
var res ResponseData
|
var res ResponseData
|
||||||
res.MustWriteTemplate("project_edit.html", ProjectEditData{
|
res.MustWriteTemplate("project_edit.html", ProjectEditData{
|
||||||
|
@ -822,13 +816,11 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
owners, err := db.Query[models.User](ctx, tx,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
ownerRows, err := db.Query(ctx, tx, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE LOWER(username) = ANY ($1)
|
WHERE LOWER(username) = ANY ($1)
|
||||||
`,
|
`,
|
||||||
|
@ -849,7 +841,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
return oops.New(err, "Failed to delete project owners")
|
return oops.New(err, "Failed to delete project owners")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ownerRow := range ownerRows {
|
for _, owner := range owners {
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_user_projects
|
INSERT INTO handmade_user_projects
|
||||||
|
@ -857,7 +849,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2)
|
($1, $2)
|
||||||
`,
|
`,
|
||||||
ownerRow.(*userQuery).User.ID,
|
owner.ID,
|
||||||
payload.ProjectID,
|
payload.ProjectID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -548,13 +548,11 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
`,
|
`,
|
||||||
|
@ -568,7 +566,6 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User
|
||||||
return nil, nil, oops.New(err, "failed to get user for session")
|
return nil, nil, oops.New(err, "failed to get user for session")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := &userRow.(*userQuery).User
|
|
||||||
|
|
||||||
return user, session, nil
|
return user, session, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TwitchDebugPage(c *RequestContext) ResponseData {
|
func TwitchDebugPage(c *RequestContext) ResponseData {
|
||||||
streams, err := db.Query(c.Context(), c.Conn, models.TwitchStream{},
|
streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -83,8 +83,7 @@ func TwitchDebugPage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
html := ""
|
html := ""
|
||||||
for _, stream := range streams {
|
for _, s := range streams {
|
||||||
s := stream.(*models.TwitchStream)
|
|
||||||
html += fmt.Sprintf(`<a href="https://twitch.tv/%s">%s</a>%s<br />`, s.Login, s.Login, s.Title)
|
html += fmt.Sprintf(`<a href="https://twitch.tv/%s">%s</a>%s<br />`, s.Login, s.Login, s.Title)
|
||||||
}
|
}
|
||||||
var res ResponseData
|
var res ResponseData
|
||||||
|
|
|
@ -53,12 +53,9 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
profileUser = c.CurrentUser
|
profileUser = c.CurrentUser
|
||||||
} else {
|
} else {
|
||||||
c.Perf.StartBlock("SQL", "Fetch user")
|
c.Perf.StartBlock("SQL", "Fetch user")
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
auth_user
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
|
@ -75,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
profileUser = &userResult.(*userQuery).User
|
profileUser = user
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -87,10 +84,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetch user links")
|
c.Perf.StartBlock("SQL", "Fetch user links")
|
||||||
type userLinkQuery struct {
|
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
UserLink models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
userLinksSlice, err := db.Query(c.Context(), c.Conn, userLinkQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -104,9 +98,9 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username))
|
||||||
}
|
}
|
||||||
profileUserLinks := make([]templates.Link, 0, len(userLinksSlice))
|
profileUserLinks := make([]templates.Link, 0, len(userLinks))
|
||||||
for _, l := range userLinksSlice {
|
for _, l := range userLinks {
|
||||||
profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(&l.(*userLinkQuery).UserLink))
|
profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(l))
|
||||||
}
|
}
|
||||||
c.Perf.EndBlock()
|
c.Perf.EndBlock()
|
||||||
|
|
||||||
|
@ -231,7 +225,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
DiscordShowcaseBacklogUrl string
|
DiscordShowcaseBacklogUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
links, err := db.Query(c.Context(), c.Conn, models.Link{},
|
links, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_links
|
FROM handmade_links
|
||||||
|
@ -248,7 +242,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
var tduser *templates.DiscordUser
|
var tduser *templates.DiscordUser
|
||||||
var numUnsavedMessages int
|
var numUnsavedMessages int
|
||||||
iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discorduser
|
FROM handmade_discorduser
|
||||||
|
@ -261,11 +255,10 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))
|
||||||
} else {
|
} else {
|
||||||
duser := iduser.(*models.DiscordUser)
|
|
||||||
tmp := templates.DiscordUserToTemplate(duser)
|
tmp := templates.DiscordUserToTemplate(duser)
|
||||||
tduser = &tmp
|
tduser = &tmp
|
||||||
|
|
||||||
numUnsavedMessages, err = db.QueryInt(c.Context(), c.Conn,
|
numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
|
Loading…
Reference in New Issue