diff --git a/src/discord/rest.go b/src/discord/rest.go index b802063..61e9762 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -7,8 +7,10 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" "net/http/httputil" + "net/textproto" "net/url" "strconv" "strings" @@ -187,13 +189,15 @@ type CreateMessageRequest struct { Content string `json:"content"` } -func CreateMessage(ctx context.Context, channelID string, payloadJSON string) (*Message, error) { +func CreateMessage(ctx context.Context, channelID string, payloadJSON string, files ...FileUpload) (*Message, error) { const name = "Create Message" + contentType, body := makeNewMessageBody(payloadJSON, files) + path := fmt.Sprintf("/channels/%s/messages", channelID) res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { - req := makeRequest(ctx, http.MethodPost, path, []byte(payloadJSON)) - req.Header.Add("Content-Type", "application/json") + req := makeRequest(ctx, http.MethodPost, path, body) + req.Header.Add("Content-Type", contentType) return req }) if err != nil { @@ -568,6 +572,41 @@ func CreateInteractionResponse(ctx context.Context, interactionID, interactionTo return nil } +func EditOriginalInteractionResponse(ctx context.Context, interactionToken string, payloadJSON string, files ...FileUpload) (*Message, error) { + const name = "Edit Original Interaction Response" + + contentType, body := makeNewMessageBody(payloadJSON, files) + + path := fmt.Sprintf("/webhooks/%s/%s/messages/@original", config.Config.Discord.BotUserID, interactionToken) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodPatch, path, body) + req.Header.Add("Content-Type", contentType) + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var msg Message + err = json.Unmarshal(bodyBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord message") + } + + return &msg, nil +} + func GetAuthorizeUrl(state string) string { params := make(url.Values) params.Set("response_type", "code") @@ -578,6 +617,43 @@ func GetAuthorizeUrl(state string) string { return fmt.Sprintf("https://discord.com/api/oauth2/authorize?%s", params.Encode()) } +type FileUpload struct { + Name string + Data []byte +} + +func makeNewMessageBody(payloadJSON string, files []FileUpload) (contentType string, body []byte) { + if len(files) == 0 { + contentType = "application/json" + body = []byte(payloadJSON) + } else { + var bodyBuffer bytes.Buffer + w := multipart.NewWriter(&bodyBuffer) + contentType = w.FormDataContentType() + + jsonHeader := textproto.MIMEHeader{} + jsonHeader.Set("Content-Disposition", `form-data; name="payload_json"`) + jsonHeader.Set("Content-Type", "application/json") + jsonWriter, _ := w.CreatePart(jsonHeader) + jsonWriter.Write([]byte(payloadJSON)) + + for _, f := range files { + formFile, _ := w.CreateFormFile("file", f.Name) + formFile.Write(f.Data) + } + + w.Close() + + body = bodyBuffer.Bytes() + } + + if len(body) == 0 { + panic("somehow we generated an empty body for Discord") + } + + return +} + func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { dump, err := httputil.DumpResponse(res, true) if err != nil {