Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
Ben Visness | 8ff5f02916 | |
Ben Visness | 2229ac85d5 | |
Ben Visness | 97360a1998 | |
Ben Visness | a2917b98c0 | |
Ben Visness | b9a4cb2361 |
40
go.mod
40
go.mod
|
@ -1,10 +1,8 @@
|
||||||
module git.handmade.network/hmn/hmn
|
module git.handmade.network/hmn/hmn
|
||||||
|
|
||||||
go 1.16
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
|
||||||
github.com/Masterminds/semver v1.5.0 // indirect
|
|
||||||
github.com/Masterminds/sprig v2.22.0+incompatible
|
github.com/Masterminds/sprig v2.22.0+incompatible
|
||||||
github.com/alecthomas/chroma v0.9.2
|
github.com/alecthomas/chroma v0.9.2
|
||||||
github.com/aws/aws-sdk-go-v2 v1.8.1
|
github.com/aws/aws-sdk-go-v2 v1.8.1
|
||||||
|
@ -16,13 +14,10 @@ require (
|
||||||
github.com/go-stack/stack v1.8.0
|
github.com/go-stack/stack v1.8.0
|
||||||
github.com/google/uuid v1.2.0
|
github.com/google/uuid v1.2.0
|
||||||
github.com/gorilla/websocket v1.4.2
|
github.com/gorilla/websocket v1.4.2
|
||||||
github.com/huandu/xstrings v1.3.2 // indirect
|
|
||||||
github.com/imdario/mergo v0.3.12 // indirect
|
|
||||||
github.com/jackc/pgconn v1.8.0
|
github.com/jackc/pgconn v1.8.0
|
||||||
github.com/jackc/pgtype v1.6.2
|
github.com/jackc/pgtype v1.6.2
|
||||||
github.com/jackc/pgx/v4 v4.10.1
|
github.com/jackc/pgx/v4 v4.10.1
|
||||||
github.com/jpillora/backoff v1.0.0
|
github.com/jpillora/backoff v1.0.0
|
||||||
github.com/mitchellh/copystructure v1.1.1 // indirect
|
|
||||||
github.com/rs/zerolog v1.21.0
|
github.com/rs/zerolog v1.21.0
|
||||||
github.com/spf13/cobra v1.1.3
|
github.com/spf13/cobra v1.1.3
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
|
@ -35,6 +30,39 @@ require (
|
||||||
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
|
golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||||
|
github.com/Masterminds/semver v1.5.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.4.1 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.1 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.3.3 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.6.2 // indirect
|
||||||
|
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.4.0 // indirect
|
||||||
|
github.com/huandu/xstrings v1.3.2 // indirect
|
||||||
|
github.com/imdario/mergo v0.3.12 // indirect
|
||||||
|
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||||
|
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
|
||||||
|
github.com/jackc/pgio v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgproto3/v2 v2.0.6 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||||
|
github.com/jackc/puddle v1.1.3 // indirect
|
||||||
|
github.com/mitchellh/copystructure v1.1.1 // indirect
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 // indirect
|
||||||
|
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 // indirect
|
||||||
|
golang.org/x/text v0.3.6 // indirect
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||||
|
)
|
||||||
|
|
||||||
replace (
|
replace (
|
||||||
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
|
github.com/frustra/bbcode v0.0.0-20201127003707-6ef347fbe1c8 => github.com/HandmadeNetwork/bbcode v0.0.0-20210623031351-ec0e2e2e39d9
|
||||||
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e
|
github.com/yuin/goldmark v1.4.1 => github.com/HandmadeNetwork/goldmark v1.4.1-0.20210707024600-f7e596e26b5e
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -159,7 +159,6 @@ github.com/imdario/mergo v0.3.12 h1:b6R2BslTbIEToALKP7LxUvijTsNI9TAe80pLWN2g/HU=
|
||||||
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
|
||||||
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
||||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
|
||||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||||
|
@ -178,7 +177,6 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47
|
||||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
|
||||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||||
|
|
|
@ -55,7 +55,7 @@ func addCreateProjectCommand(projectCommand *cobra.Command) {
|
||||||
}
|
}
|
||||||
hmn := p.Project
|
hmn := p.Project
|
||||||
|
|
||||||
newProjectID, err := db.QueryInt(ctx, tx,
|
newProjectID, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_project (
|
INSERT INTO handmade_project (
|
||||||
slug,
|
slug,
|
||||||
|
|
|
@ -210,7 +210,7 @@ func init() {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -218,7 +218,7 @@ func init() {
|
||||||
var parentId *int
|
var parentId *int
|
||||||
if parentSlug == "" {
|
if parentSlug == "" {
|
||||||
// Select the root subforum
|
// Select the root subforum
|
||||||
id, err := db.QueryInt(ctx, tx,
|
id, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`,
|
`SELECT id FROM handmade_subforum WHERE parent_id IS NULL AND project_id = $1`,
|
||||||
projectId,
|
projectId,
|
||||||
)
|
)
|
||||||
|
@ -228,7 +228,7 @@ func init() {
|
||||||
parentId = &id
|
parentId = &id
|
||||||
} else {
|
} else {
|
||||||
// Select the parent
|
// Select the parent
|
||||||
id, err := db.QueryInt(ctx, tx,
|
id, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
||||||
parentSlug, projectId,
|
parentSlug, projectId,
|
||||||
)
|
)
|
||||||
|
@ -238,7 +238,7 @@ func init() {
|
||||||
parentId = &id
|
parentId = &id
|
||||||
}
|
}
|
||||||
|
|
||||||
newId, err := db.QueryInt(ctx, tx,
|
newId, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id)
|
INSERT INTO handmade_subforum (name, slug, blurb, parent_id, project_id)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -289,12 +289,12 @@ func init() {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
projectId, err := db.QueryInt(ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
projectId, err := db.QueryOneScalar[int](ctx, tx, `SELECT id FROM handmade_project WHERE slug = $1`, projectSlug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subforumId, err := db.QueryInt(ctx, tx,
|
subforumId, err := db.QueryOneScalar[int](ctx, tx,
|
||||||
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
`SELECT id FROM handmade_subforum WHERE slug = $1 AND project_id = $2`,
|
||||||
subforumSlug, projectId,
|
subforumSlug, projectId,
|
||||||
)
|
)
|
||||||
|
|
|
@ -140,7 +140,7 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch and return the new record
|
// Fetch and return the new record
|
||||||
iasset, err := db.QueryOne(ctx, dbConn, models.Asset{},
|
asset, err := db.QueryOne[models.Asset](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_asset
|
FROM handmade_asset
|
||||||
|
@ -152,5 +152,5 @@ func Create(ctx context.Context, dbConn db.ConnOrTx, in CreateInput) (*models.As
|
||||||
return nil, oops.New(err, "failed to fetch newly-created asset")
|
return nil, oops.New(err, "failed to fetch newly-created asset")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iasset.(*models.Asset), nil
|
return asset, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ func makeCSRFToken() string {
|
||||||
var ErrNoSession = errors.New("no session found")
|
var ErrNoSession = errors.New("no session found")
|
||||||
|
|
||||||
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
|
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
|
||||||
row, err := db.QueryOne(ctx, conn, models.Session{}, "SELECT $columns FROM sessions WHERE id = $1", id)
|
sess, err := db.QueryOne[models.Session](ctx, conn, "SELECT $columns FROM sessions WHERE id = $1", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.NotFound) {
|
if errors.Is(err, db.NotFound) {
|
||||||
return nil, ErrNoSession
|
return nil, ErrNoSession
|
||||||
|
@ -53,7 +53,6 @@ func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Ses
|
||||||
return nil, oops.New(err, "failed to get session")
|
return nil, oops.New(err, "failed to get session")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sess := row.(*models.Session)
|
|
||||||
|
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
617
src/db/db.go
617
src/db/db.go
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/src/config"
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
|
@ -20,46 +21,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Values of these kinds are ok to query even if they are not directly understood by pgtype.
|
A general error to be used when no results are found. This is the error returned
|
||||||
This is common for custom types like:
|
by QueryOne, and can generally be used by other database helpers that fetch a single
|
||||||
|
result but find nothing.
|
||||||
type ThreadType int
|
|
||||||
*/
|
*/
|
||||||
var queryableKinds = []reflect.Kind{
|
var NotFound = errors.New("not found")
|
||||||
reflect.Int,
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Checks if we are able to handle a particular type in a database query. This applies only to
|
|
||||||
primitive types and not structs, since the database only returns individual primitive types
|
|
||||||
and it is our job to stitch them back together into structs later.
|
|
||||||
*/
|
|
||||||
func typeIsQueryable(t reflect.Type) bool {
|
|
||||||
_, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
|
|
||||||
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
|
|
||||||
|
|
||||||
if isRecognizedByPgtype {
|
|
||||||
return true
|
|
||||||
} else if t == reflect.TypeOf(uuid.UUID{}) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// pgtype doesn't recognize it, but maybe it's a primitive type we can deal with
|
|
||||||
k := t.Kind()
|
|
||||||
for _, qk := range queryableKinds {
|
|
||||||
if k == qk {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// This interface should match both a direct pgx connection or a pgx transaction.
|
// This interface should match both a direct pgx connection or a pgx transaction.
|
||||||
type ConnOrTx interface {
|
type ConnOrTx interface {
|
||||||
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||||
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||||
Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error)
|
Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
|
||||||
|
|
||||||
// Both raw database connections and transactions in pgx can begin/commit
|
// Both raw database connections and transactions in pgx can begin/commit
|
||||||
// transactions. For database connections it does the obvious thing; for
|
// transactions. For database connections it does the obvious thing; for
|
||||||
|
@ -70,6 +42,8 @@ type ConnOrTx interface {
|
||||||
|
|
||||||
var connInfo = pgtype.NewConnInfo()
|
var connInfo = pgtype.NewConnInfo()
|
||||||
|
|
||||||
|
// Creates a new connection to the HMN database.
|
||||||
|
// This connection is not safe for concurrent use.
|
||||||
func NewConn() *pgx.Conn {
|
func NewConn() *pgx.Conn {
|
||||||
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
conn, err := pgx.Connect(context.Background(), config.Config.Postgres.DSN())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,6 +53,8 @@ func NewConn() *pgx.Conn {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Creates a connection pool for the HMN database.
|
||||||
|
// The resulting pool is safe for concurrent use.
|
||||||
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
||||||
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
cfg, err := pgxpool.ParseConfig(config.Config.Postgres.DSN())
|
||||||
|
|
||||||
|
@ -95,144 +71,20 @@ func NewConnPool(minConns, maxConns int32) *pgxpool.Pool {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
type StructQueryIterator struct {
|
/*
|
||||||
fieldPaths [][]int
|
Performs a SQL query and returns a slice of all the result rows. The query is just plain SQL, but make sure to read the package documentation for details. You must explicitly provide the type argument - this is how it knows what Go type to map the results to, and it cannot be inferred.
|
||||||
rows pgx.Rows
|
|
||||||
destType reflect.Type
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) Next() (interface{}, bool) {
|
Any SQL query may be performed, including INSERT and UPDATE - as long as it returns a result set, you can use this. If the query does not return a result set, or you simply do not care about the result set, call Exec directly on your pgx connection.
|
||||||
hasNext := it.rows.Next()
|
|
||||||
if !hasNext {
|
|
||||||
it.Close()
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
result := reflect.New(it.destType)
|
This function always returns pointers to the values. This is convenient for structs, but for other types, you may wish to use QueryScalar.
|
||||||
|
*/
|
||||||
vals, err := it.rows.Values()
|
func Query[T any](
|
||||||
if err != nil {
|
ctx context.Context,
|
||||||
panic(err)
|
conn ConnOrTx,
|
||||||
}
|
query string,
|
||||||
|
args ...any,
|
||||||
// Better logging of panics in this confusing reflection process
|
) ([]*T, error) {
|
||||||
var currentField reflect.StructField
|
it, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
var currentValue reflect.Value
|
|
||||||
var currentIdx int
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
if currentValue.IsValid() {
|
|
||||||
logging.Error().
|
|
||||||
Int("index", currentIdx).
|
|
||||||
Str("field name", currentField.Name).
|
|
||||||
Stringer("field type", currentField.Type).
|
|
||||||
Interface("value", currentValue.Interface()).
|
|
||||||
Stringer("value type", currentValue.Type()).
|
|
||||||
Msg("panic in iterator")
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentField.Name != "" {
|
|
||||||
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
|
|
||||||
} else {
|
|
||||||
panic(r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i, val := range vals {
|
|
||||||
currentIdx = i
|
|
||||||
if val == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var field reflect.Value
|
|
||||||
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
|
|
||||||
if field.Kind() == reflect.Ptr {
|
|
||||||
field.Set(reflect.New(field.Type().Elem()))
|
|
||||||
field = field.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
|
|
||||||
// Regardless, we know it's not nil, so we can get at the contents.
|
|
||||||
valReflected := reflect.ValueOf(val)
|
|
||||||
if valReflected.Kind() == reflect.Ptr {
|
|
||||||
valReflected = valReflected.Elem()
|
|
||||||
}
|
|
||||||
currentValue = valReflected
|
|
||||||
|
|
||||||
switch field.Kind() {
|
|
||||||
case reflect.Int:
|
|
||||||
field.SetInt(valReflected.Int())
|
|
||||||
default:
|
|
||||||
field.Set(valReflected)
|
|
||||||
}
|
|
||||||
|
|
||||||
currentField = reflect.StructField{}
|
|
||||||
currentValue = reflect.Value{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.Interface(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) Close() {
|
|
||||||
it.rows.Close()
|
|
||||||
select {
|
|
||||||
case it.closed <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (it *StructQueryIterator) ToSlice() []interface{} {
|
|
||||||
defer it.Close()
|
|
||||||
var result []interface{}
|
|
||||||
for {
|
|
||||||
row, ok := it.Next()
|
|
||||||
if !ok {
|
|
||||||
err := it.rows.Err()
|
|
||||||
if err != nil {
|
|
||||||
panic(oops.New(err, "error while iterating through db results"))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
result = append(result, row)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
|
|
||||||
if len(path) < 1 {
|
|
||||||
panic(oops.New(nil, "can't follow an empty path"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
|
|
||||||
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// more informative panic recovery
|
|
||||||
var field reflect.StructField
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
val := structPtrVal
|
|
||||||
for _, i := range path {
|
|
||||||
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
|
|
||||||
if val.IsNil() {
|
|
||||||
val.Set(reflect.New(val.Type().Elem()))
|
|
||||||
}
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
field = val.Type().Field(i)
|
|
||||||
val = val.Field(i)
|
|
||||||
}
|
|
||||||
return val, field
|
|
||||||
}
|
|
||||||
|
|
||||||
func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
|
||||||
it, err := QueryIterator(ctx, conn, destExample, query, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -240,27 +92,99 @@ func Query(ctx context.Context, conn ConnOrTx, destExample interface{}, query st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (*StructQueryIterator, error) {
|
/*
|
||||||
destType := reflect.TypeOf(destExample)
|
Identical to Query, but returns only the first result row. If there are no
|
||||||
columnNames, fieldPaths, err := getColumnNamesAndPaths(destType, nil, nil)
|
rows in the result set, returns NotFound.
|
||||||
|
*/
|
||||||
|
func QueryOne[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (*T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to generate column names")
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
return nil, NotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
columns := make([]string, 0, len(columnNames))
|
return result, nil
|
||||||
for _, strSlice := range columnNames {
|
}
|
||||||
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
|
||||||
fullName := strSlice[len(strSlice)-1]
|
/*
|
||||||
if tableName != "" {
|
Identical to Query, but returns concrete values instead of pointers. More convenient
|
||||||
fullName = tableName + "." + fullName
|
for primitive types.
|
||||||
|
*/
|
||||||
|
func QueryScalar[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) ([]T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var result []T
|
||||||
|
for {
|
||||||
|
val, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
columns = append(columns, fullName)
|
result = append(result, *val)
|
||||||
}
|
}
|
||||||
|
|
||||||
columnNamesString := strings.Join(columns, ", ")
|
return result, nil
|
||||||
query = strings.Replace(query, "$columns", columnNamesString, -1)
|
}
|
||||||
|
|
||||||
rows, err := conn.Query(ctx, query, args...)
|
/*
|
||||||
|
Identical to QueryScalar, but returns only the first result value. If there are
|
||||||
|
no rows in the result set, returns NotFound.
|
||||||
|
*/
|
||||||
|
func QueryOneScalar[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (T, error) {
|
||||||
|
rows, err := QueryIterator[T](ctx, conn, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
result, hasRow := rows.Next()
|
||||||
|
if !hasRow {
|
||||||
|
var zero T
|
||||||
|
return zero, NotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return *result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Identical to Query, but returns the ResultIterator instead of automatically converting the results to a slice. The iterator must be closed after use.
|
||||||
|
*/
|
||||||
|
func QueryIterator[T any](
|
||||||
|
ctx context.Context,
|
||||||
|
conn ConnOrTx,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (*Iterator[T], error) {
|
||||||
|
var destExample T
|
||||||
|
destType := reflect.TypeOf(destExample)
|
||||||
|
|
||||||
|
compiled := compileQuery(query, destType)
|
||||||
|
|
||||||
|
rows, err := conn.Query(ctx, compiled.query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
panic("query exceeded its deadline")
|
panic("query exceeded its deadline")
|
||||||
|
@ -268,11 +192,12 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
it := &StructQueryIterator{
|
it := &Iterator[T]{
|
||||||
fieldPaths: fieldPaths,
|
fieldPaths: compiled.fieldPaths,
|
||||||
rows: rows,
|
rows: rows,
|
||||||
destType: destType,
|
destType: compiled.destType,
|
||||||
closed: make(chan struct{}, 1),
|
destTypeIsScalar: typeIsQueryable(compiled.destType),
|
||||||
|
closed: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold
|
// Ensure that iterators are closed if context is cancelled. Otherwise, iterators can hold
|
||||||
|
@ -292,16 +217,72 @@ func QueryIterator(ctx context.Context, conn ConnOrTx, destExample interface{},
|
||||||
return it, nil
|
return it, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names [][]string, paths [][]int, err error) {
|
// TODO: QueryFunc?
|
||||||
var columnNames [][]string
|
|
||||||
var fieldPaths [][]int
|
type compiledQuery struct {
|
||||||
|
query string
|
||||||
|
destType reflect.Type
|
||||||
|
fieldPaths []fieldPath
|
||||||
|
}
|
||||||
|
|
||||||
|
var reColumnsPlaceholder = regexp.MustCompile(`\$columns({(.*?)})?`)
|
||||||
|
|
||||||
|
func compileQuery(query string, destType reflect.Type) compiledQuery {
|
||||||
|
columnsMatch := reColumnsPlaceholder.FindStringSubmatch(query)
|
||||||
|
hasColumnsPlaceholder := columnsMatch != nil
|
||||||
|
|
||||||
|
if hasColumnsPlaceholder {
|
||||||
|
// The presence of the $columns placeholder means that the destination type
|
||||||
|
// must be a struct, and we will plonk that struct's fields into the query.
|
||||||
|
|
||||||
|
if destType.Kind() != reflect.Struct {
|
||||||
|
panic("$columns can only be used when querying into a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix []string
|
||||||
|
prefixText := columnsMatch[2]
|
||||||
|
if prefixText != "" {
|
||||||
|
prefix = []string{prefixText}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNames, fieldPaths := getColumnNamesAndPaths(destType, nil, prefix)
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(columnNames))
|
||||||
|
for _, strSlice := range columnNames {
|
||||||
|
tableName := strings.Join(strSlice[0:len(strSlice)-1], "_")
|
||||||
|
fullName := strSlice[len(strSlice)-1]
|
||||||
|
if tableName != "" {
|
||||||
|
fullName = tableName + "." + fullName
|
||||||
|
}
|
||||||
|
columns = append(columns, fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
columnNamesString := strings.Join(columns, ", ")
|
||||||
|
query = reColumnsPlaceholder.ReplaceAllString(query, columnNamesString)
|
||||||
|
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
fieldPaths: fieldPaths,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return compiledQuery{
|
||||||
|
query: query,
|
||||||
|
destType: destType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []string) (names []columnName, paths []fieldPath) {
|
||||||
|
var columnNames []columnName
|
||||||
|
var fieldPaths []fieldPath
|
||||||
|
|
||||||
if destType.Kind() == reflect.Ptr {
|
if destType.Kind() == reflect.Ptr {
|
||||||
destType = destType.Elem()
|
destType = destType.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if destType.Kind() != reflect.Struct {
|
if destType.Kind() != reflect.Struct {
|
||||||
return nil, nil, oops.New(nil, "can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix)
|
panic(fmt.Errorf("can only get column names and paths from a struct, got type '%v' (at prefix '%v')", destType.Name(), prefix))
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnonPrefix struct {
|
type AnonPrefix struct {
|
||||||
|
@ -348,108 +329,214 @@ func getColumnNamesAndPaths(destType reflect.Type, pathSoFar []int, prefix []str
|
||||||
columnNames = append(columnNames, fieldColumnNames)
|
columnNames = append(columnNames, fieldColumnNames)
|
||||||
fieldPaths = append(fieldPaths, path)
|
fieldPaths = append(fieldPaths, path)
|
||||||
} else if fieldType.Kind() == reflect.Struct {
|
} else if fieldType.Kind() == reflect.Struct {
|
||||||
subCols, subPaths, err := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
subCols, subPaths := getColumnNamesAndPaths(fieldType, path, fieldColumnNames)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
columnNames = append(columnNames, subCols...)
|
columnNames = append(columnNames, subCols...)
|
||||||
fieldPaths = append(fieldPaths, subPaths...)
|
fieldPaths = append(fieldPaths, subPaths...)
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, oops.New(nil, "field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type)
|
panic(fmt.Errorf("field '%s' in type %s has invalid type '%s'", field.Name, destType, field.Type))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columnNames, fieldPaths, nil
|
return columnNames, fieldPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
A general error to be used when no results are found. This is the error returned
|
Values of these kinds are ok to query even if they are not directly understood by pgtype.
|
||||||
by QueryOne, and can generally be used by other database helpers that fetch a single
|
This is common for custom types like:
|
||||||
result but find nothing.
|
|
||||||
|
type ThreadType int
|
||||||
*/
|
*/
|
||||||
var NotFound = errors.New("not found")
|
var queryableKinds = []reflect.Kind{
|
||||||
|
reflect.Int,
|
||||||
func QueryOne(ctx context.Context, conn ConnOrTx, destExample interface{}, query string, args ...interface{}) (interface{}, error) {
|
|
||||||
rows, err := QueryIterator(ctx, conn, destExample, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
result, hasRow := rows.Next()
|
|
||||||
if !hasRow {
|
|
||||||
return nil, NotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryScalar(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (interface{}, error) {
|
/*
|
||||||
rows, err := conn.Query(ctx, query, args...)
|
Checks if we are able to handle a particular type in a database query. This applies only to
|
||||||
if err != nil {
|
primitive types and not structs, since the database only returns individual primitive types
|
||||||
return nil, err
|
and it is our job to stitch them back together into structs later.
|
||||||
|
*/
|
||||||
|
func typeIsQueryable(t reflect.Type) bool {
|
||||||
|
_, isRecognizedByPgtype := connInfo.DataTypeForValue(reflect.New(t).Elem().Interface()) // if pgtype recognizes it, we don't need to dig in further for more `db` tags
|
||||||
|
// NOTE: boy it would be nice if we didn't have to do reflect.New here, considering that pgtype is just doing reflection on the value anyway
|
||||||
|
|
||||||
|
if isRecognizedByPgtype {
|
||||||
|
return true
|
||||||
|
} else if t == reflect.TypeOf(uuid.UUID{}) {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Next() {
|
// pgtype doesn't recognize it, but maybe it's a primitive type we can deal with
|
||||||
vals, err := rows.Values()
|
k := t.Kind()
|
||||||
if err != nil {
|
for _, qk := range queryableKinds {
|
||||||
panic(err)
|
if k == qk {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type columnName []string
|
||||||
|
|
||||||
|
// A path to a particular field in query's destination type. Each index in the slice
|
||||||
|
// corresponds to a field index for use with Field on a reflect.Type or reflect.Value.
|
||||||
|
type fieldPath []int
|
||||||
|
|
||||||
|
type Iterator[T any] struct {
|
||||||
|
fieldPaths []fieldPath
|
||||||
|
rows pgx.Rows
|
||||||
|
destType reflect.Type
|
||||||
|
destTypeIsScalar bool // NOTE(ben): Make sure this gets set every time destType gets set, based on typeIsQueryable(destType). This is kinda fragile...but also contained to this file, so doesn't seem worth a lazy evaluation or a constructor function.
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *Iterator[T]) Next() (*T, bool) {
|
||||||
|
// TODO(ben): What happens if this panics? Does it leak resources? Do we need
|
||||||
|
// to put a recover() here and close the rows?
|
||||||
|
|
||||||
|
hasNext := it.rows.Next()
|
||||||
|
if !hasNext {
|
||||||
|
it.Close()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
result := reflect.New(it.destType)
|
||||||
|
|
||||||
|
vals, err := it.rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if it.destTypeIsScalar {
|
||||||
|
// This type can be directly queried, meaning pgx recognizes it, it's
|
||||||
|
// a simple scalar thing, and we can just take the easy way out.
|
||||||
if len(vals) != 1 {
|
if len(vals) != 1 {
|
||||||
return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals))
|
panic(fmt.Errorf("tried to query a scalar value, but got %v values in the row", len(vals)))
|
||||||
|
}
|
||||||
|
setValueFromDB(result.Elem(), reflect.ValueOf(vals[0]))
|
||||||
|
return result.Interface().(*T), true
|
||||||
|
} else {
|
||||||
|
var currentField reflect.StructField
|
||||||
|
var currentValue reflect.Value
|
||||||
|
var currentIdx int
|
||||||
|
|
||||||
|
// Better logging of panics in this confusing reflection process
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if currentValue.IsValid() {
|
||||||
|
logging.Error().
|
||||||
|
Int("index", currentIdx).
|
||||||
|
Str("field name", currentField.Name).
|
||||||
|
Stringer("field type", currentField.Type).
|
||||||
|
Interface("value", currentValue.Interface()).
|
||||||
|
Stringer("value type", currentValue.Type()).
|
||||||
|
Msg("panic in iterator")
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentField.Name != "" {
|
||||||
|
panic(fmt.Errorf("panic while processing field '%s': %v", currentField.Name, r))
|
||||||
|
} else {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i, val := range vals {
|
||||||
|
currentIdx = i
|
||||||
|
if val == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var field reflect.Value
|
||||||
|
field, currentField = followPathThroughStructs(result, it.fieldPaths[i])
|
||||||
|
if field.Kind() == reflect.Ptr {
|
||||||
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
|
field = field.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some actual values still come through as pointers (like net.IPNet). Dunno why.
|
||||||
|
// Regardless, we know it's not nil, so we can get at the contents.
|
||||||
|
valReflected := reflect.ValueOf(val)
|
||||||
|
if valReflected.Kind() == reflect.Ptr {
|
||||||
|
valReflected = valReflected.Elem()
|
||||||
|
}
|
||||||
|
currentValue = valReflected
|
||||||
|
|
||||||
|
setValueFromDB(field, valReflected)
|
||||||
|
|
||||||
|
currentField = reflect.StructField{}
|
||||||
|
currentValue = reflect.Value{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return vals[0], nil
|
return result.Interface().(*T), true
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, NotFound
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryString(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (string, error) {
|
func setValueFromDB(dest reflect.Value, value reflect.Value) {
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
switch dest.Kind() {
|
||||||
if err != nil {
|
case reflect.Int:
|
||||||
return "", err
|
dest.SetInt(value.Int())
|
||||||
}
|
|
||||||
|
|
||||||
switch r := result.(type) {
|
|
||||||
case string:
|
|
||||||
return r, nil
|
|
||||||
default:
|
default:
|
||||||
return "", oops.New(nil, "QueryString got a non-string result: %v", result)
|
dest.Set(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryInt(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (int, error) {
|
func (it *Iterator[any]) Close() {
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
it.rows.Close()
|
||||||
if err != nil {
|
select {
|
||||||
return 0, err
|
case it.closed <- struct{}{}:
|
||||||
}
|
|
||||||
|
|
||||||
switch r := result.(type) {
|
|
||||||
case int:
|
|
||||||
return r, nil
|
|
||||||
case int32:
|
|
||||||
return int(r), nil
|
|
||||||
case int64:
|
|
||||||
return int(r), nil
|
|
||||||
default:
|
default:
|
||||||
return 0, oops.New(nil, "QueryInt got a non-int result: %v", result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func QueryBool(ctx context.Context, conn ConnOrTx, query string, args ...interface{}) (bool, error) {
|
/*
|
||||||
result, err := QueryScalar(ctx, conn, query, args...)
|
Pulls all the remaining values into a slice, and closes the iterator.
|
||||||
if err != nil {
|
*/
|
||||||
return false, err
|
func (it *Iterator[T]) ToSlice() []*T {
|
||||||
|
defer it.Close()
|
||||||
|
var result []*T
|
||||||
|
for {
|
||||||
|
row, ok := it.Next()
|
||||||
|
if !ok {
|
||||||
|
err := it.rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
panic(oops.New(err, "error while iterating through db results"))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, row)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func followPathThroughStructs(structPtrVal reflect.Value, path []int) (reflect.Value, reflect.StructField) {
|
||||||
|
if len(path) < 1 {
|
||||||
|
panic(oops.New(nil, "can't follow an empty path"))
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r := result.(type) {
|
if structPtrVal.Kind() != reflect.Ptr || structPtrVal.Elem().Kind() != reflect.Struct {
|
||||||
case bool:
|
panic(oops.New(nil, "structPtrVal must be a pointer to a struct; got value of type %s", structPtrVal.Type()))
|
||||||
return r, nil
|
|
||||||
default:
|
|
||||||
return false, oops.New(nil, "QueryBool got a non-bool result: %v", result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// more informative panic recovery
|
||||||
|
var field reflect.StructField
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
panic(oops.New(nil, "panic at field '%s': %v", field.Name, r))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
val := structPtrVal
|
||||||
|
for _, i := range path {
|
||||||
|
if val.Kind() == reflect.Ptr && val.Type().Elem().Kind() == reflect.Struct {
|
||||||
|
if val.IsNil() {
|
||||||
|
val.Set(reflect.New(val.Type().Elem()))
|
||||||
|
}
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
field = val.Type().Field(i)
|
||||||
|
val = val.Field(i)
|
||||||
|
}
|
||||||
|
return val, field
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,52 +10,143 @@ import (
|
||||||
|
|
||||||
func TestPaths(t *testing.T) {
|
func TestPaths(t *testing.T) {
|
||||||
type CustomInt int
|
type CustomInt int
|
||||||
type S struct {
|
type S2 struct {
|
||||||
I int `db:"I"`
|
B bool `db:"B"` // field 0
|
||||||
PI *int `db:"PI"`
|
PB *bool `db:"PB"` // field 1
|
||||||
CI CustomInt `db:"CI"`
|
|
||||||
PCI *CustomInt `db:"PCI"`
|
|
||||||
B bool `db:"B"`
|
|
||||||
PB *bool `db:"PB"`
|
|
||||||
|
|
||||||
NoTag int
|
NoTag string // field 2
|
||||||
|
}
|
||||||
|
type S struct {
|
||||||
|
I int `db:"I"` // field 0
|
||||||
|
PI *int `db:"PI"` // field 1
|
||||||
|
CI CustomInt `db:"CI"` // field 2
|
||||||
|
PCI *CustomInt `db:"PCI"` // field 3
|
||||||
|
S2 `db:"S2"` // field 4 (embedded!)
|
||||||
|
PS2 *S2 `db:"PS2"` // field 5
|
||||||
|
|
||||||
|
NoTag int // field 6
|
||||||
}
|
}
|
||||||
type Nested struct {
|
type Nested struct {
|
||||||
S S `db:"S"`
|
S S `db:"S"` // field 0
|
||||||
PS *S `db:"PS"`
|
PS *S `db:"PS"` // field 1
|
||||||
|
|
||||||
NoTag S
|
NoTag S // field 2
|
||||||
}
|
}
|
||||||
type Embedded struct {
|
type Embedded struct {
|
||||||
NoTag S
|
NoTag S // field 0
|
||||||
Nested
|
Nested // field 1
|
||||||
}
|
}
|
||||||
|
|
||||||
names, paths, err := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, "")
|
names, paths := getColumnNamesAndPaths(reflect.TypeOf(Embedded{}), nil, nil)
|
||||||
if assert.Nil(t, err) {
|
assert.Equal(t, []columnName{
|
||||||
assert.Equal(t, []string{
|
{"S", "I"}, {"S", "PI"},
|
||||||
"S.I", "S.PI",
|
{"S", "CI"}, {"S", "PCI"},
|
||||||
"S.CI", "S.PCI",
|
{"S", "S2", "B"}, {"S", "S2", "PB"},
|
||||||
"S.B", "S.PB",
|
{"S", "PS2", "B"}, {"S", "PS2", "PB"},
|
||||||
"PS.I", "PS.PI",
|
{"PS", "I"}, {"PS", "PI"},
|
||||||
"PS.CI", "PS.PCI",
|
{"PS", "CI"}, {"PS", "PCI"},
|
||||||
"PS.B", "PS.PB",
|
{"PS", "S2", "B"}, {"PS", "S2", "PB"},
|
||||||
}, names)
|
{"PS", "PS2", "B"}, {"PS", "PS2", "PB"},
|
||||||
assert.Equal(t, [][]int{
|
}, names)
|
||||||
{1, 0, 0}, {1, 0, 1}, {1, 0, 2}, {1, 0, 3}, {1, 0, 4}, {1, 0, 5},
|
assert.Equal(t, []fieldPath{
|
||||||
{1, 1, 0}, {1, 1, 1}, {1, 1, 2}, {1, 1, 3}, {1, 1, 4}, {1, 1, 5},
|
{1, 0, 0}, {1, 0, 1}, // Nested.S.I, Nested.S.PI
|
||||||
}, paths)
|
{1, 0, 2}, {1, 0, 3}, // Nested.S.CI, Nested.S.PCI
|
||||||
assert.True(t, len(names) == len(paths))
|
{1, 0, 4, 0}, {1, 0, 4, 1}, // Nested.S.S2.B, Nested.S.S2.PB
|
||||||
}
|
{1, 0, 5, 0}, {1, 0, 5, 1}, // Nested.S.PS2.B, Nested.S.PS2.PB
|
||||||
|
{1, 1, 0}, {1, 1, 1}, // Nested.PS.I, Nested.PS.PI
|
||||||
|
{1, 1, 2}, {1, 1, 3}, // Nested.PS.CI, Nested.PS.PCI
|
||||||
|
{1, 1, 4, 0}, {1, 1, 4, 1}, // Nested.PS.S2.B, Nested.PS.S2.PB
|
||||||
|
{1, 1, 5, 0}, {1, 1, 5, 1}, // Nested.PS.PS2.B, Nested.PS.PS2.PB
|
||||||
|
}, paths)
|
||||||
|
assert.True(t, len(names) == len(paths))
|
||||||
|
|
||||||
testStruct := Embedded{}
|
testStruct := Embedded{}
|
||||||
for i, path := range paths {
|
for i, path := range paths {
|
||||||
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
val, field := followPathThroughStructs(reflect.ValueOf(&testStruct), path)
|
||||||
assert.True(t, val.IsValid())
|
assert.True(t, val.IsValid())
|
||||||
assert.True(t, strings.Contains(names[i], field.Name))
|
assert.True(t, strings.Contains(names[i][len(names[i])-1], field.Name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompileQuery(t *testing.T) {
|
||||||
|
t.Run("simple struct", func(t *testing.T) {
|
||||||
|
type Dest struct {
|
||||||
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
|
assert.Equal(t, "SELECT foo, bar FROM greeblies", compiled.query)
|
||||||
|
})
|
||||||
|
t.Run("complex structs", func(t *testing.T) {
|
||||||
|
type CustomInt int
|
||||||
|
type S2 struct {
|
||||||
|
B bool `db:"B"`
|
||||||
|
PB *bool `db:"PB"`
|
||||||
|
|
||||||
|
NoTag string
|
||||||
|
}
|
||||||
|
type S struct {
|
||||||
|
I int `db:"I"`
|
||||||
|
PI *int `db:"PI"`
|
||||||
|
CI CustomInt `db:"CI"`
|
||||||
|
PCI *CustomInt `db:"PCI"`
|
||||||
|
S2 `db:"S2"` // embedded!
|
||||||
|
PS2 *S2 `db:"PS2"`
|
||||||
|
|
||||||
|
NoTag int
|
||||||
|
}
|
||||||
|
type Nested struct {
|
||||||
|
S S `db:"S"`
|
||||||
|
PS *S `db:"PS"`
|
||||||
|
|
||||||
|
NoTag S
|
||||||
|
}
|
||||||
|
type Dest struct {
|
||||||
|
NoTag S
|
||||||
|
Nested
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest{}))
|
||||||
|
assert.Equal(t, "SELECT S.I, S.PI, S.CI, S.PCI, S_S2.B, S_S2.PB, S_PS2.B, S_PS2.PB, PS.I, PS.PI, PS.CI, PS.PCI, PS_S2.B, PS_S2.PB, PS_PS2.B, PS_PS2.PB FROM greeblies", compiled.query)
|
||||||
|
})
|
||||||
|
t.Run("int", func(t *testing.T) {
|
||||||
|
type Dest int
|
||||||
|
|
||||||
|
// There should be no error here because we do not need to extract columns from
|
||||||
|
// the destination type. There may be errors down the line in value iteration, but
|
||||||
|
// that is always the case if the Go types don't match the query.
|
||||||
|
compiled := compileQuery("SELECT id FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
|
assert.Equal(t, "SELECT id FROM greeblies", compiled.query)
|
||||||
|
})
|
||||||
|
t.Run("just one table", func(t *testing.T) {
|
||||||
|
type Dest struct {
|
||||||
|
Foo int `db:"foo"`
|
||||||
|
Bar bool `db:"bar"`
|
||||||
|
Nope string // no tag
|
||||||
|
}
|
||||||
|
|
||||||
|
// The prefix is necessary because otherwise we would have to provide a struct with
|
||||||
|
// a db tag in order to provide the query with the `greeblies.` prefix in the
|
||||||
|
// final query. This comes up a lot when we do a JOIN to help with a condition, but
|
||||||
|
// don't actually care about any of the data we joined to.
|
||||||
|
compiled := compileQuery(
|
||||||
|
"SELECT $columns{greeblies} FROM greeblies NATURAL JOIN props",
|
||||||
|
reflect.TypeOf(Dest{}),
|
||||||
|
)
|
||||||
|
assert.Equal(t, "SELECT greeblies.foo, greeblies.bar FROM greeblies NATURAL JOIN props", compiled.query)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("using $columns without a struct is not allowed", func(t *testing.T) {
|
||||||
|
type Dest int
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
compileQuery("SELECT $columns FROM greeblies", reflect.TypeOf(Dest(0)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestQueryBuilder(t *testing.T) {
|
func TestQueryBuilder(t *testing.T) {
|
||||||
t.Run("happy time", func(t *testing.T) {
|
t.Run("happy time", func(t *testing.T) {
|
||||||
var qb QueryBuilder
|
var qb QueryBuilder
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*
|
||||||
|
This package contains lowish-level APIs for making database queries to our Postgres database. It streamlines the process of mapping query results to Go types, while allowing you to write arbitrary SQL queries.
|
||||||
|
|
||||||
|
The primary functions are Query and QueryIterator. See the package and function examples for detailed usage.
|
||||||
|
|
||||||
|
Query syntax
|
||||||
|
|
||||||
|
This package allows a few small extensions to SQL syntax to streamline the interaction between Go and Postgres.
|
||||||
|
|
||||||
|
Arguments can be provided using placeholders like $1, $2, etc. All arguments will be safely escaped and mapped from their Go type to the correct Postgres type. (This is a direct proxy to pgx.)
|
||||||
|
|
||||||
|
projectIDs, err := db.Query[int](ctx, conn,
|
||||||
|
`
|
||||||
|
SELECT id
|
||||||
|
FROM handmade_project
|
||||||
|
WHERE
|
||||||
|
slug = ANY($1)
|
||||||
|
AND hidden = $2
|
||||||
|
`,
|
||||||
|
[]string{"4coder", "metadesk"},
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
(This also demonstrates a useful tip: if you want to use a slice in your query, use Postgres arrays instead of IN.)
|
||||||
|
|
||||||
|
When querying individual fields, you can simply select the field like so:
|
||||||
|
|
||||||
|
ids, err := db.Query[int](ctx, conn, `SELECT id FROM handmade_project`)
|
||||||
|
|
||||||
|
To query multiple columns at once, you may use a struct type with `db:"column_name"` tags, and the special $columns placeholder:
|
||||||
|
|
||||||
|
type Project struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Slug string `db:"slug"`
|
||||||
|
DateCreated time.Time `db:"date_created"`
|
||||||
|
}
|
||||||
|
projects, err := db.Query[Project](ctx, conn, `SELECT $columns FROM ...`)
|
||||||
|
// Resulting query:
|
||||||
|
// SELECT id, slug, date_created FROM ...
|
||||||
|
|
||||||
|
Sometimes a table name prefix is required on each column to disambiguate between column names, especially when performing a JOIN. In those situations, you can include the prefix in the $columns placeholder like $columns{prefix}:
|
||||||
|
|
||||||
|
type Project struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Slug string `db:"slug"`
|
||||||
|
DateCreated time.Time `db:"date_created"`
|
||||||
|
}
|
||||||
|
orphanedProjects, err := db.Query[Project](ctx, conn, `
|
||||||
|
SELECT $columns{projects}
|
||||||
|
FROM
|
||||||
|
handmade_project AS projects
|
||||||
|
LEFT JOIN handmade_user_projects AS uproj
|
||||||
|
WHERE
|
||||||
|
uproj.user_id IS NULL
|
||||||
|
`)
|
||||||
|
// Resulting query:
|
||||||
|
// SELECT projects.id, projects.slug, projects.date_created FROM ...
|
||||||
|
*/
|
||||||
|
package db
|
|
@ -7,7 +7,7 @@ import (
|
||||||
|
|
||||||
type QueryBuilder struct {
|
type QueryBuilder struct {
|
||||||
sql strings.Builder
|
sql strings.Builder
|
||||||
args []interface{}
|
args []any
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -18,7 +18,7 @@ of `$?` will be replaced with the correct argument number.
|
||||||
foo ARG1 bar ARG2 baz $?
|
foo ARG1 bar ARG2 baz $?
|
||||||
foo ARG1 bar ARG2 baz ARG3
|
foo ARG1 bar ARG2 baz ARG3
|
||||||
*/
|
*/
|
||||||
func (qb *QueryBuilder) Add(sql string, args ...interface{}) {
|
func (qb *QueryBuilder) Add(sql string, args ...any) {
|
||||||
numPlaceholders := strings.Count(sql, "$?")
|
numPlaceholders := strings.Count(sql, "$?")
|
||||||
if numPlaceholders != len(args) {
|
if numPlaceholders != len(args) {
|
||||||
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
|
panic(fmt.Errorf("cannot add chunk to query; expected %d arguments but got %d", numPlaceholders, len(args)))
|
||||||
|
@ -37,6 +37,6 @@ func (qb *QueryBuilder) String() string {
|
||||||
return qb.sql.String()
|
return qb.sql.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qb *QueryBuilder) Args() []interface{} {
|
func (qb *QueryBuilder) Args() []any {
|
||||||
return qb.args
|
return qb.args
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,12 +90,9 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type profileResult struct {
|
hmnUser, err := db.QueryOne[models.User](ctx, bot.dbConn,
|
||||||
HMNUser models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
ires, err := db.QueryOne(ctx, bot.dbConn, profileResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
handmade_discorduser AS duser
|
handmade_discorduser AS duser
|
||||||
JOIN auth_user ON duser.hmn_user_id = auth_user.id
|
JOIN auth_user ON duser.hmn_user_id = auth_user.id
|
||||||
|
@ -122,16 +119,15 @@ func (bot *botInstance) handleProfileCommand(ctx context.Context, i *Interaction
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
res := ires.(*profileResult)
|
|
||||||
|
|
||||||
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
|
projectsAndStuff, err := hmndata.FetchProjects(ctx, bot.dbConn, nil, hmndata.ProjectsQuery{
|
||||||
OwnerIDs: []int{res.HMNUser.ID},
|
OwnerIDs: []int{hmnUser.ID},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects")
|
logging.ExtractLogger(ctx).Error().Err(err).Msg("failed to fetch user projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
url := hmnurl.BuildUserProfile(res.HMNUser.Username)
|
url := hmnurl.BuildUserProfile(hmnUser.Username)
|
||||||
msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url)
|
msg := fmt.Sprintf("<@%s>'s profile can be viewed at %s.", member.User.ID, url)
|
||||||
if len(projectsAndStuff) > 0 {
|
if len(projectsAndStuff) > 0 {
|
||||||
projectNoun := "projects"
|
projectNoun := "projects"
|
||||||
|
|
|
@ -250,7 +250,7 @@ func (bot *botInstance) connect(ctx context.Context) error {
|
||||||
// an old one or starting a new one.
|
// an old one or starting a new one.
|
||||||
|
|
||||||
shouldResume := true
|
shouldResume := true
|
||||||
isession, err := db.QueryOne(ctx, bot.dbConn, models.DiscordSession{}, `SELECT $columns FROM discord_session`)
|
session, err := db.QueryOne[models.DiscordSession](ctx, bot.dbConn, `SELECT $columns FROM discord_session`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.NotFound) {
|
if errors.Is(err, db.NotFound) {
|
||||||
// No session yet! Just identify and get on with it
|
// No session yet! Just identify and get on with it
|
||||||
|
@ -262,8 +262,6 @@ func (bot *botInstance) connect(ctx context.Context) error {
|
||||||
|
|
||||||
if shouldResume {
|
if shouldResume {
|
||||||
// Reconnect to the previous session
|
// Reconnect to the previous session
|
||||||
session := isession.(*models.DiscordSession)
|
|
||||||
|
|
||||||
err := bot.sendGatewayMessage(ctx, GatewayMessage{
|
err := bot.sendGatewayMessage(ctx, GatewayMessage{
|
||||||
Opcode: OpcodeResume,
|
Opcode: OpcodeResume,
|
||||||
Data: Resume{
|
Data: Resume{
|
||||||
|
@ -356,7 +354,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
}
|
}
|
||||||
bot.didAckHeartbeat = false
|
bot.didAckHeartbeat = false
|
||||||
|
|
||||||
latestSequenceNumber, err := db.QueryInt(ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`)
|
latestSequenceNumber, err := db.QueryOneScalar[int](ctx, bot.dbConn, `SELECT sequence_number FROM discord_session`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to fetch latest sequence number from the db")
|
log.Error().Err(err).Msg("failed to fetch latest sequence number from the db")
|
||||||
return false
|
return false
|
||||||
|
@ -408,7 +406,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
msgs, err := db.Query(ctx, tx, models.DiscordOutgoingMessage{}, `
|
msgs, err := db.Query[models.DiscordOutgoingMessage](ctx, tx, `
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM discord_outgoingmessages
|
FROM discord_outgoingmessages
|
||||||
ORDER BY id ASC
|
ORDER BY id ASC
|
||||||
|
@ -418,8 +416,7 @@ func (bot *botInstance) doSender(ctx context.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, imsg := range msgs {
|
for _, msg := range msgs {
|
||||||
msg := imsg.(*models.DiscordOutgoingMessage)
|
|
||||||
if time.Now().After(msg.ExpiresAt) {
|
if time.Now().After(msg.ExpiresAt) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,12 +73,9 @@ func RunHistoryWatcher(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{
|
||||||
func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
log := logging.ExtractLogger(ctx)
|
log := logging.ExtractLogger(ctx)
|
||||||
|
|
||||||
type query struct {
|
messagesWithoutContent, err := db.Query[models.DiscordMessage](ctx, dbConn,
|
||||||
Message models.DiscordMessage `db:"msg"`
|
|
||||||
}
|
|
||||||
imessagesWithoutContent, err := db.Query(ctx, dbConn, query{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{msg}
|
||||||
FROM
|
FROM
|
||||||
handmade_discordmessage AS msg
|
handmade_discordmessage AS msg
|
||||||
JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users
|
JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users
|
||||||
|
@ -95,10 +92,10 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(imessagesWithoutContent) > 0 {
|
if len(messagesWithoutContent) > 0 {
|
||||||
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent))
|
log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(messagesWithoutContent))
|
||||||
msgloop:
|
msgloop:
|
||||||
for _, imsg := range imessagesWithoutContent {
|
for _, msg := range messagesWithoutContent {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info().Msg("Scrape was canceled")
|
log.Info().Msg("Scrape was canceled")
|
||||||
|
@ -106,8 +103,6 @@ func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := imsg.(*query).Message
|
|
||||||
|
|
||||||
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
|
discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID)
|
||||||
if errors.Is(err, NotFound) {
|
if errors.Is(err, NotFound) {
|
||||||
// This message has apparently been deleted; delete it from our database
|
// This message has apparently been deleted; delete it from our database
|
||||||
|
|
|
@ -165,7 +165,7 @@ func InternMessage(
|
||||||
dbConn db.ConnOrTx,
|
dbConn db.ConnOrTx,
|
||||||
msg *Message,
|
msg *Message,
|
||||||
) error {
|
) error {
|
||||||
_, err := db.QueryOne(ctx, dbConn, models.DiscordMessage{},
|
_, err := db.QueryOne[models.DiscordMessage](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessage
|
FROM handmade_discordmessage
|
||||||
|
@ -219,7 +219,7 @@ type InternedMessage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
|
func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string) (*InternedMessage, error) {
|
||||||
result, err := db.QueryOne(ctx, dbConn, InternedMessage{},
|
interned, err := db.QueryOne[InternedMessage](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -235,8 +235,6 @@ func FetchInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msgId string)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
interned := result.(*InternedMessage)
|
|
||||||
return interned, nil
|
return interned, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,7 +281,7 @@ func HandleInternedMessage(ctx context.Context, dbConn db.ConnOrTx, msg *Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
|
func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *InternedMessage) error {
|
||||||
isnippet, err := db.QueryOne(ctx, dbConn, models.Snippet{},
|
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_snippet
|
FROM handmade_snippet
|
||||||
|
@ -294,10 +292,6 @@ func DeleteInternedMessage(ctx context.Context, dbConn db.ConnOrTx, interned *In
|
||||||
if err != nil && !errors.Is(err, db.NotFound) {
|
if err != nil && !errors.Is(err, db.NotFound) {
|
||||||
return oops.New(err, "failed to fetch snippet for discord message")
|
return oops.New(err, "failed to fetch snippet for discord message")
|
||||||
}
|
}
|
||||||
var snippet *models.Snippet
|
|
||||||
if !errors.Is(err, db.NotFound) {
|
|
||||||
snippet = isnippet.(*models.Snippet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE(asaf): Also deletes the following through a db cascade:
|
// NOTE(asaf): Also deletes the following through a db cascade:
|
||||||
// * handmade_discordmessageattachment
|
// * handmade_discordmessageattachment
|
||||||
|
@ -367,7 +361,7 @@ func SaveMessageContents(
|
||||||
return oops.New(err, "failed to create or update message contents")
|
return oops.New(err, "failed to create or update message contents")
|
||||||
}
|
}
|
||||||
|
|
||||||
icontent, err := db.QueryOne(ctx, dbConn, models.DiscordMessageContent{},
|
content, err := db.QueryOne[models.DiscordMessageContent](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -380,7 +374,7 @@ func SaveMessageContents(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to fetch message contents")
|
return oops.New(err, "failed to fetch message contents")
|
||||||
}
|
}
|
||||||
interned.MessageContent = icontent.(*models.DiscordMessageContent)
|
interned.MessageContent = content
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save attachments
|
// Save attachments
|
||||||
|
@ -395,12 +389,12 @@ func SaveMessageContents(
|
||||||
|
|
||||||
// Save / delete embeds
|
// Save / delete embeds
|
||||||
if msg.OriginalHasFields("embeds") {
|
if msg.OriginalHasFields("embeds") {
|
||||||
numSavedEmbeds, err := db.QueryInt(ctx, dbConn,
|
numSavedEmbeds, err := db.QueryOneScalar[int](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
WHERE message_id = $1
|
WHERE message_id = $1
|
||||||
`,
|
`,
|
||||||
msg.ID,
|
msg.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -472,7 +466,7 @@ func saveAttachment(
|
||||||
hmnUserID int,
|
hmnUserID int,
|
||||||
discordMessageID string,
|
discordMessageID string,
|
||||||
) (*models.DiscordMessageAttachment, error) {
|
) (*models.DiscordMessageAttachment, error) {
|
||||||
iexisting, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
|
existing, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -481,7 +475,7 @@ func saveAttachment(
|
||||||
attachment.ID,
|
attachment.ID,
|
||||||
)
|
)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return iexisting.(*models.DiscordMessageAttachment), nil
|
return existing, nil
|
||||||
} else if errors.Is(err, db.NotFound) {
|
} else if errors.Is(err, db.NotFound) {
|
||||||
// this is fine, just create it
|
// this is fine, just create it
|
||||||
} else {
|
} else {
|
||||||
|
@ -534,7 +528,7 @@ func saveAttachment(
|
||||||
return nil, oops.New(err, "failed to save Discord attachment data")
|
return nil, oops.New(err, "failed to save Discord attachment data")
|
||||||
}
|
}
|
||||||
|
|
||||||
iDiscordAttachment, err := db.QueryOne(ctx, tx, models.DiscordMessageAttachment{},
|
discordAttachment, err := db.QueryOne[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -546,7 +540,7 @@ func saveAttachment(
|
||||||
return nil, oops.New(err, "failed to fetch new Discord attachment data")
|
return nil, oops.New(err, "failed to fetch new Discord attachment data")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iDiscordAttachment.(*models.DiscordMessageAttachment), nil
|
return discordAttachment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
|
// Saves an embed from Discord. NOTE: This is _not_ idempotent, so only call it
|
||||||
|
@ -636,7 +630,7 @@ func saveEmbed(
|
||||||
return nil, oops.New(err, "failed to insert new embed")
|
return nil, oops.New(err, "failed to insert new embed")
|
||||||
}
|
}
|
||||||
|
|
||||||
iDiscordEmbed, err := db.QueryOne(ctx, tx, models.DiscordMessageEmbed{},
|
discordEmbed, err := db.QueryOne[models.DiscordMessageEmbed](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
|
@ -648,11 +642,11 @@ func saveEmbed(
|
||||||
return nil, oops.New(err, "failed to fetch new Discord embed data")
|
return nil, oops.New(err, "failed to fetch new Discord embed data")
|
||||||
}
|
}
|
||||||
|
|
||||||
return iDiscordEmbed.(*models.DiscordMessageEmbed), nil
|
return discordEmbed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
|
func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID string) (*models.Snippet, error) {
|
||||||
iresult, err := db.QueryOne(ctx, dbConn, models.Snippet{},
|
snippet, err := db.QueryOne[models.Snippet](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_snippet
|
FROM handmade_snippet
|
||||||
|
@ -669,7 +663,7 @@ func FetchSnippetForMessage(ctx context.Context, dbConn db.ConnOrTx, msgID strin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return iresult.(*models.Snippet), nil
|
return snippet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -805,12 +799,9 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
|
||||||
projectIDs[i] = p.Project.ID
|
projectIDs[i] = p.Project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type tagsRow struct {
|
userTags, err := db.Query[models.Tag](ctx, tx,
|
||||||
Tag models.Tag `db:"tags"`
|
|
||||||
}
|
|
||||||
iUserTags, err := db.Query(ctx, tx, tagsRow{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{tags}
|
||||||
FROM
|
FROM
|
||||||
tags
|
tags
|
||||||
JOIN handmade_project AS project ON project.tag = tags.id
|
JOIN handmade_project AS project ON project.tag = tags.id
|
||||||
|
@ -823,8 +814,7 @@ func HandleSnippetForInternedMessage(ctx context.Context, dbConn db.ConnOrTx, in
|
||||||
return oops.New(err, "failed to fetch tags for user projects")
|
return oops.New(err, "failed to fetch tags for user projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, itag := range iUserTags {
|
for _, tag := range userTags {
|
||||||
tag := itag.(*tagsRow).Tag
|
|
||||||
allTags = append(allTags, tag.ID)
|
allTags = append(allTags, tag.ID)
|
||||||
for _, messageTag := range messageTags {
|
for _, messageTag := range messageTags {
|
||||||
if strings.EqualFold(tag.Text, messageTag) {
|
if strings.EqualFold(tag.Text, messageTag) {
|
||||||
|
@ -890,7 +880,7 @@ var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\
|
||||||
|
|
||||||
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
|
func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) {
|
||||||
// Check attachments
|
// Check attachments
|
||||||
attachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{},
|
attachments, err := db.Query[models.DiscordMessageAttachment](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageattachment
|
FROM handmade_discordmessageattachment
|
||||||
|
@ -901,13 +891,12 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, oops.New(err, "failed to fetch message attachments")
|
return nil, nil, oops.New(err, "failed to fetch message attachments")
|
||||||
}
|
}
|
||||||
for _, iattachment := range attachments {
|
for _, attachment := range attachments {
|
||||||
attachment := iattachment.(*models.DiscordMessageAttachment)
|
|
||||||
return &attachment.AssetID, nil, nil
|
return &attachment.AssetID, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check embeds
|
// Check embeds
|
||||||
embeds, err := db.Query(ctx, tx, models.DiscordMessageEmbed{},
|
embeds, err := db.Query[models.DiscordMessageEmbed](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discordmessageembed
|
FROM handmade_discordmessageembed
|
||||||
|
@ -918,8 +907,7 @@ func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.Disco
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, oops.New(err, "failed to fetch discord embeds")
|
return nil, nil, oops.New(err, "failed to fetch discord embeds")
|
||||||
}
|
}
|
||||||
for _, iembed := range embeds {
|
for _, embed := range embeds {
|
||||||
embed := iembed.(*models.DiscordMessageEmbed)
|
|
||||||
if embed.VideoID != nil {
|
if embed.VideoID != nil {
|
||||||
return embed.VideoID, nil, nil
|
return embed.VideoID, nil, nil
|
||||||
} else if embed.ImageID != nil {
|
} else if embed.ImageID != nil {
|
||||||
|
|
|
@ -140,15 +140,15 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the query
|
// Do the query
|
||||||
iprojects, err := db.Query(ctx, dbConn, projectRow{}, qb.String(), qb.Args()...)
|
projectRows, err := db.Query[projectRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch projects")
|
return nil, oops.New(err, "failed to fetch projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch project owners to do permission checks
|
// Fetch project owners to do permission checks
|
||||||
projectIds := make([]int, len(iprojects))
|
projectIds := make([]int, len(projectRows))
|
||||||
for i, iproject := range iprojects {
|
for i, p := range projectRows {
|
||||||
projectIds[i] = iproject.(*projectRow).Project.ID
|
projectIds[i] = p.Project.ID
|
||||||
}
|
}
|
||||||
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
|
projectOwners, err := FetchMultipleProjectsOwners(ctx, tx, projectIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -156,8 +156,7 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
var res []ProjectAndStuff
|
var res []ProjectAndStuff
|
||||||
for i, iproject := range iprojects {
|
for i, p := range projectRows {
|
||||||
row := iproject.(*projectRow)
|
|
||||||
owners := projectOwners[i].Owners
|
owners := projectOwners[i].Owners
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -191,10 +190,10 @@ func FetchProjects(
|
||||||
}
|
}
|
||||||
|
|
||||||
projectGenerallyVisible := true &&
|
projectGenerallyVisible := true &&
|
||||||
row.Project.Lifecycle.In(models.VisibleProjectLifecycles) &&
|
p.Project.Lifecycle.In(models.VisibleProjectLifecycles) &&
|
||||||
!row.Project.Hidden &&
|
!p.Project.Hidden &&
|
||||||
(!row.Project.Personal || allOwnersApproved || row.Project.IsHMN())
|
(!p.Project.Personal || allOwnersApproved || p.Project.IsHMN())
|
||||||
if row.Project.IsHMN() {
|
if p.Project.IsHMN() {
|
||||||
projectGenerallyVisible = true // hard override
|
projectGenerallyVisible = true // hard override
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,11 +204,11 @@ func FetchProjects(
|
||||||
|
|
||||||
if projectVisible {
|
if projectVisible {
|
||||||
res = append(res, ProjectAndStuff{
|
res = append(res, ProjectAndStuff{
|
||||||
Project: row.Project,
|
Project: p.Project,
|
||||||
LogoLightAsset: row.LogoLightAsset,
|
LogoLightAsset: p.LogoLightAsset,
|
||||||
LogoDarkAsset: row.LogoDarkAsset,
|
LogoDarkAsset: p.LogoDarkAsset,
|
||||||
Owners: owners,
|
Owners: owners,
|
||||||
Tag: row.Tag,
|
Tag: p.Tag,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -334,7 +333,7 @@ func FetchMultipleProjectsOwners(
|
||||||
UserID int `db:"user_id"`
|
UserID int `db:"user_id"`
|
||||||
ProjectID int `db:"project_id"`
|
ProjectID int `db:"project_id"`
|
||||||
}
|
}
|
||||||
iuserprojects, err := db.Query(ctx, tx, userProject{},
|
userProjects, err := db.Query[userProject](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_user_projects
|
FROM handmade_user_projects
|
||||||
|
@ -348,9 +347,7 @@ func FetchMultipleProjectsOwners(
|
||||||
|
|
||||||
// Get the unique user IDs from this set and fetch the users from the db
|
// Get the unique user IDs from this set and fetch the users from the db
|
||||||
var userIds []int
|
var userIds []int
|
||||||
for _, iuserproject := range iuserprojects {
|
for _, userProject := range userProjects {
|
||||||
userProject := iuserproject.(*userProject)
|
|
||||||
|
|
||||||
addUserId := true
|
addUserId := true
|
||||||
for _, uid := range userIds {
|
for _, uid := range userIds {
|
||||||
if uid == userProject.UserID {
|
if uid == userProject.UserID {
|
||||||
|
@ -361,14 +358,12 @@ func FetchMultipleProjectsOwners(
|
||||||
userIds = append(userIds, userProject.UserID)
|
userIds = append(userIds, userProject.UserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
type userQuery struct {
|
users, err := db.Query[models.User](ctx, tx,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
iusers, err := db.Query(ctx, tx, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
auth_user
|
||||||
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE
|
WHERE
|
||||||
auth_user.id = ANY($1)
|
auth_user.id = ANY($1)
|
||||||
`,
|
`,
|
||||||
|
@ -383,9 +378,7 @@ func FetchMultipleProjectsOwners(
|
||||||
for i, pid := range projectIds {
|
for i, pid := range projectIds {
|
||||||
res[i] = ProjectOwners{ProjectID: pid}
|
res[i] = ProjectOwners{ProjectID: pid}
|
||||||
}
|
}
|
||||||
for _, iuserproject := range iuserprojects {
|
for _, userProject := range userProjects {
|
||||||
userProject := iuserproject.(*userProject)
|
|
||||||
|
|
||||||
// Get a pointer to the existing record in the result
|
// Get a pointer to the existing record in the result
|
||||||
var projectOwners *ProjectOwners
|
var projectOwners *ProjectOwners
|
||||||
for i := range res {
|
for i := range res {
|
||||||
|
@ -396,10 +389,9 @@ func FetchMultipleProjectsOwners(
|
||||||
|
|
||||||
// Get the full user record we fetched
|
// Get the full user record we fetched
|
||||||
var user *models.User
|
var user *models.User
|
||||||
for _, iuser := range iusers {
|
for _, u := range users {
|
||||||
u := iuser.(*userQuery).User
|
|
||||||
if u.ID == userProject.UserID {
|
if u.ID == userProject.UserID {
|
||||||
user = &u
|
user = u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
@ -473,7 +465,7 @@ func SetProjectTag(
|
||||||
resultTag = p.Tag
|
resultTag = p.Tag
|
||||||
} else if p.Project.TagID == nil {
|
} else if p.Project.TagID == nil {
|
||||||
// Create a tag
|
// Create a tag
|
||||||
itag, err := db.QueryOne(ctx, tx, models.Tag{},
|
tag, err := db.QueryOne[models.Tag](ctx, tx,
|
||||||
`
|
`
|
||||||
INSERT INTO tags (text) VALUES ($1)
|
INSERT INTO tags (text) VALUES ($1)
|
||||||
RETURNING $columns
|
RETURNING $columns
|
||||||
|
@ -483,7 +475,7 @@ func SetProjectTag(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to create new tag for project")
|
return nil, oops.New(err, "failed to create new tag for project")
|
||||||
}
|
}
|
||||||
resultTag = itag.(*models.Tag)
|
resultTag = tag
|
||||||
|
|
||||||
// Attach it to the project
|
// Attach it to the project
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
|
@ -499,7 +491,7 @@ func SetProjectTag(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Update the text of an existing one
|
// Update the text of an existing one
|
||||||
itag, err := db.QueryOne(ctx, tx, models.Tag{},
|
tag, err := db.QueryOne[models.Tag](ctx, tx,
|
||||||
`
|
`
|
||||||
UPDATE tags
|
UPDATE tags
|
||||||
SET text = $1
|
SET text = $1
|
||||||
|
@ -511,7 +503,7 @@ func SetProjectTag(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to update existing tag")
|
return nil, oops.New(err, "failed to update existing tag")
|
||||||
}
|
}
|
||||||
resultTag = itag.(*models.Tag)
|
resultTag = tag
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit(ctx)
|
err = tx.Commit(ctx)
|
||||||
|
|
|
@ -44,10 +44,7 @@ func FetchSnippets(
|
||||||
|
|
||||||
if len(q.Tags) > 0 {
|
if len(q.Tags) > 0 {
|
||||||
// Get snippet IDs with this tag, then use that in the main query
|
// Get snippet IDs with this tag, then use that in the main query
|
||||||
type snippetIDRow struct {
|
snippetIDs, err := db.QueryScalar[int](ctx, tx,
|
||||||
SnippetID int `db:"snippet_id"`
|
|
||||||
}
|
|
||||||
iSnippetIDs, err := db.Query(ctx, tx, snippetIDRow{},
|
|
||||||
`
|
`
|
||||||
SELECT DISTINCT snippet_id
|
SELECT DISTINCT snippet_id
|
||||||
FROM
|
FROM
|
||||||
|
@ -63,14 +60,11 @@ func FetchSnippets(
|
||||||
}
|
}
|
||||||
|
|
||||||
// special early-out: no snippets found for these tags at all
|
// special early-out: no snippets found for these tags at all
|
||||||
if len(iSnippetIDs) == 0 {
|
if len(snippetIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
q.IDs = make([]int, len(iSnippetIDs))
|
q.IDs = snippetIDs
|
||||||
for i := range iSnippetIDs {
|
|
||||||
q.IDs[i] = iSnippetIDs[i].(*snippetIDRow).SnippetID
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var qb db.QueryBuilder
|
var qb db.QueryBuilder
|
||||||
|
@ -125,16 +119,14 @@ func FetchSnippets(
|
||||||
DiscordMessage *models.DiscordMessage `db:"discord_message"`
|
DiscordMessage *models.DiscordMessage `db:"discord_message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, tx, resultRow{}, qb.String(), qb.Args()...)
|
results, err := db.Query[resultRow](ctx, tx, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch threads")
|
return nil, oops.New(err, "failed to fetch threads")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]SnippetAndStuff, len(iresults)) // allocate extra space because why not
|
result := make([]SnippetAndStuff, len(results)) // allocate extra space because why not
|
||||||
snippetIDs := make([]int, len(iresults))
|
snippetIDs := make([]int, len(results))
|
||||||
for i, iresult := range iresults {
|
for i, row := range results {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
result[i] = SnippetAndStuff{
|
result[i] = SnippetAndStuff{
|
||||||
Snippet: row.Snippet,
|
Snippet: row.Snippet,
|
||||||
Owner: row.Owner,
|
Owner: row.Owner,
|
||||||
|
@ -150,7 +142,7 @@ func FetchSnippets(
|
||||||
SnippetID int `db:"snippet_tags.snippet_id"`
|
SnippetID int `db:"snippet_tags.snippet_id"`
|
||||||
Tag *models.Tag `db:"tags"`
|
Tag *models.Tag `db:"tags"`
|
||||||
}
|
}
|
||||||
iSnippetTags, err := db.Query(ctx, tx, snippetTagRow{},
|
snippetTags, err := db.Query[snippetTagRow](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -170,8 +162,7 @@ func FetchSnippets(
|
||||||
for i := range result {
|
for i := range result {
|
||||||
resultBySnippetId[result[i].Snippet.ID] = &result[i]
|
resultBySnippetId[result[i].Snippet.ID] = &result[i]
|
||||||
}
|
}
|
||||||
for _, iSnippetTag := range iSnippetTags {
|
for _, snippetTag := range snippetTags {
|
||||||
snippetTag := iSnippetTag.(*snippetTagRow)
|
|
||||||
item := resultBySnippetId[snippetTag.SnippetID]
|
item := resultBySnippetId[snippetTag.SnippetID]
|
||||||
item.Tags = append(item.Tags, snippetTag.Tag)
|
item.Tags = append(item.Tags, snippetTag.Tag)
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,18 +40,11 @@ func FetchTags(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) ([]*models.T
|
||||||
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
itags, err := db.Query(ctx, dbConn, models.Tag{}, qb.String(), qb.Args()...)
|
tags, err := db.Query[models.Tag](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch tags")
|
return nil, oops.New(err, "failed to fetch tags")
|
||||||
}
|
}
|
||||||
|
return tags, nil
|
||||||
res := make([]*models.Tag, len(itags))
|
|
||||||
for i, itag := range itags {
|
|
||||||
tag := itag.(*models.Tag)
|
|
||||||
res[i] = tag
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {
|
func FetchTag(ctx context.Context, dbConn db.ConnOrTx, q TagQuery) (*models.Tag, error) {
|
||||||
|
|
|
@ -145,15 +145,13 @@ func FetchThreads(
|
||||||
ForumLastReadTime *time.Time `db:"slri.lastread"`
|
ForumLastReadTime *time.Time `db:"slri.lastread"`
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
|
rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch threads")
|
return nil, oops.New(err, "failed to fetch threads")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]ThreadAndStuff, len(iresults))
|
result := make([]ThreadAndStuff, len(rows))
|
||||||
for i, iresult := range iresults {
|
for i, row := range rows {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
hasRead := false
|
hasRead := false
|
||||||
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
|
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.LastPost.PostDate) {
|
||||||
hasRead = true
|
hasRead = true
|
||||||
|
@ -263,7 +261,7 @@ func CountThreads(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...)
|
count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, oops.New(err, "failed to fetch count of threads")
|
return 0, oops.New(err, "failed to fetch count of threads")
|
||||||
}
|
}
|
||||||
|
@ -405,15 +403,13 @@ func FetchPosts(
|
||||||
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
qb.Add(`LIMIT $? OFFSET $?`, q.Limit, q.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
iresults, err := db.Query(ctx, dbConn, resultRow{}, qb.String(), qb.Args()...)
|
rows, err := db.Query[resultRow](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch posts")
|
return nil, oops.New(err, "failed to fetch posts")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]PostAndStuff, len(iresults))
|
result := make([]PostAndStuff, len(rows))
|
||||||
for i, iresult := range iresults {
|
for i, row := range rows {
|
||||||
row := *iresult.(*resultRow)
|
|
||||||
|
|
||||||
hasRead := false
|
hasRead := false
|
||||||
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
|
if currentUser != nil && currentUser.MarkedAllReadAt.After(row.Post.PostDate) {
|
||||||
hasRead = true
|
hasRead = true
|
||||||
|
@ -595,7 +591,7 @@ func CountPosts(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := db.QueryInt(ctx, dbConn, qb.String(), qb.Args()...)
|
count, err := db.QueryOneScalar[int](ctx, dbConn, qb.String(), qb.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, oops.New(err, "failed to count posts")
|
return 0, oops.New(err, "failed to count posts")
|
||||||
}
|
}
|
||||||
|
@ -608,12 +604,9 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
type postResult struct {
|
authorID, err := db.QueryOneScalar[*int](ctx, connOrTx,
|
||||||
AuthorID *int `db:"post.author_id"`
|
|
||||||
}
|
|
||||||
iresult, err := db.QueryOne(ctx, connOrTx, postResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT post.author_id
|
||||||
FROM
|
FROM
|
||||||
handmade_post AS post
|
handmade_post AS post
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -629,9 +622,8 @@ func UserCanEditPost(ctx context.Context, connOrTx db.ConnOrTx, user models.User
|
||||||
panic(oops.New(err, "failed to get author of post when checking permissions"))
|
panic(oops.New(err, "failed to get author of post when checking permissions"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result := iresult.(*postResult)
|
|
||||||
|
|
||||||
return result.AuthorID != nil && *result.AuthorID == user.ID
|
return authorID != nil && *authorID == user.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateNewPost(
|
func CreateNewPost(
|
||||||
|
@ -709,7 +701,7 @@ func DeletePost(
|
||||||
FirstPostID int `db:"first_id"`
|
FirstPostID int `db:"first_id"`
|
||||||
Deleted bool `db:"deleted"`
|
Deleted bool `db:"deleted"`
|
||||||
}
|
}
|
||||||
ti, err := db.QueryOne(ctx, tx, threadInfo{},
|
info, err := db.QueryOne[threadInfo](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -722,7 +714,6 @@ func DeletePost(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(oops.New(err, "failed to fetch thread info"))
|
panic(oops.New(err, "failed to fetch thread info"))
|
||||||
}
|
}
|
||||||
info := ti.(*threadInfo)
|
|
||||||
if info.Deleted {
|
if info.Deleted {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -848,12 +839,9 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
|
||||||
keys = append(keys, key)
|
keys = append(keys, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type assetId struct {
|
assetIDs, err := db.QueryScalar[uuid.UUID](ctx, tx,
|
||||||
AssetID uuid.UUID `db:"id"`
|
|
||||||
}
|
|
||||||
assetResult, err := db.Query(ctx, tx, assetId{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT id
|
||||||
FROM handmade_asset
|
FROM handmade_asset
|
||||||
WHERE s3_key = ANY($1)
|
WHERE s3_key = ANY($1)
|
||||||
`,
|
`,
|
||||||
|
@ -865,8 +853,8 @@ func CreatePostVersion(ctx context.Context, tx pgx.Tx, postId int, unparsedConte
|
||||||
|
|
||||||
var values [][]interface{}
|
var values [][]interface{}
|
||||||
|
|
||||||
for _, asset := range assetResult {
|
for _, assetID := range assetIDs {
|
||||||
values = append(values, []interface{}{postId, asset.(*assetId).AssetID})
|
values = append(values, []interface{}{postId, assetID})
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
|
_, err = tx.CopyFrom(ctx, pgx.Identifier{"handmade_post_asset_usage"}, []string{"post_id", "asset_id"}, pgx.CopyFromRows(values))
|
||||||
|
@ -886,7 +874,7 @@ Returns errThreadEmpty if the thread contains no visible posts any more.
|
||||||
You should probably mark the thread as deleted in this case.
|
You should probably mark the thread as deleted in this case.
|
||||||
*/
|
*/
|
||||||
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
||||||
postsIter, err := db.Query(ctx, tx, models.Post{},
|
posts, err := db.Query[models.Post](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_post
|
FROM handmade_post
|
||||||
|
@ -901,9 +889,7 @@ func FixThreadPostIds(ctx context.Context, tx pgx.Tx, threadId int) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var firstPost, lastPost *models.Post
|
var firstPost, lastPost *models.Post
|
||||||
for _, ipost := range postsIter {
|
for _, post := range posts {
|
||||||
post := ipost.(*models.Post)
|
|
||||||
|
|
||||||
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
|
if firstPost == nil || post.PostDate.Before(firstPost.PostDate) {
|
||||||
firstPost = post
|
firstPost = post
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,12 +22,9 @@ type TwitchStreamer struct {
|
||||||
var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P<login>[^/]+)$`)
|
var twitchRegex = regexp.MustCompile(`twitch\.tv/(?P<login>[^/]+)$`)
|
||||||
|
|
||||||
func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) {
|
func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStreamer, error) {
|
||||||
type linkResult struct {
|
dbStreamers, err := db.Query[models.Link](ctx, dbConn,
|
||||||
Link models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
streamers, err := db.Query(ctx, dbConn, linkResult{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{link}
|
||||||
FROM
|
FROM
|
||||||
handmade_links AS link
|
handmade_links AS link
|
||||||
LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id
|
LEFT JOIN auth_user AS link_owner ON link_owner.id = link.user_id
|
||||||
|
@ -49,10 +46,8 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre
|
||||||
return nil, oops.New(err, "failed to fetch twitch links")
|
return nil, oops.New(err, "failed to fetch twitch links")
|
||||||
}
|
}
|
||||||
|
|
||||||
result := make([]TwitchStreamer, 0, len(streamers))
|
result := make([]TwitchStreamer, 0, len(dbStreamers))
|
||||||
for _, s := range streamers {
|
for _, dbStreamer := range dbStreamers {
|
||||||
dbStreamer := s.(*linkResult).Link
|
|
||||||
|
|
||||||
streamer := TwitchStreamer{
|
streamer := TwitchStreamer{
|
||||||
UserID: dbStreamer.UserID,
|
UserID: dbStreamer.UserID,
|
||||||
ProjectID: dbStreamer.ProjectID,
|
ProjectID: dbStreamer.ProjectID,
|
||||||
|
@ -81,7 +76,7 @@ func FetchTwitchStreamers(ctx context.Context, dbConn db.ConnOrTx) ([]TwitchStre
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) {
|
func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx, userId *int, projectId *int) ([]string, error) {
|
||||||
links, err := db.Query(ctx, dbConn, models.Link{},
|
links, err := db.Query[models.Link](ctx, dbConn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -100,8 +95,7 @@ func FetchTwitchLoginsForUserOrProject(ctx context.Context, dbConn db.ConnOrTx,
|
||||||
result := make([]string, 0, len(links))
|
result := make([]string, 0, len(links))
|
||||||
|
|
||||||
for _, l := range links {
|
for _, l := range links {
|
||||||
url := l.(*models.Link).URL
|
match := twitchRegex.FindStringSubmatch(l.URL)
|
||||||
match := twitchRegex.FindStringSubmatch(url)
|
|
||||||
if match != nil {
|
if match != nil {
|
||||||
login := strings.ToLower(match[twitchRegex.SubexpIndex("login")])
|
login := strings.ToLower(match[twitchRegex.SubexpIndex("login")])
|
||||||
result = append(result, login)
|
result = append(result, login)
|
||||||
|
|
|
@ -110,7 +110,7 @@ func (m PersonalProjects) Up(ctx context.Context, tx pgx.Tx) error {
|
||||||
// Port "jam snippets" to use a tag
|
// Port "jam snippets" to use a tag
|
||||||
//
|
//
|
||||||
|
|
||||||
jamTagId, err := db.QueryInt(ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
|
jamTagId, err := db.QueryOneScalar[int](ctx, tx, `INSERT INTO tags (text) VALUES ('wheeljam') RETURNING id`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to create jam tag")
|
return oops.New(err, "failed to create jam tag")
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,14 +44,10 @@ func (node *SubforumTreeNode) GetLineage() []*Subforum {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
type subforumRow struct {
|
subforums, err := db.Query[Subforum](ctx, conn,
|
||||||
Subforum Subforum `db:"sf"`
|
|
||||||
}
|
|
||||||
rowsSlice, err := db.Query(ctx, conn, subforumRow{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM handmade_subforum
|
||||||
handmade_subforum as sf
|
|
||||||
ORDER BY sort, id ASC
|
ORDER BY sort, id ASC
|
||||||
`,
|
`,
|
||||||
)
|
)
|
||||||
|
@ -59,10 +55,9 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
panic(oops.New(err, "failed to fetch subforum tree"))
|
panic(oops.New(err, "failed to fetch subforum tree"))
|
||||||
}
|
}
|
||||||
|
|
||||||
sfTreeMap := make(map[int]*SubforumTreeNode, len(rowsSlice))
|
sfTreeMap := make(map[int]*SubforumTreeNode, len(subforums))
|
||||||
for _, row := range rowsSlice {
|
for _, sf := range subforums {
|
||||||
sf := row.(*subforumRow).Subforum
|
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: *sf}
|
||||||
sfTreeMap[sf.ID] = &SubforumTreeNode{Subforum: sf}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range sfTreeMap {
|
for _, node := range sfTreeMap {
|
||||||
|
@ -71,9 +66,8 @@ func GetFullSubforumTree(ctx context.Context, conn *pgxpool.Pool) SubforumTree {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, row := range rowsSlice {
|
for _, cat := range subforums {
|
||||||
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
|
// NOTE(asaf): Doing this in a separate loop over rowsSlice to ensure that Children are in db order.
|
||||||
cat := row.(*subforumRow).Subforum
|
|
||||||
node := sfTreeMap[cat.ID]
|
node := sfTreeMap[cat.ID]
|
||||||
if node.Parent != nil {
|
if node.Parent != nil {
|
||||||
node.Parent.Children = append(node.Parent.Children, node)
|
node.Parent.Children = append(node.Parent.Children, node)
|
||||||
|
|
|
@ -440,7 +440,7 @@ func updateStreamStatusInDB(ctx context.Context, conn db.ConnOrTx, status *strea
|
||||||
inserted := false
|
inserted := false
|
||||||
if isStatusRelevant(status) {
|
if isStatusRelevant(status) {
|
||||||
log.Debug().Msg("Status relevant")
|
log.Debug().Msg("Status relevant")
|
||||||
_, err := db.QueryOne(ctx, conn, models.TwitchStream{},
|
_, err := db.QueryOne[models.TwitchStream](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM twitch_streams
|
FROM twitch_streams
|
||||||
|
|
|
@ -207,7 +207,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
userIds = append(userIds, u.User.ID)
|
userIds = append(userIds, u.User.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
userLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -222,8 +222,7 @@ func AdminApprovalQueue(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user links"))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ul := range userLinks {
|
for _, link := range userLinks {
|
||||||
link := ul.(*models.Link)
|
|
||||||
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
userData := unapprovedUsers[userIDToDataIdx[*link.UserID]]
|
||||||
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
userData.UserLinks = append(userData.UserLinks, templates.LinkToTemplate(link))
|
||||||
}
|
}
|
||||||
|
@ -257,12 +256,9 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
return RejectRequest(c, "User id can't be parsed")
|
return RejectRequest(c, "User id can't be parsed")
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
u, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE auth_user.id = $1
|
WHERE auth_user.id = $1
|
||||||
|
@ -276,7 +272,6 @@ func AdminApprovalQueueSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := u.(*userQuery).User
|
|
||||||
|
|
||||||
whatHappened := ""
|
whatHappened := ""
|
||||||
if action == ApprovalQueueActionApprove {
|
if action == ApprovalQueueActionApprove {
|
||||||
|
@ -337,7 +332,7 @@ type UnapprovedPost struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
it, err := db.Query(c.Context(), c.Conn, UnapprovedPost{},
|
posts, err := db.Query[UnapprovedPost](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -358,11 +353,7 @@ func fetchUnapprovedPosts(c *RequestContext) ([]*UnapprovedPost, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved posts")
|
return nil, oops.New(err, "failed to fetch unapproved posts")
|
||||||
}
|
}
|
||||||
var res []*UnapprovedPost
|
return posts, nil
|
||||||
for _, iresult := range it {
|
|
||||||
res = append(res, iresult.(*UnapprovedPost))
|
|
||||||
}
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UnapprovedProject struct {
|
type UnapprovedProject struct {
|
||||||
|
@ -372,12 +363,9 @@ type UnapprovedProject struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
type unapprovedUser struct {
|
ownerIDs, err := db.QueryScalar[int](c.Context(), c.Conn,
|
||||||
ID int `db:"id"`
|
|
||||||
}
|
|
||||||
it, err := db.Query(c.Context(), c.Conn, unapprovedUser{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT id
|
||||||
FROM
|
FROM
|
||||||
auth_user AS u
|
auth_user AS u
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -388,10 +376,6 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, oops.New(err, "failed to fetch unapproved users")
|
return nil, oops.New(err, "failed to fetch unapproved users")
|
||||||
}
|
}
|
||||||
ownerIDs := make([]int, 0, len(it))
|
|
||||||
for _, uid := range it {
|
|
||||||
ownerIDs = append(ownerIDs, uid.(*unapprovedUser).ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
projects, err := hmndata.FetchProjects(c.Context(), c.Conn, c.CurrentUser, hmndata.ProjectsQuery{
|
||||||
OwnerIDs: ownerIDs,
|
OwnerIDs: ownerIDs,
|
||||||
|
@ -406,7 +390,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
projectIDs = append(projectIDs, p.Project.ID)
|
projectIDs = append(projectIDs, p.Project.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
projectLinks, err := db.Query(c.Context(), c.Conn, models.Link{},
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -425,8 +409,7 @@ func fetchUnapprovedProjects(c *RequestContext) ([]UnapprovedProject, error) {
|
||||||
|
|
||||||
for idx, proj := range projects {
|
for idx, proj := range projects {
|
||||||
links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects.
|
links := make([]*models.Link, 0, 10) // NOTE(asaf): 10 should be enough for most projects.
|
||||||
for _, l := range projectLinks {
|
for _, link := range projectLinks {
|
||||||
link := l.(*models.Link)
|
|
||||||
if *link.ProjectID == proj.Project.ID {
|
if *link.ProjectID == proj.Project.ID {
|
||||||
links = append(links, link)
|
links = append(links, link)
|
||||||
}
|
}
|
||||||
|
@ -455,7 +438,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int)
|
||||||
ThreadID int `db:"thread.id"`
|
ThreadID int `db:"thread.id"`
|
||||||
PostID int `db:"post.id"`
|
PostID int `db:"post.id"`
|
||||||
}
|
}
|
||||||
it, err := db.Query(ctx, tx, toDelete{},
|
rows, err := db.Query[toDelete](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -471,8 +454,7 @@ func deleteAllPostsForUser(ctx context.Context, conn *pgxpool.Pool, userId int)
|
||||||
return oops.New(err, "failed to fetch posts to delete for user")
|
return oops.New(err, "failed to fetch posts to delete for user")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, iResult := range it {
|
for _, row := range rows {
|
||||||
row := iResult.(*toDelete)
|
|
||||||
hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID)
|
hmndata.DeletePost(ctx, tx, row.ThreadID, row.PostID)
|
||||||
}
|
}
|
||||||
err = tx.Commit(ctx)
|
err = tx.Commit(ctx)
|
||||||
|
@ -489,9 +471,9 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in
|
||||||
}
|
}
|
||||||
defer tx.Rollback(ctx)
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
toDelete, err := db.Query(ctx, tx, models.Project{},
|
projectIDsToDelete, err := db.QueryScalar[int](ctx, tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT project.id
|
||||||
FROM
|
FROM
|
||||||
handmade_project AS project
|
handmade_project AS project
|
||||||
JOIN handmade_user_projects AS up ON up.project_id = project.id
|
JOIN handmade_user_projects AS up ON up.project_id = project.id
|
||||||
|
@ -504,17 +486,12 @@ func deleteAllProjectsForUser(ctx context.Context, conn *pgxpool.Pool, userId in
|
||||||
return oops.New(err, "failed to fetch user's projects")
|
return oops.New(err, "failed to fetch user's projects")
|
||||||
}
|
}
|
||||||
|
|
||||||
var projectIds []int
|
if len(projectIDsToDelete) > 0 {
|
||||||
for _, p := range toDelete {
|
|
||||||
projectIds = append(projectIds, p.(*models.Project).ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(projectIds) > 0 {
|
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
`
|
`
|
||||||
DELETE FROM handmade_project WHERE id = ANY($1)
|
DELETE FROM handmade_project WHERE id = ANY($1)
|
||||||
`,
|
`,
|
||||||
projectIds,
|
projectIDsToDelete,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oops.New(err, "failed to delete user's projects")
|
return oops.New(err, "failed to delete user's projects")
|
||||||
|
|
|
@ -19,12 +19,9 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
requestedUsername := usernameArgs[0]
|
requestedUsername := usernameArgs[0]
|
||||||
found = true
|
found = true
|
||||||
c.Perf.StartBlock("SQL", "Fetch user")
|
c.Perf.StartBlock("SQL", "Fetch user")
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
auth_user
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
|
@ -43,7 +40,7 @@ func APICheckUsername(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", requestedUsername))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
canonicalUsername = userResult.(*userQuery).User.Username
|
canonicalUsername = user.Username
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,13 +75,11 @@ func Login(c *RequestContext) ResponseData {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE LOWER(username) = LOWER($1)
|
WHERE LOWER(username) = LOWER($1)
|
||||||
`,
|
`,
|
||||||
|
@ -94,7 +92,6 @@ func Login(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := &userRow.(*userQuery).User
|
|
||||||
|
|
||||||
success, err := tryLogin(c, user, password)
|
success, err := tryLogin(c, user, password)
|
||||||
|
|
||||||
|
@ -174,7 +171,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
|
c.Perf.StartBlock("SQL", "Check for existing usernames and emails")
|
||||||
userAlreadyExists := true
|
userAlreadyExists := true
|
||||||
_, err := db.QueryInt(c.Context(), c.Conn,
|
_, err := db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -195,7 +192,7 @@ func RegisterNewUserSubmit(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
emailAlreadyExists := true
|
emailAlreadyExists := true
|
||||||
_, err = db.QueryInt(c.Context(), c.Conn,
|
_, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -454,17 +451,16 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return RejectRequest(c, "You must provide a username and an email address.")
|
return RejectRequest(c, "You must provide a username and an email address.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *models.User
|
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching user")
|
c.Perf.StartBlock("SQL", "Fetching user")
|
||||||
type userQuery struct {
|
type userQuery struct {
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
}
|
}
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
auth_user
|
||||||
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE
|
WHERE
|
||||||
LOWER(username) = LOWER($1)
|
LOWER(username) = LOWER($1)
|
||||||
AND LOWER(email) = LOWER($2)
|
AND LOWER(email) = LOWER($2)
|
||||||
|
@ -478,13 +474,10 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if userRow != nil {
|
|
||||||
user = &userRow.(*userQuery).User
|
|
||||||
}
|
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
c.Perf.StartBlock("SQL", "Fetching existing token")
|
c.Perf.StartBlock("SQL", "Fetching existing token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
resetToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_onetimetoken
|
FROM handmade_onetimetoken
|
||||||
|
@ -501,10 +494,6 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch onetimetoken for user"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var resetToken *models.OneTimeToken
|
|
||||||
if tokenRow != nil {
|
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
|
||||||
}
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
if resetToken != nil {
|
if resetToken != nil {
|
||||||
|
@ -527,7 +516,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
if resetToken == nil {
|
if resetToken == nil {
|
||||||
c.Perf.StartBlock("SQL", "Creating new token")
|
c.Perf.StartBlock("SQL", "Creating new token")
|
||||||
tokenRow, err := db.QueryOne(c.Context(), c.Conn, models.OneTimeToken{},
|
newToken, err := db.QueryOne[models.OneTimeToken](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
INSERT INTO handmade_onetimetoken (token_type, created, expires, token_content, owner_id)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
@ -543,7 +532,7 @@ func RequestPasswordResetSubmit(c *RequestContext) ResponseData {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create onetimetoken"))
|
||||||
}
|
}
|
||||||
resetToken = tokenRow.(*models.OneTimeToken)
|
resetToken = newToken
|
||||||
|
|
||||||
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
err = email.SendPasswordReset(user.Email, user.BestName(), user.Username, resetToken.Content, resetToken.Expires, c.Perf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -787,7 +776,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
User models.User `db:"auth_user"`
|
User models.User `db:"auth_user"`
|
||||||
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
OneTimeToken *models.OneTimeToken `db:"onetimetoken"`
|
||||||
}
|
}
|
||||||
row, err := db.QueryOne(c.Context(), c.Conn, userAndTokenQuery{},
|
data, err := db.QueryOne[userAndTokenQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM auth_user
|
FROM auth_user
|
||||||
|
@ -807,8 +796,7 @@ func validateUsernameAndToken(c *RequestContext, username string, token string,
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if row != nil {
|
if data != nil {
|
||||||
data := row.(*userAndTokenQuery)
|
|
||||||
result.User = &data.User
|
result.User = &data.User
|
||||||
result.OneTimeToken = data.OneTimeToken
|
result.OneTimeToken = data.OneTimeToken
|
||||||
if result.OneTimeToken != nil {
|
if result.OneTimeToken != nil {
|
||||||
|
|
|
@ -558,7 +558,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
||||||
res.ThreadID = threadId
|
res.ThreadID = threadId
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Verify that the thread exists")
|
c.Perf.StartBlock("SQL", "Verify that the thread exists")
|
||||||
threadExists, err := db.QueryBool(c.Context(), c.Conn,
|
threadExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*) > 0
|
SELECT COUNT(*) > 0
|
||||||
FROM handmade_thread
|
FROM handmade_thread
|
||||||
|
@ -586,7 +586,7 @@ func getCommonBlogData(c *RequestContext) (commonBlogData, bool) {
|
||||||
res.PostID = postId
|
res.PostID = postId
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Verify that the post exists")
|
c.Perf.StartBlock("SQL", "Verify that the post exists")
|
||||||
postExists, err := db.QueryBool(c.Context(), c.Conn,
|
postExists, err := db.QueryOneScalar[bool](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*) > 0
|
SELECT COUNT(*) > 0
|
||||||
FROM handmade_post
|
FROM handmade_post
|
||||||
|
|
|
@ -104,7 +104,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c.Context())
|
defer tx.Rollback(c.Context())
|
||||||
|
|
||||||
iDiscordUser, err := db.QueryOne(c.Context(), tx, models.DiscordUser{},
|
discordUser, err := db.QueryOne[models.DiscordUser](c.Context(), tx,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discorduser
|
FROM handmade_discorduser
|
||||||
|
@ -119,7 +119,6 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get Discord user for unlink"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
discordUser := iDiscordUser.(*models.DiscordUser)
|
|
||||||
|
|
||||||
_, err = tx.Exec(c.Context(),
|
_, err = tx.Exec(c.Context(),
|
||||||
`
|
`
|
||||||
|
@ -146,7 +145,7 @@ func DiscordUnlink(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||||
`SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`,
|
`SELECT $columns FROM handmade_discorduser WHERE hmn_user_id = $1`,
|
||||||
c.CurrentUser.ID,
|
c.CurrentUser.ID,
|
||||||
)
|
)
|
||||||
|
@ -157,14 +156,10 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get discord user"))
|
||||||
}
|
}
|
||||||
duser := iduser.(*models.DiscordUser)
|
|
||||||
|
|
||||||
type messageIdQuery struct {
|
msgIDs, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||||
MessageID string `db:"msg.id"`
|
|
||||||
}
|
|
||||||
iMsgIDs, err := db.Query(c.Context(), c.Conn, messageIdQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT msg.id
|
||||||
FROM
|
FROM
|
||||||
handmade_discordmessage AS msg
|
handmade_discordmessage AS msg
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -178,10 +173,6 @@ func DiscordShowcaseBacklog(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, err)
|
return c.ErrorResponse(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var msgIDs []string
|
|
||||||
for _, imsgId := range iMsgIDs {
|
|
||||||
msgIDs = append(msgIDs, imsgId.(*messageIdQuery).MessageID)
|
|
||||||
}
|
|
||||||
for _, msgID := range msgIDs {
|
for _, msgID := range msgIDs {
|
||||||
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
|
interned, err := discord.FetchInternedMessage(c.Context(), c.Conn, msgID)
|
||||||
if err != nil && !errors.Is(err, db.NotFound) {
|
if err != nil && !errors.Is(err, db.NotFound) {
|
||||||
|
|
|
@ -936,7 +936,7 @@ func addForumUrlsToPost(urlContext *hmnurl.UrlContext, p *templates.Post, subfor
|
||||||
// Takes a template post and adds information about how many posts the user has made
|
// Takes a template post and adds information about how many posts the user has made
|
||||||
// on the site.
|
// on the site.
|
||||||
func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) {
|
func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.Post) {
|
||||||
numPosts, err := db.QueryInt(ctx, conn,
|
numPosts, err := db.QueryOneScalar[int](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
@ -956,7 +956,7 @@ func addAuthorCountsToPost(ctx context.Context, conn db.ConnOrTx, p *templates.P
|
||||||
p.AuthorNumPosts = numPosts
|
p.AuthorNumPosts = numPosts
|
||||||
}
|
}
|
||||||
|
|
||||||
numProjects, err := db.QueryInt(ctx, conn,
|
numProjects, err := db.QueryOneScalar[int](ctx, conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
|
|
@ -89,8 +89,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
|
||||||
img.Seek(0, io.SeekStart)
|
img.Seek(0, io.SeekStart)
|
||||||
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
|
io.Copy(hasher, img) // NOTE(asaf): Writing to hash.Hash never returns an error according to the docs
|
||||||
sha1sum := hasher.Sum(nil)
|
sha1sum := hasher.Sum(nil)
|
||||||
// TODO(db): Should use insert helper
|
imageFile, err := db.QueryOne[models.ImageFile](c.Context(), dbConn,
|
||||||
imageFile, err := db.QueryOne(c.Context(), dbConn, models.ImageFile{},
|
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height)
|
INSERT INTO handmade_imagefile (file, size, sha1sum, protected, width, height)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
@ -105,7 +104,7 @@ func SaveImageFile(c *RequestContext, dbConn db.ConnOrTx, fileFieldName string,
|
||||||
}
|
}
|
||||||
|
|
||||||
return SaveImageFileResult{
|
return SaveImageFileResult{
|
||||||
ImageFile: imageFile.(*models.ImageFile),
|
ImageFile: imageFile,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,10 +30,9 @@ func ParseLinks(text string) []ParsedLink {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func LinksToText(links []interface{}) string {
|
func LinksToText(links []*models.Link) string {
|
||||||
linksText := ""
|
linksText := ""
|
||||||
for _, l := range links {
|
for _, link := range links {
|
||||||
link := l.(*models.Link)
|
|
||||||
linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name)
|
linksText += fmt.Sprintf("%s %s\n", link.URL, link.Name)
|
||||||
}
|
}
|
||||||
return linksText
|
return linksText
|
||||||
|
|
|
@ -532,11 +532,12 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
Podcast models.Podcast `db:"podcast"`
|
Podcast models.Podcast `db:"podcast"`
|
||||||
ImageFilename string `db:"imagefile.file"`
|
ImageFilename string `db:"imagefile.file"`
|
||||||
}
|
}
|
||||||
podcastQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastQuery{},
|
podcastQueryResult, err := db.QueryOne[podcastQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcast AS podcast
|
FROM
|
||||||
LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id
|
handmade_podcast AS podcast
|
||||||
|
LEFT JOIN handmade_imagefile AS imagefile ON imagefile.id = podcast.image_id
|
||||||
WHERE podcast.project_id = $1
|
WHERE podcast.project_id = $1
|
||||||
`,
|
`,
|
||||||
projectId,
|
projectId,
|
||||||
|
@ -549,18 +550,15 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
return result, oops.New(err, "failed to fetch podcast")
|
return result, oops.New(err, "failed to fetch podcast")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
podcast := podcastQueryResult.(*podcastQuery).Podcast
|
podcast := podcastQueryResult.Podcast
|
||||||
podcastImageFilename := podcastQueryResult.(*podcastQuery).ImageFilename
|
podcastImageFilename := podcastQueryResult.ImageFilename
|
||||||
result.Podcast = &podcast
|
result.Podcast = &podcast
|
||||||
result.ImageFile = podcastImageFilename
|
result.ImageFile = podcastImageFilename
|
||||||
|
|
||||||
if fetchEpisodes {
|
if fetchEpisodes {
|
||||||
type podcastEpisodeQuery struct {
|
|
||||||
Episode models.PodcastEpisode `db:"episode"`
|
|
||||||
}
|
|
||||||
if episodeGUID == "" {
|
if episodeGUID == "" {
|
||||||
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
|
c.Perf.StartBlock("SQL", "Fetch podcast episodes")
|
||||||
podcastEpisodeQueryResult, err := db.Query(c.Context(), c.Conn, podcastEpisodeQuery{},
|
episodes, err := db.Query[models.PodcastEpisode](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcastepisode AS episode
|
FROM handmade_podcastepisode AS episode
|
||||||
|
@ -573,16 +571,14 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, oops.New(err, "failed to fetch podcast episodes")
|
return result, oops.New(err, "failed to fetch podcast episodes")
|
||||||
}
|
}
|
||||||
for _, episodeRow := range podcastEpisodeQueryResult {
|
result.Episodes = episodes
|
||||||
result.Episodes = append(result.Episodes, &episodeRow.(*podcastEpisodeQuery).Episode)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
guid, err := uuid.Parse(episodeGUID)
|
guid, err := uuid.Parse(episodeGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
c.Perf.StartBlock("SQL", "Fetch podcast episode")
|
c.Perf.StartBlock("SQL", "Fetch podcast episode")
|
||||||
podcastEpisodeQueryResult, err := db.QueryOne(c.Context(), c.Conn, podcastEpisodeQuery{},
|
episode, err := db.QueryOne[models.PodcastEpisode](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_podcastepisode AS episode
|
FROM handmade_podcastepisode AS episode
|
||||||
|
@ -599,8 +595,7 @@ func FetchPodcast(c *RequestContext, projectId int, fetchEpisodes bool, episodeG
|
||||||
return result, oops.New(err, "failed to fetch podcast episode")
|
return result, oops.New(err, "failed to fetch podcast episode")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
episode := podcastEpisodeQueryResult.(*podcastEpisodeQuery).Episode
|
result.Episodes = append(result.Episodes, episode)
|
||||||
result.Episodes = append(result.Episodes, &episode)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,12 +187,9 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching screenshots")
|
c.Perf.StartBlock("SQL", "Fetching screenshots")
|
||||||
type screenshotQuery struct {
|
screenshotFilenames, err := db.QueryScalar[string](c.Context(), c.Conn,
|
||||||
Filename string `db:"screenshot.file"`
|
|
||||||
}
|
|
||||||
screenshotQueryResult, err := db.Query(c.Context(), c.Conn, screenshotQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT screenshot.file
|
||||||
FROM
|
FROM
|
||||||
handmade_imagefile AS screenshot
|
handmade_imagefile AS screenshot
|
||||||
INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id
|
INNER JOIN handmade_project_screenshots ON screenshot.id = handmade_project_screenshots.imagefile_id
|
||||||
|
@ -207,10 +204,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
c.Perf.EndBlock()
|
c.Perf.EndBlock()
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||||
type projectLinkQuery struct {
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
Link models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
projectLinkResult, err := db.Query(c.Context(), c.Conn, projectLinkQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -237,7 +231,7 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
Thread models.Thread `db:"thread"`
|
Thread models.Thread `db:"thread"`
|
||||||
Author models.User `db:"author"`
|
Author models.User `db:"author"`
|
||||||
}
|
}
|
||||||
postQueryResult, err := db.Query(c.Context(), c.Conn, postQuery{},
|
posts, err := db.Query[postQuery](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -318,21 +312,21 @@ func ProjectHomepage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, screenshot := range screenshotQueryResult {
|
for _, screenshotFilename := range screenshotFilenames {
|
||||||
templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshot.(*screenshotQuery).Filename))
|
templateData.Screenshots = append(templateData.Screenshots, hmnurl.BuildUserFile(screenshotFilename))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, link := range projectLinkResult {
|
for _, link := range projectLinks {
|
||||||
templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(&link.(*projectLinkQuery).Link))
|
templateData.ProjectLinks = append(templateData.ProjectLinks, templates.LinkToTemplate(link))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, post := range postQueryResult {
|
for _, post := range posts {
|
||||||
templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem(
|
templateData.RecentActivity = append(templateData.RecentActivity, PostToTimelineItem(
|
||||||
c.UrlContext,
|
c.UrlContext,
|
||||||
lineageBuilder,
|
lineageBuilder,
|
||||||
&post.(*postQuery).Post,
|
&post.Post,
|
||||||
&post.(*postQuery).Thread,
|
&post.Thread,
|
||||||
&post.(*postQuery).Author,
|
&post.Author,
|
||||||
c.Theme,
|
c.Theme,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
@ -498,7 +492,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetching project links")
|
c.Perf.StartBlock("SQL", "Fetching project links")
|
||||||
projectLinkResult, err := db.Query(c.Context(), c.Conn, models.Link{},
|
projectLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -525,7 +519,7 @@ func ProjectEdit(c *RequestContext) ResponseData {
|
||||||
c.Theme,
|
c.Theme,
|
||||||
)
|
)
|
||||||
|
|
||||||
projectSettings.LinksText = LinksToText(projectLinkResult)
|
projectSettings.LinksText = LinksToText(projectLinks)
|
||||||
|
|
||||||
var res ResponseData
|
var res ResponseData
|
||||||
res.MustWriteTemplate("project_edit.html", ProjectEditData{
|
res.MustWriteTemplate("project_edit.html", ProjectEditData{
|
||||||
|
@ -822,14 +816,12 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
owners, err := db.Query[models.User](ctx, tx,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
ownerRows, err := db.Query(ctx, tx, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
auth_user
|
||||||
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE LOWER(username) = ANY ($1)
|
WHERE LOWER(username) = ANY ($1)
|
||||||
`,
|
`,
|
||||||
payload.OwnerUsernames,
|
payload.OwnerUsernames,
|
||||||
|
@ -849,7 +841,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
return oops.New(err, "Failed to delete project owners")
|
return oops.New(err, "Failed to delete project owners")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ownerRow := range ownerRows {
|
for _, owner := range owners {
|
||||||
_, err = tx.Exec(ctx,
|
_, err = tx.Exec(ctx,
|
||||||
`
|
`
|
||||||
INSERT INTO handmade_user_projects
|
INSERT INTO handmade_user_projects
|
||||||
|
@ -857,7 +849,7 @@ func updateProject(ctx context.Context, tx pgx.Tx, user *models.User, payload *P
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2)
|
($1, $2)
|
||||||
`,
|
`,
|
||||||
ownerRow.(*userQuery).User.ID,
|
owner.ID,
|
||||||
payload.ProjectID,
|
payload.ProjectID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -548,13 +548,11 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userRow, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM auth_user
|
FROM
|
||||||
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
`,
|
`,
|
||||||
|
@ -568,7 +566,6 @@ func getCurrentUserAndSession(c *RequestContext, sessionId string) (*models.User
|
||||||
return nil, nil, oops.New(err, "failed to get user for session")
|
return nil, nil, oops.New(err, "failed to get user for session")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user := &userRow.(*userQuery).User
|
|
||||||
|
|
||||||
return user, session, nil
|
return user, session, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ func TwitchEventSubCallback(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TwitchDebugPage(c *RequestContext) ResponseData {
|
func TwitchDebugPage(c *RequestContext) ResponseData {
|
||||||
streams, err := db.Query(c.Context(), c.Conn, models.TwitchStream{},
|
streams, err := db.Query[models.TwitchStream](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -83,8 +83,7 @@ func TwitchDebugPage(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
html := ""
|
html := ""
|
||||||
for _, stream := range streams {
|
for _, s := range streams {
|
||||||
s := stream.(*models.TwitchStream)
|
|
||||||
html += fmt.Sprintf(`<a href="https://twitch.tv/%s">%s</a>%s<br />`, s.Login, s.Login, s.Title)
|
html += fmt.Sprintf(`<a href="https://twitch.tv/%s">%s</a>%s<br />`, s.Login, s.Login, s.Title)
|
||||||
}
|
}
|
||||||
var res ResponseData
|
var res ResponseData
|
||||||
|
|
|
@ -53,12 +53,9 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
profileUser = c.CurrentUser
|
profileUser = c.CurrentUser
|
||||||
} else {
|
} else {
|
||||||
c.Perf.StartBlock("SQL", "Fetch user")
|
c.Perf.StartBlock("SQL", "Fetch user")
|
||||||
type userQuery struct {
|
user, err := db.QueryOne[models.User](c.Context(), c.Conn,
|
||||||
User models.User `db:"auth_user"`
|
|
||||||
}
|
|
||||||
userResult, err := db.QueryOne(c.Context(), c.Conn, userQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns{auth_user}
|
||||||
FROM
|
FROM
|
||||||
auth_user
|
auth_user
|
||||||
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
LEFT JOIN handmade_asset AS auth_user_avatar ON auth_user_avatar.id = auth_user.avatar_asset_id
|
||||||
|
@ -75,7 +72,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user: %s", username))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
profileUser = &userResult.(*userQuery).User
|
profileUser = user
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -87,10 +84,7 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Perf.StartBlock("SQL", "Fetch user links")
|
c.Perf.StartBlock("SQL", "Fetch user links")
|
||||||
type userLinkQuery struct {
|
userLinks, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
UserLink models.Link `db:"link"`
|
|
||||||
}
|
|
||||||
userLinksSlice, err := db.Query(c.Context(), c.Conn, userLinkQuery{},
|
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM
|
FROM
|
||||||
|
@ -104,9 +98,9 @@ func UserProfile(c *RequestContext) ResponseData {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch links for user: %s", username))
|
||||||
}
|
}
|
||||||
profileUserLinks := make([]templates.Link, 0, len(userLinksSlice))
|
profileUserLinks := make([]templates.Link, 0, len(userLinks))
|
||||||
for _, l := range userLinksSlice {
|
for _, l := range userLinks {
|
||||||
profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(&l.(*userLinkQuery).UserLink))
|
profileUserLinks = append(profileUserLinks, templates.LinkToTemplate(l))
|
||||||
}
|
}
|
||||||
c.Perf.EndBlock()
|
c.Perf.EndBlock()
|
||||||
|
|
||||||
|
@ -231,7 +225,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
DiscordShowcaseBacklogUrl string
|
DiscordShowcaseBacklogUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
links, err := db.Query(c.Context(), c.Conn, models.Link{},
|
links, err := db.Query[models.Link](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_links
|
FROM handmade_links
|
||||||
|
@ -248,7 +242,7 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
|
|
||||||
var tduser *templates.DiscordUser
|
var tduser *templates.DiscordUser
|
||||||
var numUnsavedMessages int
|
var numUnsavedMessages int
|
||||||
iduser, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
duser, err := db.QueryOne[models.DiscordUser](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT $columns
|
SELECT $columns
|
||||||
FROM handmade_discorduser
|
FROM handmade_discorduser
|
||||||
|
@ -261,11 +255,10 @@ func UserSettings(c *RequestContext) ResponseData {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))
|
return c.ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch user's Discord account"))
|
||||||
} else {
|
} else {
|
||||||
duser := iduser.(*models.DiscordUser)
|
|
||||||
tmp := templates.DiscordUserToTemplate(duser)
|
tmp := templates.DiscordUserToTemplate(duser)
|
||||||
tduser = &tmp
|
tduser = &tmp
|
||||||
|
|
||||||
numUnsavedMessages, err = db.QueryInt(c.Context(), c.Conn,
|
numUnsavedMessages, err = db.QueryOneScalar[int](c.Context(), c.Conn,
|
||||||
`
|
`
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM
|
FROM
|
||||||
|
|
Loading…
Reference in New Issue