Add Discord account linking
This commit is contained in:
parent
38a1188be7
commit
d92bf9a9b8
|
@ -67,6 +67,11 @@ type DiscordConfig struct {
|
|||
BotToken string
|
||||
BotUserID string
|
||||
|
||||
OAuthClientID string
|
||||
OAuthClientSecret string
|
||||
|
||||
GuildID string
|
||||
MemberRoleID string
|
||||
ShowcaseChannelID string
|
||||
LibraryChannelID string
|
||||
}
|
||||
|
|
|
@ -220,7 +220,9 @@ type User struct {
|
|||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Discriminator string `json:"discriminator"`
|
||||
Avatar *string `json:"avatar"`
|
||||
IsBot bool `json:"bot"`
|
||||
Locale string `json:"locale"`
|
||||
}
|
||||
|
||||
func UserFromMap(m interface{}) User {
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
|
@ -26,13 +28,17 @@ var UserAgent = fmt.Sprintf("%s (%s, %s)", BotName, UserAgentURL, UserAgentVersi
|
|||
|
||||
var httpClient = &http.Client{}
|
||||
|
||||
func buildUrl(path string) string {
|
||||
return fmt.Sprintf("%s%s", BaseUrl, path)
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, path string, body []byte) *http.Request {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewBuffer(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fmt.Sprintf("%s%s", BaseUrl, path), bodyReader)
|
||||
req, err := http.NewRequestWithContext(ctx, method, buildUrl(path), bodyReader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -171,6 +177,121 @@ func CreateDM(ctx context.Context, recipientID string) (*Channel, error) {
|
|||
return &channel, nil
|
||||
}
|
||||
|
||||
type OAuthCodeExchangeResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
func ExchangeOAuthCode(ctx context.Context, code, redirectURI string) (*OAuthCodeExchangeResponse, error) {
|
||||
const name = "OAuth Code Exchange"
|
||||
|
||||
body := make(url.Values)
|
||||
body.Set("client_id", config.Config.Discord.OAuthClientID)
|
||||
body.Set("client_secret", config.Config.Discord.OAuthClientSecret)
|
||||
body.Set("grant_type", "authorization_code")
|
||||
body.Set("code", code)
|
||||
body.Set("redirect_uri", redirectURI)
|
||||
bodyStr := body.Encode()
|
||||
|
||||
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
"https://discord.com/api/oauth2/token",
|
||||
strings.NewReader(bodyStr),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Add("User-Agent", UserAgent)
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
return req
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
logErrorResponse(ctx, name, res, "")
|
||||
return nil, oops.New(nil, "received error from Discord")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var tokenResponse OAuthCodeExchangeResponse
|
||||
err = json.Unmarshal(bodyBytes, &tokenResponse)
|
||||
if err != nil {
|
||||
return nil, oops.New(err, "failed to unmarshal Discord OAuth token")
|
||||
}
|
||||
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func GetCurrentUserAsOAuth(ctx context.Context, accessToken string) (*User, error) {
|
||||
const name = "Get Current User"
|
||||
|
||||
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildUrl("/users/@me"), nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
req.Header.Add("User-Agent", UserAgent)
|
||||
|
||||
return req
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
logErrorResponse(ctx, name, res, "")
|
||||
return nil, oops.New(nil, "received error from Discord")
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var user User
|
||||
err = json.Unmarshal(bodyBytes, &user)
|
||||
if err != nil {
|
||||
return nil, oops.New(err, "failed to unmarshal Discord user")
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func AddGuildMemberRole(ctx context.Context, userID, roleID string) error {
|
||||
const name = "Delete Message"
|
||||
|
||||
path := fmt.Sprintf("/guilds/%s/members/%s/roles/%s", config.Config.Discord.GuildID, userID, roleID)
|
||||
res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request {
|
||||
return makeRequest(ctx, http.MethodPut, path, nil)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
logErrorResponse(ctx, name, res, "")
|
||||
return oops.New(nil, "got unexpected status code when adding role")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) {
|
||||
dump, err := httputil.DumpResponse(res, true)
|
||||
if err != nil {
|
||||
|
|
|
@ -545,6 +545,22 @@ func BuildLibraryResource(projectSlug string, resourceId int) string {
|
|||
return ProjectUrl(builder.String(), nil, projectSlug)
|
||||
}
|
||||
|
||||
/*
|
||||
* Discord OAuth
|
||||
*/
|
||||
|
||||
var RegexDiscordTest = regexp.MustCompile("^/discord$")
|
||||
|
||||
func BuildDiscordTest() string {
|
||||
return Url("/discord", nil)
|
||||
}
|
||||
|
||||
var RegexDiscordOAuthCallback = regexp.MustCompile("^/_discord_callback$")
|
||||
|
||||
func BuildDiscordOAuthCallback() string {
|
||||
return Url("/_discord_callback", nil)
|
||||
}
|
||||
|
||||
/*
|
||||
* Assets
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/migration/types"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerMigration(DiscordData{})
|
||||
}
|
||||
|
||||
type DiscordData struct{}
|
||||
|
||||
func (m DiscordData) Version() types.MigrationVersion {
|
||||
return types.MigrationVersion(time.Date(2021, 8, 16, 2, 34, 40, 0, time.UTC))
|
||||
}
|
||||
|
||||
func (m DiscordData) Name() string {
|
||||
return "DiscordData"
|
||||
}
|
||||
|
||||
func (m DiscordData) Description() string {
|
||||
return "Clean up Discord data models"
|
||||
}
|
||||
|
||||
func (m DiscordData) Up(ctx context.Context, tx pgx.Tx) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
ALTER TABLE handmade_discord RENAME TO handmade_discorduser;
|
||||
ALTER TABLE handmade_discorduser
|
||||
ALTER username SET NOT NULL,
|
||||
ALTER discriminator SET NOT NULL,
|
||||
ALTER access_token SET NOT NULL,
|
||||
ALTER refresh_token SET NOT NULL,
|
||||
ALTER locale SET NOT NULL,
|
||||
ALTER userid SET NOT NULL,
|
||||
ALTER expiry SET NOT NULL;
|
||||
`)
|
||||
if err != nil {
|
||||
return oops.New(err, "failed to fix up discord table")
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
ALTER TABLE handmade_discordmessagecontent
|
||||
DROP CONSTRAINT handmade_discordmess_discord_id_1acc147f_fk_handmade_,
|
||||
DROP CONSTRAINT handmade_discordmess_message_id_4dfde67d_fk_handmade_,
|
||||
ADD FOREIGN KEY (discord_id) REFERENCES handmade_discorduser (id) ON DELETE CASCADE,
|
||||
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE handmade_discordmessageattachment
|
||||
DROP CONSTRAINT handmade_discordmess_asset_id_c64a3c31_fk_handmade_,
|
||||
DROP CONSTRAINT handmade_discordmess_message_id_d39da9b3_fk_handmade_,
|
||||
ADD FOREIGN KEY (asset_id) REFERENCES handmade_asset (id) ON DELETE CASCADE,
|
||||
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE handmade_discordmessageembed
|
||||
DROP CONSTRAINT handmade_discordmess_image_id_9b04bb5f_fk_handmade_,
|
||||
DROP CONSTRAINT handmade_discordmess_message_id_04f15ce6_fk_handmade_,
|
||||
DROP CONSTRAINT handmade_discordmess_video_id_1c41289f_fk_handmade_,
|
||||
ADD FOREIGN KEY (image_id) REFERENCES handmade_asset (id) ON DELETE SET NULL,
|
||||
ADD FOREIGN KEY (message_id) REFERENCES handmade_discordmessage (id) ON DELETE CASCADE,
|
||||
ADD FOREIGN KEY (video_id) REFERENCES handmade_asset (id) ON DELETE SET NULL;
|
||||
`)
|
||||
if err != nil {
|
||||
return oops.New(err, "failed to fix constraints")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m DiscordData) Down(ctx context.Context, tx pgx.Tx) error {
|
||||
panic("Implement me")
|
||||
}
|
|
@ -2,8 +2,65 @@ package models
|
|||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DiscordUser struct {
|
||||
ID int `db:"id"`
|
||||
Username string `db:"username"`
|
||||
Discriminator string `db:"discriminator"`
|
||||
AccessToken string `db:"access_token"`
|
||||
RefreshToken string `db:"refresh_token"`
|
||||
Avatar *string `db:"avatar"`
|
||||
Locale string `db:"locale"`
|
||||
UserID string `db:"userid"`
|
||||
Expiry time.Time `db:"expiry"`
|
||||
HMNUserId int `db:"hmn_user_id"`
|
||||
}
|
||||
|
||||
/*
|
||||
Logs the existence of a Discord message and what we've done with it.
|
||||
Created unconditionally for all users, regardless of link status.
|
||||
Therefore, it must not contain any actual content.
|
||||
*/
|
||||
type DiscordMessage struct {
|
||||
ID string `db:"id"`
|
||||
ChannelID string `db:"channel_id"`
|
||||
GuildID *string `db:"guild_id"`
|
||||
Url string `db:"url"`
|
||||
UserID string `db:"user_id"`
|
||||
SentAt time.Time `db:"sent_at"`
|
||||
SnippetCreated bool `db:"snippet_created"`
|
||||
}
|
||||
|
||||
/*
|
||||
Stores the content of a Discord message for users with a linked
|
||||
Discord account. Always created for users with a linked Discord
|
||||
account, regardless of whether we create snippets or not.
|
||||
*/
|
||||
type DiscordMessageContent struct {
|
||||
MessageID string `db:"message_id"`
|
||||
LastContent string `db:"last_content"`
|
||||
DiscordID int `db:"discord_id"`
|
||||
}
|
||||
|
||||
type DiscordMessageAttachment struct {
|
||||
ID string `db:"id"`
|
||||
AssetID uuid.UUID `db:"asset_id"`
|
||||
MessageID string `db:"message_id"`
|
||||
}
|
||||
|
||||
type DiscordMessageEmbed struct {
|
||||
ID int `db:"id"`
|
||||
Title *string `db:"title"`
|
||||
Description *string `db:"description"`
|
||||
URL *string `db:"url"`
|
||||
ImageID *uuid.UUID `db:"image_id"`
|
||||
MessageID string `db:"message_id"`
|
||||
VideoID *uuid.UUID `db:"video_id"`
|
||||
}
|
||||
|
||||
type DiscordSession struct {
|
||||
ID string `db:"session_id"`
|
||||
SequenceNumber int `db:"sequence_number"`
|
||||
|
@ -15,13 +72,3 @@ type DiscordOutgoingMessage struct {
|
|||
PayloadJSON string `db:"payload_json"`
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
}
|
||||
|
||||
type DiscordMessage struct {
|
||||
ID string `db:"id"`
|
||||
ChannelID string `db:"channel_id"`
|
||||
GuildID *string `db:"guild_id"`
|
||||
Url string `db:"url"`
|
||||
UserID string `db:"user_id"`
|
||||
SentAt time.Time `db:"sent_at"`
|
||||
SnippetCreated bool `db:"snippet_created"`
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package templates
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"regexp"
|
||||
|
@ -331,6 +332,19 @@ func PodcastEpisodeToTemplate(projectSlug string, episode *models.PodcastEpisode
|
|||
}
|
||||
}
|
||||
|
||||
func DiscordUserToTemplate(d *models.DiscordUser) DiscordUser {
|
||||
var avatarUrl string // TODO: Default avatar image
|
||||
if d.Avatar != nil {
|
||||
avatarUrl = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.png", d.UserID, *d.Avatar)
|
||||
}
|
||||
|
||||
return DiscordUser{
|
||||
Username: d.Username,
|
||||
Discriminator: d.Discriminator,
|
||||
Avatar: avatarUrl,
|
||||
}
|
||||
}
|
||||
|
||||
func maybeString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
{{ template "base.html" . }}
|
||||
|
||||
{{ define "content" }}
|
||||
Wow a discord
|
||||
|
||||
<div>
|
||||
{{ with .DiscordUser }}
|
||||
<img src="{{ .Avatar }}">
|
||||
{{ .Username }}#{{ .Discriminator }}
|
||||
{{ else }}
|
||||
<a href="{{ $.AuthorizeURL }}">Link your account</a>
|
||||
{{ end }}
|
||||
</div>
|
||||
{{ end }}
|
|
@ -303,3 +303,9 @@ type EmailBaseData struct {
|
|||
Subject template.HTML
|
||||
Separator template.HTML
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
Username string
|
||||
Discriminator string
|
||||
Avatar string
|
||||
}
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/auth"
|
||||
"git.handmade.network/hmn/hmn/src/config"
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"git.handmade.network/hmn/hmn/src/discord"
|
||||
"git.handmade.network/hmn/hmn/src/hmnurl"
|
||||
"git.handmade.network/hmn/hmn/src/models"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
)
|
||||
|
||||
func DiscordTest(c *RequestContext) ResponseData {
|
||||
var userDiscord *models.DiscordUser
|
||||
iUserDiscord, err := db.QueryOne(c.Context(), c.Conn, models.DiscordUser{},
|
||||
`
|
||||
SELECT $columns
|
||||
FROM handmade_discorduser
|
||||
WHERE hmn_user_id = $1
|
||||
`,
|
||||
c.CurrentUser.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||
// we're ok, just no user
|
||||
} else {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current user's Discord account"))
|
||||
}
|
||||
} else {
|
||||
userDiscord = iUserDiscord.(*models.DiscordUser)
|
||||
}
|
||||
|
||||
type templateData struct {
|
||||
templates.BaseData
|
||||
DiscordUser *templates.DiscordUser
|
||||
AuthorizeURL string
|
||||
}
|
||||
|
||||
baseData := getBaseData(c)
|
||||
baseData.Title = "Discord Test"
|
||||
|
||||
params := make(url.Values)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", config.Config.Discord.OAuthClientID)
|
||||
params.Set("scope", "identify")
|
||||
params.Set("state", c.CurrentSession.CSRFToken)
|
||||
params.Set("redirect_uri", hmnurl.BuildDiscordOAuthCallback())
|
||||
|
||||
td := templateData{
|
||||
BaseData: baseData,
|
||||
AuthorizeURL: fmt.Sprintf("https://discord.com/api/oauth2/authorize?%s", params.Encode()),
|
||||
}
|
||||
|
||||
if userDiscord != nil {
|
||||
u := templates.DiscordUserToTemplate(userDiscord)
|
||||
td.DiscordUser = &u
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
res.MustWriteTemplate("discordtest.html", td, c.Perf)
|
||||
return res
|
||||
}
|
||||
|
||||
func DiscordOAuthCallback(c *RequestContext) ResponseData {
|
||||
query := c.Req.URL.Query()
|
||||
|
||||
// Check the state
|
||||
state := query.Get("state")
|
||||
if state != c.CurrentSession.CSRFToken {
|
||||
// CSRF'd!!!!
|
||||
|
||||
// TODO(compression): Should this and the CSRF middleware be pulled out to
|
||||
// a separate function?
|
||||
|
||||
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed Discord OAuth state validation - potential attack?")
|
||||
|
||||
err := auth.DeleteSession(c.Context(), c.Conn, c.CurrentSession.ID)
|
||||
if err != nil {
|
||||
c.Logger.Error().Err(err).Msg("failed to delete session on Discord OAuth state failure")
|
||||
}
|
||||
|
||||
res := c.Redirect("/", http.StatusSeeOther)
|
||||
res.SetCookie(auth.DeleteSessionCookie)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// Check for error values and redirect back to ????
|
||||
if query.Get("error") != "" {
|
||||
// TODO: actually handle these errors
|
||||
return ErrorResponse(http.StatusBadRequest, errors.New(query.Get("error")))
|
||||
}
|
||||
|
||||
// Do the actual token exchange and redirect back to ????
|
||||
code := query.Get("code")
|
||||
res, err := discord.ExchangeOAuthCode(c.Context(), code, hmnurl.BuildDiscordOAuthCallback()) // TODO: Redirect to the right place
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to exchange Discord authorization code"))
|
||||
}
|
||||
expiry := time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
|
||||
|
||||
user, err := discord.GetCurrentUserAsOAuth(c.Context(), res.AccessToken)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch Discord user info"))
|
||||
}
|
||||
|
||||
// TODO: Add the role on Discord
|
||||
err = discord.AddGuildMemberRole(c.Context(), user.ID, config.Config.Discord.MemberRoleID)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to add member role"))
|
||||
}
|
||||
|
||||
_, err = c.Conn.Exec(c.Context(),
|
||||
`
|
||||
INSERT INTO handmade_discorduser (username, discriminator, access_token, refresh_token, avatar, locale, userid, expiry, hmn_user_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`,
|
||||
user.Username,
|
||||
user.Discriminator,
|
||||
res.AccessToken,
|
||||
res.RefreshToken,
|
||||
user.Avatar,
|
||||
user.Locale,
|
||||
user.ID,
|
||||
expiry,
|
||||
c.CurrentUser.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to save new Discord user info"))
|
||||
}
|
||||
|
||||
return c.Redirect(hmnurl.BuildDiscordTest(), http.StatusSeeOther)
|
||||
}
|
|
@ -197,6 +197,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt
|
|||
mainRoutes.GET(hmnurl.RegexPodcastEpisode, PodcastEpisode)
|
||||
mainRoutes.GET(hmnurl.RegexPodcastRSS, PodcastRSS)
|
||||
|
||||
mainRoutes.GET(hmnurl.RegexDiscordTest, authMiddleware(DiscordTest)) // TODO: Delete this route
|
||||
mainRoutes.GET(hmnurl.RegexDiscordOAuthCallback, authMiddleware(DiscordOAuthCallback))
|
||||
|
||||
mainRoutes.GET(hmnurl.RegexProjectCSS, ProjectCSS)
|
||||
mainRoutes.GET(hmnurl.RegexEditorPreviewsJS, func(c *RequestContext) ResponseData {
|
||||
var res ResponseData
|
||||
|
|
Loading…
Reference in New Issue