Add Discord account linking
This commit is contained in:
parent
38a1188be7
commit
d92bf9a9b8
|
@ -67,6 +67,11 @@ type DiscordConfig struct {
|
||||||
BotToken string
|
BotToken string
|
||||||
BotUserID string
|
BotUserID string
|
||||||
|
|
||||||
|
OAuthClientID string
|
||||||
|
OAuthClientSecret string
|
||||||
|
|
||||||
|
GuildID string
|
||||||
|
MemberRoleID string
|
||||||
ShowcaseChannelID string
|
ShowcaseChannelID string
|
||||||
LibraryChannelID string
|
LibraryChannelID string
|
||||||
}
|
}
|
||||||
|
|
|
@ -217,10 +217,12 @@ func MessageFromMap(m interface{}) Message {
|
||||||
|
|
||||||
// https://discord.com/developers/docs/resources/user#user-object
|
// https://discord.com/developers/docs/resources/user#user-object
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Discriminator string `json:"discriminator"`
|
Discriminator string `json:"discriminator"`
|
||||||
IsBot bool `json:"bot"`
|
Avatar *string `json:"avatar"`
|
||||||
|
IsBot bool `json:"bot"`
|
||||||
|
Locale string `json:"locale"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func UserFromMap(m interface{}) User {
|
func UserFromMap(m interface{}) User {
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/src/config"
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
"git.handmade.network/hmn/hmn/src/logging"
|
"git.handmade.network/hmn/hmn/src/logging"
|
||||||
|
@ -26,13 +28,17 @@ var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersi
|
||||||
|
|
||||||
var httpClient = &http.Client{}
|
var httpClient = &http.Client{}
|
||||||
|
|
||||||
|
func buildUrl(path string) string {
|
||||||
|
return fmt.Sprintf("%s%s", BaseUrl, path)
|
||||||
|
}
|
||||||
|
|
||||||
func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request {
|
func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request {
|
||||||
var bodyReader io.Reader
|
var bodyReader io.Reader
|
||||||
if body != nil {
|
if body != nil {
|
||||||
bodyReader = bytes.NewBuffer(body)
|
bodyReader = bytes.NewBuffer(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader)
|
req, err := http.NewRequestWithContext(ctx, method, buildUrl(path), bodyReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -171,6 +177,121 @@ func CreateDM(ctx context.Context, recipientID string) (*Channel, error) {
|
||||||
return &channel, nil
|
return &channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OAuthCodeExchangeResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExchangeOAuthCode(ctx context.Context, code, redirectURI string) (*OAuthCodeExchangeResponse, error) {
|
||||||
|
const name = "OAuth Code Exchange"
|
||||||
|
|
||||||
|
body := make(url.Values)
|
||||||
|
body.Set("client_id", config.Config.Discord.OAuthClientID)
|
||||||
|
body.Set("client_secret", config.Config.Discord.OAuthClientSecret)
|
||||||
|
body.Set("grant_type", "authorization_code")
|
||||||
|
body.Set("code", code)
|
||||||
|
body.Set("redirect_uri", redirectURI)
|
||||||
|
bodyStr := body.Encode()
|
||||||
|
|
||||||
|
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||||
|
req, err := http.NewRequestWithContext(
|
||||||
|
ctx,
|
||||||
|
http.MethodPost,
|
||||||
|
"https://discord.com/api/oauth2/token",
|
||||||
|
strings.NewReader(bodyStr),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
req.Header.Add("User-Agent", UserAgent)
|
||||||
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
return req
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
logErrorResponse(ctx, name, res, "")
|
||||||
|
return nil, oops.New(nil, "received error from Discord")
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResponse OAuthCodeExchangeResponse
|
||||||
|
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oops.New(err, "failed to unmarshal Discord OAuth token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCurrentUserAsOAuth(ctx context.Context, accessToken string) (*User, error) {
|
||||||
|
const name = "Get Current User"
|
||||||
|
|
||||||
|
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildUrl("/users/@me"), nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
req.Header.Add("User-Agent", UserAgent)
|
||||||
|
|
||||||
|
return req
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
logErrorResponse(ctx, name, res, "")
|
||||||
|
return nil, oops.New(nil, "received error from Discord")
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
err = json.Unmarshal(bodyBytes, &user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oops.New(err, "failed to unmarshal Discord user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddGuildMemberRole(ctx context.Context, userID, roleID string) error {
|
||||||
|
const name = "Delete Message"
|
||||||
|
|
||||||
|
path := fmt.Sprintf("/guilds/%s/members/%s/roles/%s", config.Config.Discord.GuildID, userID, roleID)
|
||||||
|
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||||
|
return makeRequest(ctx, http.MethodPut, path, nil)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode != http.StatusNoContent {
|
||||||
|
logErrorResponse(ctx, name, res, "")
|
||||||
|
return oops.New(nil, "got unexpected status code when adding role")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) {
|
func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) {
|
||||||
dump, err := httputil.DumpResponse(res, true)
|
dump, err := httputil.DumpResponse(res, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -545,6 +545,22 @@ func BuildLibraryResource(projectSlug string, resourceId int) string {
|
||||||
return ProjectUrl(builder.String(), nil, projectSlug)
|
return ProjectUrl(builder.String(), nil, projectSlug)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Discord OAuth
|
||||||
|
*/
|
||||||
|
|
||||||
|
var RegexDiscordTest = regexp.MustCompile("^/discord$")
|
||||||
|
|
||||||
|
func BuildDiscordTest() string {
|
||||||
|
return Url("/discord", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var RegexDiscordOAuthCallback = regexp.MustCompile("^/_discord_callback$")
|
||||||
|
|
||||||
|
func BuildDiscordOAuthCallback() string {
|
||||||
|
return Url("/_discord_callback", nil)
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Assets
|
* Assets
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.handmade.network/hmn/hmn/src/migration/types"
|
||||||
|
"git.handmade.network/hmn/hmn/src/oops"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registerMigration(DiscordData{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscordData struct{}
|
||||||
|
|
||||||
|
func (m DiscordData) Version() types.MigrationVersion {
|
||||||
|
return types.MigrationVersion(time.Date(2021, 8, 16, 2, 34, 40, 0, time.UTC))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DiscordData) Name() string {
|
||||||
|
return "DiscordData"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DiscordData) Description() string {
|
||||||
|
return "Clean up Discord data models"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DiscordData) Up(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(ctx, `
|
||||||
|
ALTER TABLE handmade_discord RENAME TO handmade_discorduser;
|
||||||
|
ALTER TABLE handmade_discorduser
|
||||||
|
ALTER username SET NOT NULL,
|
||||||
|
ALTER discriminator SET NOT NULL,
|
||||||
|
ALTER access_token SET NOT NULL,
|
||||||
|
ALTER refresh_token SET NOT NULL,
|
||||||
|
ALTER locale SET NOT NULL,
|
||||||
|
ALTER userid SET NOT NULL,
|
||||||
|
ALTER expiry SET NOT NULL;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return oops.New(err, "failed to fix up discord table")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, `
|
||||||
|
ALTER TABLE handmade_discordmessagecontent
|
||||||
|
DROP CONSTRAINT handmade_discordmess_discord_id_1acc147f_fk_handmade_,
|
||||||
|
DROP CONSTRAINT handmade_discordmess_message_id_4dfde67d_fk_handmade_,
|
||||||
|
ADD FOREIGN KEY (discord_id) REFERENCES handmade_discorduser (id) ON DELETE CASCADE,
|
||||||
|
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE handmade_discordmessageattachment
|
||||||
|
DROP CONSTRAINT handmade_discordmess_asset_id_c64a3c31_fk_handmade_,
|
||||||
|
DROP CONSTRAINT handmade_discordmess_message_id_d39da9b3_fk_handmade_,
|
||||||
|
ADD FOREIGN KEY (asset_id) REFERENCES handmade_asset (id) ON DELETE CASCADE,
|
||||||
|
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
|
||||||
|
|
||||||
|
ALTER TABLE handmade_discordmessageembed
|
||||||
|
DROP CONSTRAINT handmade_discordmess_image_id_9b04bb5f_fk_handmade_,
|
||||||
|
DROP CONSTRAINT handmade_discordmess_message_id_04f15ce6_fk_handmade_,
|
||||||
|
DROP CONSTRAINT handmade_discordmess_video_id_1c41289f_fk_handmade_,
|
||||||
|
ADD FOREIGN KEY (image_id) REFERENCES handmade_asset (id) ON DELETE SET NULL,
|
||||||
|
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE,
|
||||||
|
ADD FOREIGN KEY (video_id) REFERENCES handmade_asset (id) ON DELETE SET NULL;
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return oops.New(err, "failed to fix constraints")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DiscordData) Down(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
panic("Implement me")
|
||||||
|
}
|
|
@ -2,8 +2,65 @@ package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type DiscordUser struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Username string `db:"username"`
|
||||||
|
Discriminator string `db:"discriminator"`
|
||||||
|
AccessToken string `db:"access_token"`
|
||||||
|
RefreshToken string `db:"refresh_token"`
|
||||||
|
Avatar *string `db:"avatar"`
|
||||||
|
Locale string `db:"locale"`
|
||||||
|
UserID string `db:"userid"`
|
||||||
|
Expiry time.Time `db:"expiry"`
|
||||||
|
HMNUserId int `db:"hmn_user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Logs the existence of a Discord message and what we've done with it.
|
||||||
|
Created unconditionally for all users, regardless of link status.
|
||||||
|
Therefore, it must not contain any actual content.
|
||||||
|
*/
|
||||||
|
type DiscordMessage struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
ChannelID string `db:"channel_id"`
|
||||||
|
GuildID *string `db:"guild_id"`
|
||||||
|
Url string `db:"url"`
|
||||||
|
UserID string `db:"user_id"`
|
||||||
|
SentAt time.Time `db:"sent_at"`
|
||||||
|
SnippetCreated bool `db:"snippet_created"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Stores the content of a Discord message for users with a linked
|
||||||
|
Discord account. Always created for users with a linked Discord
|
||||||
|
account, regardless of whether we create snippets or not.
|
||||||
|
*/
|
||||||
|
type DiscordMessageContent struct {
|
||||||
|
MessageID string `db:"message_id"`
|
||||||
|
LastContent string `db:"last_content"`
|
||||||
|
DiscordID int `db:"discord_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscordMessageAttachment struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
AssetID uuid.UUID `db:"asset_id"`
|
||||||
|
MessageID string `db:"message_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiscordMessageEmbed struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Title *string `db:"title"`
|
||||||
|
Description *string `db:"description"`
|
||||||
|
URL *string `db:"url"`
|
||||||
|
ImageID *uuid.UUID `db:"image_id"`
|
||||||
|
MessageID string `db:"message_id"`
|
||||||
|
VideoID *uuid.UUID `db:"video_id"`
|
||||||
|
}
|
||||||
|
|
||||||
type DiscordSession struct {
|
type DiscordSession struct {
|
||||||
ID string `db:"session_id"`
|
ID string `db:"session_id"`
|
||||||
SequenceNumber int `db:"sequence_number"`
|
SequenceNumber int `db:"sequence_number"`
|
||||||
|
@ -15,13 +72,3 @@ type DiscordOutgoingMessage struct {
|
||||||
PayloadJSON string `db:"payload_json"`
|
PayloadJSON string `db:"payload_json"`
|
||||||
ExpiresAt time.Time `db:"expires_at"`
|
ExpiresAt time.Time `db:"expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DiscordMessage struct {
|
|
||||||
ID string `db:"id"`
|
|
||||||
ChannelID string `db:"channel_id"`
|
|
||||||
GuildID *string `db:"guild_id"`
|
|
||||||
Url string `db:"url"`
|
|
||||||
UserID string `db:"user_id"`
|
|
||||||
SentAt time.Time `db:"sent_at"`
|
|
||||||
SnippetCreated bool `db:"snippet_created"`
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package templates
|
package templates
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -331,6 +332,19 @@ func PodcastEpisodeToTemplate(projectSlug string, episode *models.PodcastEpisode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DiscordUserToTemplate(d *models.DiscordUser) DiscordUser {
|
||||||
|
var avatarUrl string // TODO: Default avatar image
|
||||||
|
if d.Avatar != nil {
|
||||||
|
avatarUrl = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", d.UserID, *d.Avatar)
|
||||||
|
}
|
||||||
|
|
||||||
|
return DiscordUser{
|
||||||
|
Username: d.Username,
|
||||||
|
Discriminator: d.Discriminator,
|
||||||
|
Avatar: avatarUrl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func maybeString(s *string) string {
|
func maybeString(s *string) string {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
{{ template "base.html" . }}
|
||||||
|
|
||||||
|
{{ define "content" }}
|
||||||
|
Wow a discord
|
||||||
|
|
||||||
|
<div>
|
||||||
|
{{ with .DiscordUser }}
|
||||||
|
<img src="{{ .Avatar }}">
|
||||||
|
{{ .Username }}#{{ .Discriminator }}
|
||||||
|
{{ else }}
|
||||||
|
<a href="{{ $.AuthorizeURL }}">Link your account</a>
|
||||||
|
{{ end }}
|
||||||
|
</div>
|
||||||
|
{{ end }}
|
|
@ -303,3 +303,9 @@ type EmailBaseData struct {
|
||||||
Subject template.HTML
|
Subject template.HTML
|
||||||
Separator template.HTML
|
Separator template.HTML
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DiscordUser struct {
|
||||||
|
Username string
|
||||||
|
Discriminator string
|
||||||
|
Avatar string
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
package website
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.handmade.network/hmn/hmn/src/auth"
|
||||||
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
|
"git.handmade.network/hmn/hmn/src/db"
|
||||||
|
"git.handmade.network/hmn/hmn/src/discord"
|
||||||
|
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||||
|
"git.handmade.network/hmn/hmn/src/models"
|
||||||
|
"git.handmade.network/hmn/hmn/src/oops"
|
||||||
|
"git.handmade.network/hmn/hmn/src/templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DiscordTest(c *RequestContext) ResponseData {
|
||||||
|
var userDiscord *models.DiscordUser
|
||||||
|
iUserDiscord, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
||||||
|
`
|
||||||
|
SELECT $columns
|
||||||
|
FROM handmade_discorduser
|
||||||
|
WHERE hmn_user_id = $1
|
||||||
|
`,
|
||||||
|
c.CurrentUser.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||||
|
// we're ok, just no user
|
||||||
|
} else {
|
||||||
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current user's Discord account"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
userDiscord = iUserDiscord.(*models.DiscordUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
type templateData struct {
|
||||||
|
templates.BaseData
|
||||||
|
DiscordUser *templates.DiscordUser
|
||||||
|
AuthorizeURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
baseData := getBaseData(c)
|
||||||
|
baseData.Title = "Discord Test"
|
||||||
|
|
||||||
|
params := make(url.Values)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("client_id", config.Config.Discord.OAuthClientID)
|
||||||
|
params.Set("scope", "identify")
|
||||||
|
params.Set("state", c.CurrentSession.CSRFToken)
|
||||||
|
params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback())
|
||||||
|
|
||||||
|
td := templateData{
|
||||||
|
BaseData: baseData,
|
||||||
|
AuthorizeURL: fmt.Sprintf("https://discord.com/api/oauth2/authorize?%s", params.Encode()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if userDiscord != nil {
|
||||||
|
u := templates.DiscordUserToTemplate(userDiscord)
|
||||||
|
td.DiscordUser = &u
|
||||||
|
}
|
||||||
|
|
||||||
|
var res ResponseData
|
||||||
|
res.MustWriteTemplate("discordtest.html", td, c.Perf)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func DiscordOAuthCallback(c *RequestContext) ResponseData {
|
||||||
|
query := c.Req.URL.Query()
|
||||||
|
|
||||||
|
// Check the state
|
||||||
|
state := query.Get("state")
|
||||||
|
if state != c.CurrentSession.CSRFToken {
|
||||||
|
// CSRF'd!!!!
|
||||||
|
|
||||||
|
// TODO(compression): Should this and the CSRF middleware be pulled out to
|
||||||
|
// a separate function?
|
||||||
|
|
||||||
|
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?")
|
||||||
|
|
||||||
|
err := auth.DeleteSession(c.Context(), c.Conn, c.CurrentSession.ID)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error().Err(err).Msg("failed to delete session on Discord OAuth state failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
res := c.Redirect("/", http.StatusSeeOther)
|
||||||
|
res.SetCookie(auth.DeleteSessionCookie)
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for error values and redirect back to ????
|
||||||
|
if query.Get("error") != "" {
|
||||||
|
// TODO: actually handle these errors
|
||||||
|
return ErrorResponse(http.StatusBadRequest, errors.New(query.Get("error")))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do the actual token exchange and redirect back to ????
|
||||||
|
code := query.Get("code")
|
||||||
|
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) // TODO: Redirect to the right place
|
||||||
|
if err != nil {
|
||||||
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
|
||||||
|
}
|
||||||
|
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add the role on Discord
|
||||||
|
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Conn.Exec(c.Context(),
|
||||||
|
`
|
||||||
|
INSERT INTO handmade_discorduser (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
`,
|
||||||
|
user.Username,
|
||||||
|
user.Discriminator,
|
||||||
|
res.AccessToken,
|
||||||
|
res.RefreshToken,
|
||||||
|
user.Avatar,
|
||||||
|
user.Locale,
|
||||||
|
user.ID,
|
||||||
|
expiry,
|
||||||
|
c.CurrentUser.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Redirect(hmnurl.BuildDiscordTest(), http.StatusSeeOther)
|
||||||
|
}
|
|
@ -197,6 +197,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt
|
||||||
mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode)
|
mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode)
|
||||||
mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS)
|
mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS)
|
||||||
|
|
||||||
|
mainRoutes.GET(hmnurl.RegexDiscordTest, authMiddleware(DiscordTest)) // TODO: Delete this route
|
||||||
|
mainRoutes.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback))
|
||||||
|
|
||||||
mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS)
|
mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS)
|
||||||
mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData {
|
mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData {
|
||||||
var res ResponseData
|
var res ResponseData
|
||||||
|
|
Loading…
Reference in New Issue