Add Discord account linking

This commit is contained in:
Ben Visness 2021-08-15 23:40:56 -05:00
parent 38a1188be7
commit d92bf9a9b8
11 changed files with 459 additions and 15 deletions

View File

@ -67,6 +67,11 @@ type DiscordConfig struct {
BotToken string BotToken string
BotUserID string BotUserID string
OAuthClientID string
OAuthClientSecret string
GuildID string
MemberRoleID string
ShowcaseChannelID string ShowcaseChannelID string
LibraryChannelID string LibraryChannelID string
} }

View File

@ -220,7 +220,9 @@ type User struct {
ID string `json:"id"` ID string `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Discriminator string `json:"discriminator"` Discriminator string `json:"discriminator"`
Avatar *string `json:"avatar"`
IsBot bool `json:"bot"` IsBot bool `json:"bot"`
Locale string `json:"locale"`
} }
func UserFromMap(m interface{}) User { func UserFromMap(m interface{}) User {

View File

@ -8,6 +8,8 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url"
"strings"
"git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/logging"
@ -26,13 +28,17 @@ var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersi
var httpClient = &http.Client{} var httpClient = &http.Client{}
func buildUrl(path string) string {
return fmt.Sprintf("%s%s", BaseUrl, path)
}
func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request { func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request {
var bodyReader io.Reader var bodyReader io.Reader
if body != nil { if body != nil {
bodyReader = bytes.NewBuffer(body) bodyReader = bytes.NewBuffer(body)
} }
req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader) req, err := http.NewRequestWithContext(ctx, method, buildUrl(path), bodyReader)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -171,6 +177,121 @@ func CreateDM(ctx context.Context, recipientID string) (*Channel, error) {
return &channel, nil return &channel, nil
} }
type OAuthCodeExchangeResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
}
func ExchangeOAuthCode(ctx context.Context, code, redirectURI string) (*OAuthCodeExchangeResponse, error) {
const name = "OAuth Code Exchange"
body := make(url.Values)
body.Set("client_id", config.Config.Discord.OAuthClientID)
body.Set("client_secret", config.Config.Discord.OAuthClientSecret)
body.Set("grant_type", "authorization_code")
body.Set("code", code)
body.Set("redirect_uri", redirectURI)
bodyStr := body.Encode()
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
"https://discord.com/api/oauth2/token",
strings.NewReader(bodyStr),
)
if err != nil {
panic(err)
}
req.Header.Add("User-Agent", UserAgent)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
return req
})
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
logErrorResponse(ctx, name, res, "")
return nil, oops.New(nil, "received error from Discord")
}
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
panic(err)
}
var tokenResponse OAuthCodeExchangeResponse
err = json.Unmarshal(bodyBytes, &tokenResponse)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord OAuth token")
}
return &tokenResponse, nil
}
func GetCurrentUserAsOAuth(ctx context.Context, accessToken string) (*User, error) {
const name = "Get Current User"
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildUrl("/users/@me"), nil)
if err != nil {
panic(err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken))
req.Header.Add("User-Agent", UserAgent)
return req
})
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
logErrorResponse(ctx, name, res, "")
return nil, oops.New(nil, "received error from Discord")
}
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
panic(err)
}
var user User
err = json.Unmarshal(bodyBytes, &user)
if err != nil {
return nil, oops.New(err, "failed to unmarshal Discord user")
}
return &user, nil
}
func AddGuildMemberRole(ctx context.Context, userID, roleID string) error {
const name = "Delete Message"
path := fmt.Sprintf("/guilds/%s/members/%s/roles/%s", config.Config.Discord.GuildID, userID, roleID)
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
return makeRequest(ctx, http.MethodPut, path, nil)
})
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
logErrorResponse(ctx, name, res, "")
return oops.New(nil, "got unexpected status code when adding role")
}
return nil
}
func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) {
dump, err := httputil.DumpResponse(res, true) dump, err := httputil.DumpResponse(res, true)
if err != nil { if err != nil {

View File

@ -545,6 +545,22 @@ func BuildLibraryResource(projectSlug string, resourceId int) string {
return ProjectUrl(builder.String(), nil, projectSlug) return ProjectUrl(builder.String(), nil, projectSlug)
} }
/*
* Discord OAuth
*/
var RegexDiscordTest = regexp.MustCompile("^/discord$")
func BuildDiscordTest() string {
return Url("/discord", nil)
}
var RegexDiscordOAuthCallback = regexp.MustCompile("^/_discord_callback$")
func BuildDiscordOAuthCallback() string {
return Url("/_discord_callback", nil)
}
/* /*
* Assets * Assets
*/ */

View File

@ -0,0 +1,76 @@
package migrations
import (
"context"
"time"
"git.handmade.network/hmn/hmn/src/migration/types"
"git.handmade.network/hmn/hmn/src/oops"
"github.com/jackc/pgx/v4"
)
func init() {
registerMigration(DiscordData{})
}
type DiscordData struct{}
func (m DiscordData) Version() types.MigrationVersion {
return types.MigrationVersion(time.Date(2021, 8, 16, 2, 34, 40, 0, time.UTC))
}
func (m DiscordData) Name() string {
return "DiscordData"
}
func (m DiscordData) Description() string {
return "Clean up Discord data models"
}
func (m DiscordData) Up(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, `
ALTER TABLE handmade_discord RENAME TO handmade_discorduser;
ALTER TABLE handmade_discorduser
ALTER username SET NOT NULL,
ALTER discriminator SET NOT NULL,
ALTER access_token SET NOT NULL,
ALTER refresh_token SET NOT NULL,
ALTER locale SET NOT NULL,
ALTER userid SET NOT NULL,
ALTER expiry SET NOT NULL;
`)
if err != nil {
return oops.New(err, "failed to fix up discord table")
}
_, err = tx.Exec(ctx, `
ALTER TABLE handmade_discordmessagecontent
DROP CONSTRAINT handmade_discordmess_discord_id_1acc147f_fk_handmade_,
DROP CONSTRAINT handmade_discordmess_message_id_4dfde67d_fk_handmade_,
ADD FOREIGN KEY (discord_id) REFERENCES handmade_discorduser (id) ON DELETE CASCADE,
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
ALTER TABLE handmade_discordmessageattachment
DROP CONSTRAINT handmade_discordmess_asset_id_c64a3c31_fk_handmade_,
DROP CONSTRAINT handmade_discordmess_message_id_d39da9b3_fk_handmade_,
ADD FOREIGN KEY (asset_id) REFERENCES handmade_asset (id) ON DELETE CASCADE,
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
ALTER TABLE handmade_discordmessageembed
DROP CONSTRAINT handmade_discordmess_image_id_9b04bb5f_fk_handmade_,
DROP CONSTRAINT handmade_discordmess_message_id_04f15ce6_fk_handmade_,
DROP CONSTRAINT handmade_discordmess_video_id_1c41289f_fk_handmade_,
ADD FOREIGN KEY (image_id) REFERENCES handmade_asset (id) ON DELETE SET NULL,
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE,
ADD FOREIGN KEY (video_id) REFERENCES handmade_asset (id) ON DELETE SET NULL;
`)
if err != nil {
return oops.New(err, "failed to fix constraints")
}
return nil
}
func (m DiscordData) Down(ctx context.Context, tx pgx.Tx) error {
panic("Implement me")
}

View File

@ -2,8 +2,65 @@ package models
import ( import (
"time" "time"
"github.com/google/uuid"
) )
type DiscordUser struct {
ID int `db:"id"`
Username string `db:"username"`
Discriminator string `db:"discriminator"`
AccessToken string `db:"access_token"`
RefreshToken string `db:"refresh_token"`
Avatar *string `db:"avatar"`
Locale string `db:"locale"`
UserID string `db:"userid"`
Expiry time.Time `db:"expiry"`
HMNUserId int `db:"hmn_user_id"`
}
/*
Logs the existence of a Discord message and what we've done with it.
Created unconditionally for all users, regardless of link status.
Therefore, it must not contain any actual content.
*/
type DiscordMessage struct {
ID string `db:"id"`
ChannelID string `db:"channel_id"`
GuildID *string `db:"guild_id"`
Url string `db:"url"`
UserID string `db:"user_id"`
SentAt time.Time `db:"sent_at"`
SnippetCreated bool `db:"snippet_created"`
}
/*
Stores the content of a Discord message for users with a linked
Discord account. Always created for users with a linked Discord
account, regardless of whether we create snippets or not.
*/
type DiscordMessageContent struct {
MessageID string `db:"message_id"`
LastContent string `db:"last_content"`
DiscordID int `db:"discord_id"`
}
type DiscordMessageAttachment struct {
ID string `db:"id"`
AssetID uuid.UUID `db:"asset_id"`
MessageID string `db:"message_id"`
}
type DiscordMessageEmbed struct {
ID int `db:"id"`
Title *string `db:"title"`
Description *string `db:"description"`
URL *string `db:"url"`
ImageID *uuid.UUID `db:"image_id"`
MessageID string `db:"message_id"`
VideoID *uuid.UUID `db:"video_id"`
}
type DiscordSession struct { type DiscordSession struct {
ID string `db:"session_id"` ID string `db:"session_id"`
SequenceNumber int `db:"sequence_number"` SequenceNumber int `db:"sequence_number"`
@ -15,13 +72,3 @@ type DiscordOutgoingMessage struct {
PayloadJSON string `db:"payload_json"` PayloadJSON string `db:"payload_json"`
ExpiresAt time.Time `db:"expires_at"` ExpiresAt time.Time `db:"expires_at"`
} }
type DiscordMessage struct {
ID string `db:"id"`
ChannelID string `db:"channel_id"`
GuildID *string `db:"guild_id"`
Url string `db:"url"`
UserID string `db:"user_id"`
SentAt time.Time `db:"sent_at"`
SnippetCreated bool `db:"snippet_created"`
}

View File

@ -1,6 +1,7 @@
package templates package templates
import ( import (
"fmt"
"html/template" "html/template"
"net" "net"
"regexp" "regexp"
@ -331,6 +332,19 @@ func PodcastEpisodeToTemplate(projectSlug string, episode *models.PodcastEpisode
} }
} }
func DiscordUserToTemplate(d *models.DiscordUser) DiscordUser {
var avatarUrl string // TODO: Default avatar image
if d.Avatar != nil {
avatarUrl = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", d.UserID, *d.Avatar)
}
return DiscordUser{
Username: d.Username,
Discriminator: d.Discriminator,
Avatar: avatarUrl,
}
}
func maybeString(s *string) string { func maybeString(s *string) string {
if s == nil { if s == nil {
return "" return ""

View File

@ -0,0 +1,14 @@
{{ template "base.html" . }}
{{ define "content" }}
Wow a discord
<div>
{{ with .DiscordUser }}
<img src="{{ .Avatar }}">
{{ .Username }}#{{ .Discriminator }}
{{ else }}
<a href="{{ $.AuthorizeURL }}">Link your account</a>
{{ end }}
</div>
{{ end }}

View File

@ -303,3 +303,9 @@ type EmailBaseData struct {
Subject template.HTML Subject template.HTML
Separator template.HTML Separator template.HTML
} }
type DiscordUser struct {
Username string
Discriminator string
Avatar string
}

140
src/website/discord.go Normal file
View File

@ -0,0 +1,140 @@
package website
import (
"errors"
"fmt"
"net/http"
"net/url"
"time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/config"
"git.handmade.network/hmn/hmn/src/db"
"git.handmade.network/hmn/hmn/src/discord"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates"
)
func DiscordTest(c *RequestContext) ResponseData {
var userDiscord *models.DiscordUser
iUserDiscord, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
`
SELECT $columns
FROM handmade_discorduser
WHERE hmn_user_id = $1
`,
c.CurrentUser.ID,
)
if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) {
// we're ok, just no user
} else {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current user's Discord account"))
}
} else {
userDiscord = iUserDiscord.(*models.DiscordUser)
}
type templateData struct {
templates.BaseData
DiscordUser *templates.DiscordUser
AuthorizeURL string
}
baseData := getBaseData(c)
baseData.Title = "Discord Test"
params := make(url.Values)
params.Set("response_type", "code")
params.Set("client_id", config.Config.Discord.OAuthClientID)
params.Set("scope", "identify")
params.Set("state", c.CurrentSession.CSRFToken)
params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback())
td := templateData{
BaseData: baseData,
AuthorizeURL: fmt.Sprintf("https://discord.com/api/oauth2/authorize?%s", params.Encode()),
}
if userDiscord != nil {
u := templates.DiscordUserToTemplate(userDiscord)
td.DiscordUser = &u
}
var res ResponseData
res.MustWriteTemplate("discordtest.html", td, c.Perf)
return res
}
func DiscordOAuthCallback(c *RequestContext) ResponseData {
query := c.Req.URL.Query()
// Check the state
state := query.Get("state")
if state != c.CurrentSession.CSRFToken {
// CSRF'd!!!!
// TODO(compression): Should this and the CSRF middleware be pulled out to
// a separate function?
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?")
err := auth.DeleteSession(c.Context(), c.Conn, c.CurrentSession.ID)
if err != nil {
c.Logger.Error().Err(err).Msg("failed to delete session on Discord OAuth state failure")
}
res := c.Redirect("/", http.StatusSeeOther)
res.SetCookie(auth.DeleteSessionCookie)
return res
}
// Check for error values and redirect back to ????
if query.Get("error") != "" {
// TODO: actually handle these errors
return ErrorResponse(http.StatusBadRequest, errors.New(query.Get("error")))
}
// Do the actual token exchange and redirect back to ????
code := query.Get("code")
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) // TODO: Redirect to the right place
if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
}
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken)
if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
}
// TODO: Add the role on Discord
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID)
if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
}
_, err = c.Conn.Exec(c.Context(),
`
INSERT INTO handmade_discorduser (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`,
user.Username,
user.Discriminator,
res.AccessToken,
res.RefreshToken,
user.Avatar,
user.Locale,
user.ID,
expiry,
c.CurrentUser.ID,
)
if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info"))
}
return c.Redirect(hmnurl.BuildDiscordTest(), http.StatusSeeOther)
}

View File

@ -197,6 +197,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt
mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode) mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode)
mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS) mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS)
mainRoutes.GET(hmnurl.RegexDiscordTest, authMiddleware(DiscordTest)) // TODO: Delete this route
mainRoutes.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback))
mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS) mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS)
mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData { mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData {
var res ResponseData var res ResponseData