diff --git a/src/db/db.go b/src/db/db.go index 3962082..e6713b2 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -58,6 +58,7 @@ func typeIsQueryable(t reflect.Type) bool { // This interface should match both a direct pgx connection or a pgx transaction. type ConnOrTx interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) } diff --git a/src/discord/cmd/cmd.go b/src/discord/cmd/cmd.go new file mode 100644 index 0000000..c9f52e6 --- /dev/null +++ b/src/discord/cmd/cmd.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "context" + "time" + + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/discord" + "git.handmade.network/hmn/hmn/src/website" + "github.com/spf13/cobra" +) + +func init() { + scrapeCommand := &cobra.Command{ + Use: "discordscrapechannel [...]", + Short: "Scrape the entire history of Discord channels", + Long: "Scrape the entire history of Discord channels, saving message content (but not creating snippets)", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + conn := db.NewConnPool(1, 1) + defer conn.Close() + + for _, channelID := range args { + discord.Scrape(ctx, conn, channelID, time.Time{}, false) + } + }, + } + + website.WebsiteCommand.AddCommand(scrapeCommand) +} diff --git a/src/discord/history.go b/src/discord/history.go new file mode 100644 index 0000000..a042f0e --- /dev/null +++ b/src/discord/history.go @@ -0,0 +1,210 @@ +package discord + +import ( + "context" + "errors" + "time" + + "git.handmade.network/hmn/hmn/src/config" + "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/logging" + "git.handmade.network/hmn/hmn/src/models" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" +) + +func RunHistoryWatcher(ctx context.Context, dbConn *pgxpool.Pool) <-chan struct{} { + log := logging.ExtractLogger(ctx).With().Str("discord goroutine", "history watcher").Logger() + ctx = logging.AttachLoggerToContext(&log, ctx) + + done := make(chan struct{}) + + go func() { + defer func() { + log.Debug().Msg("shut down Discord history watcher") + done <- struct{}{} + }() + + backfillInterval := 1 * time.Hour + + newUserTicker := time.NewTicker(5 * time.Second) + backfillTicker := time.NewTicker(backfillInterval) + + lastBackfillTime := time.Now().Add(-backfillInterval) + for { + select { + case <-ctx.Done(): + return + case <-newUserTicker.C: + // Get content for messages when a user links their account (but do not create snippets) + fetchMissingContent(ctx, dbConn) + case <-backfillTicker.C: + // Run a backfill to patch up places where the Discord bot missed (does create snippets) + Scrape(ctx, dbConn, + config.Config.Discord.ShowcaseChannelID, + lastBackfillTime, + true, + ) + } + } + }() + + return done +} + +func fetchMissingContent(ctx context.Context, dbConn *pgxpool.Pool) { + log := logging.ExtractLogger(ctx) + + type query struct { + Message models.DiscordMessage `db:"msg"` + } + result, err := db.Query(ctx, dbConn, query{}, + ` + SELECT $columns + FROM + handmade_discordmessage AS msg + JOIN handmade_discorduser AS duser ON msg.user_id = duser.userid -- only fetch messages for linked discord users + LEFT JOIN handmade_discordmessagecontent AS c ON c.message_id = msg.id + WHERE + c.last_content IS NULL + AND msg.guild_id = $1 + `, + config.Config.Discord.GuildID, + ) + if err != nil { + log.Error().Err(err).Msg("failed to check for messages without content") + return + } + imessagesWithoutContent := result.ToSlice() + + if len(imessagesWithoutContent) > 0 { + log.Info().Msgf("There are %d Discord messages without content, fetching their content now...", len(imessagesWithoutContent)) + msgloop: + for _, imsg := range imessagesWithoutContent { + select { + case <-ctx.Done(): + log.Info().Msg("Scrape was canceled") + break msgloop + default: + } + + msg := imsg.(*query).Message + + discordMsg, err := GetChannelMessage(ctx, msg.ChannelID, msg.ID) + if errors.Is(err, NotFound) { + // This message has apparently been deleted; delete it from our database + _, err = dbConn.Exec(ctx, + ` + DELETE FROM handmade_discordmessage + WHERE id = $1 + `, + msg.ID, + ) + if err != nil { + log.Error().Err(err).Msg("failed to delete missing message") + continue + } + log.Info().Str("msg id", msg.ID).Msg("deleted missing Discord message") + continue + } else if err != nil { + log.Error().Err(err).Msg("failed to get message") + continue + } + + log.Info().Str("msg", discordMsg.ShortString()).Msg("fetched message for content") + + err = handleHistoryMessage(ctx, dbConn, discordMsg, false) + if err != nil { + log.Error().Err(err).Msg("failed to save content for message") + continue + } + } + log.Info().Msgf("Done fetching missing content") + } +} + +func Scrape(ctx context.Context, dbConn *pgxpool.Pool, channelID string, earliestMessageTime time.Time, createSnippets bool) { + log := logging.ExtractLogger(ctx) + + log.Info().Msg("Starting scrape") + defer log.Info().Msg("Done with scrape!") + + before := "" + for { + msgs, err := GetChannelMessages(ctx, channelID, GetChannelMessagesInput{ + Limit: 100, + Before: before, + }) + if err != nil { + panic(err) // TODO + } + + if len(msgs) == 0 { + logging.Debug().Msg("out of messages, stopping scrape") + return + } + + for _, msg := range msgs { + select { + case <-ctx.Done(): + log.Info().Msg("Scrape was canceled") + return + default: + } + + log.Info().Str("msg", msg.ShortString()).Msg("") + + if !earliestMessageTime.IsZero() && msg.Time().Before(earliestMessageTime) { + logging.ExtractLogger(ctx).Info().Time("earliest", earliestMessageTime).Msg("Saw a message before the specified earliest time; exiting") + return + } + + err := handleHistoryMessage(ctx, dbConn, &msg, true) + if err != nil { + errLog := logging.ExtractLogger(ctx).Error() + if errors.Is(err, errNotEnoughInfo) { + errLog = logging.ExtractLogger(ctx).Warn() + } + errLog.Err(err).Msg("failed to process Discord message") + } + + before = msg.ID + } + } +} + +func handleHistoryMessage(ctx context.Context, dbConn *pgxpool.Pool, msg *Message, createSnippets bool) error { + var tx pgx.Tx + for { + var err error + tx, err = dbConn.Begin(ctx) + if err != nil { + logging.ExtractLogger(ctx).Warn().Err(err).Msg("failed to start transaction for message") + time.Sleep(1 * time.Second) + continue + } + break + } + + newMsg, err := saveMessageAndContents(ctx, tx, msg) + if err != nil { + return err + } + if createSnippets { + if doSnippet, err := allowedToCreateMessageSnippet(ctx, tx, newMsg.UserID); doSnippet && err == nil { + _, err := createMessageSnippet(ctx, tx, msg) + if err != nil { + return err + } + } else if err != nil { + return err + } + } + + err = tx.Commit(ctx) + if err != nil { + return err + } + + return nil +} diff --git a/src/discord/message_test.go b/src/discord/message_test.go new file mode 100644 index 0000000..abea1c6 --- /dev/null +++ b/src/discord/message_test.go @@ -0,0 +1,16 @@ +package discord + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetMessage(t *testing.T) { + // t.Skip("this test is only for debugging") + + msg, err := GetChannelMessage(context.Background(), "404399251276169217", "764575065772916790") + assert.Nil(t, err) + t.Logf("%+v", msg) +} diff --git a/src/discord/payloads.go b/src/discord/payloads.go index 1256cd3..4b6e3c5 100644 --- a/src/discord/payloads.go +++ b/src/discord/payloads.go @@ -252,9 +252,16 @@ func (m *Message) Time() time.Time { return t } +func (m *Message) ShortString() string { + return fmt.Sprintf("%s / %s: \"%s\" (%d attachments, %d embeds)", m.Timestamp, m.Author.Username, m.Content, len(m.Attachments), len(m.Embeds)) +} + func (m *Message) OriginalHasFields(fields ...string) bool { if m.originalMap == nil { - return false + // If we don't know, we assume the fields are there. + // Usually this is because it came from their API, where we + // always have all fields. + return true } for _, field := range fields { diff --git a/src/discord/rest.go b/src/discord/rest.go index 8387e11..db91554 100644 --- a/src/discord/rest.go +++ b/src/discord/rest.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "strconv" "strings" "git.handmade.network/hmn/hmn/src/config" @@ -411,6 +412,93 @@ func RemoveGuildMemberRole(ctx context.Context, userID, roleID string) error { return nil } +func GetChannelMessage(ctx context.Context, channelID, messageID string) (*Message, error) { + const name = "Get Channel Message" + + path := fmt.Sprintf("/channels/%s/messages/%s", channelID, messageID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + return makeRequest(ctx, http.MethodGet, path, nil) + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode == http.StatusNotFound { + return nil, NotFound + } else if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var msg Message + err = json.Unmarshal(bodyBytes, &msg) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord message") + } + + return &msg, nil +} + +type GetChannelMessagesInput struct { + Around string + Before string + After string + Limit int +} + +func GetChannelMessages(ctx context.Context, channelID string, in GetChannelMessagesInput) ([]Message, error) { + const name = "Get Channel Messages" + + path := fmt.Sprintf("/channels/%s/messages", channelID) + res, err := doWithRateLimiting(ctx, name, func(ctx context.Context) *http.Request { + req := makeRequest(ctx, http.MethodGet, path, nil) + q := req.URL.Query() + if in.Around != "" { + q.Add("around", in.Around) + } + if in.Before != "" { + q.Add("before", in.Before) + } + if in.After != "" { + q.Add("after", in.After) + } + if in.Limit != 0 { + q.Add("limit", strconv.Itoa(in.Limit)) + } + req.URL.RawQuery = q.Encode() + + return req + }) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + logErrorResponse(ctx, name, res, "") + return nil, oops.New(nil, "received error from Discord") + } + + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + panic(err) + } + + var msgs []Message + err = json.Unmarshal(bodyBytes, &msgs) + if err != nil { + return nil, oops.New(err, "failed to unmarshal Discord message") + } + + return msgs, nil +} + func logErrorResponse(ctx context.Context, name string, res *http.Response, msg string) { dump, err := httputil.DumpResponse(res, true) if err != nil { diff --git a/src/discord/showcase.go b/src/discord/showcase.go index 3f1f68a..df4ad71 100644 --- a/src/discord/showcase.go +++ b/src/discord/showcase.go @@ -12,13 +12,13 @@ import ( "time" "git.handmade.network/hmn/hmn/src/assets" + "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" "git.handmade.network/hmn/hmn/src/parsing" "github.com/google/uuid" - "github.com/jackc/pgx/v4" ) var reDiscordMessageLink = regexp.MustCompile(`https?://.+?(\s|$)`) @@ -47,7 +47,7 @@ func (bot *botInstance) processShowcaseMsg(ctx context.Context, msg *Message) er defer tx.Rollback(ctx) // save the message, maybe save its contents, and maybe make a snippet too - newMsg, err := bot.saveMessageAndContents(ctx, tx, msg) + newMsg, err := saveMessageAndContents(ctx, tx, msg) if errors.Is(err, errNotEnoughInfo) { logging.ExtractLogger(ctx).Warn(). Interface("msg", msg). @@ -56,8 +56,8 @@ func (bot *botInstance) processShowcaseMsg(ctx context.Context, msg *Message) er } else if err != nil { return err } - if doSnippet, err := bot.allowedToCreateMessageSnippet(ctx, tx, newMsg.UserID); doSnippet && err == nil { - _, err := bot.createMessageSnippet(ctx, tx, msg) + if doSnippet, err := allowedToCreateMessageSnippet(ctx, tx, newMsg.UserID); doSnippet && err == nil { + _, err := createMessageSnippet(ctx, tx, msg) if err != nil { return oops.New(err, "failed to create snippet in gateway") } @@ -120,9 +120,9 @@ the database. This does not create snippets or do anything besides save the message itself. */ -func (bot *botInstance) saveMessage( +func saveMessage( ctx context.Context, - tx pgx.Tx, + tx db.ConnOrTx, msg *Message, ) (*models.DiscordMessage, error) { iDiscordMessage, err := db.QueryOne(ctx, tx, models.DiscordMessage{}, @@ -138,6 +138,16 @@ func (bot *botInstance) saveMessage( return nil, errNotEnoughInfo } + guildID := msg.GuildID + if guildID == nil { + /* + This is weird, but it can happen when we fetch messages from + history instead of receiving it from the gateway. In this case + we just assume it's from the HMN server. + */ + guildID = &config.Config.Discord.GuildID + } + _, err = tx.Exec(ctx, ` INSERT INTO handmade_discordmessage (id, channel_id, guild_id, url, user_id, sent_at, snippet_created) @@ -145,7 +155,7 @@ func (bot *botInstance) saveMessage( `, msg.ID, msg.ChannelID, - *msg.GuildID, + *guildID, msg.JumpURL(), msg.Author.ID, msg.Time(), @@ -184,12 +194,12 @@ snippets. Idempotent; can be called any time whether the message exists or not. */ -func (bot *botInstance) saveMessageAndContents( +func saveMessageAndContents( ctx context.Context, - tx pgx.Tx, + tx db.ConnOrTx, msg *Message, ) (*models.DiscordMessage, error) { - newMsg, err := bot.saveMessage(ctx, tx, msg) + newMsg, err := saveMessage(ctx, tx, msg) if err != nil { return nil, err } @@ -231,7 +241,7 @@ func (bot *botInstance) saveMessageAndContents( // Save attachments if msg.OriginalHasFields("attachments") { for _, attachment := range msg.Attachments { - _, err := bot.saveAttachment(ctx, tx, &attachment, discordUser.HMNUserId, msg.ID) + _, err := saveAttachment(ctx, tx, &attachment, discordUser.HMNUserId, msg.ID) if err != nil { return nil, oops.New(err, "failed to save attachment") } @@ -254,7 +264,7 @@ func (bot *botInstance) saveMessageAndContents( if numSavedEmbeds == 0 { // No embeds yet, so save new ones for _, embed := range msg.Embeds { - _, err := bot.saveEmbed(ctx, tx, &embed, discordUser.HMNUserId, msg.ID) + _, err := saveEmbed(ctx, tx, &embed, discordUser.HMNUserId, msg.ID) if err != nil { return nil, oops.New(err, "failed to save embed") } @@ -310,9 +320,9 @@ func downloadDiscordResource(ctx context.Context, url string) ([]byte, string, e Saves a Discord attachment as an HMN asset. Idempotent; will not create an attachment that already exists */ -func (bot *botInstance) saveAttachment( +func saveAttachment( ctx context.Context, - tx pgx.Tx, + tx db.ConnOrTx, attachment *Attachment, hmnUserID int, discordMessageID string, @@ -394,9 +404,9 @@ func (bot *botInstance) saveAttachment( return iDiscordAttachment.(*models.DiscordMessageAttachment), nil } -func (bot *botInstance) saveEmbed( +func saveEmbed( ctx context.Context, - tx pgx.Tx, + tx db.ConnOrTx, embed *Embed, hmnUserID int, discordMessageID string, @@ -497,8 +507,8 @@ func (bot *botInstance) saveEmbed( return iDiscordEmbed.(*models.DiscordMessageEmbed), nil } -func (bot *botInstance) allowedToCreateMessageSnippet(ctx context.Context, tx pgx.Tx, discordUserId string) (bool, error) { - canSave, err := db.QueryBool(ctx, bot.dbConn, +func allowedToCreateMessageSnippet(ctx context.Context, tx db.ConnOrTx, discordUserId string) (bool, error) { + canSave, err := db.QueryBool(ctx, tx, ` SELECT u.discord_save_showcase FROM @@ -518,7 +528,7 @@ func (bot *botInstance) allowedToCreateMessageSnippet(ctx context.Context, tx pg return canSave, nil } -func (bot *botInstance) createMessageSnippet(ctx context.Context, tx pgx.Tx, msg *Message) (*models.Snippet, error) { +func createMessageSnippet(ctx context.Context, tx db.ConnOrTx, msg *Message) (*models.Snippet, error) { // Check for existing snippet, maybe return it type existingSnippetResult struct { Message models.DiscordMessage `db:"msg"` @@ -548,7 +558,7 @@ func (bot *botInstance) createMessageSnippet(ctx context.Context, tx pgx.Tx, msg // A snippet already exists - maybe update its content, then return it if msg.OriginalHasFields("content") && !existing.Snippet.EditedOnWebsite { contentMarkdown := existing.MessageContent.LastContent - contentHTML := parsing.ParseMarkdown(contentMarkdown, parsing.RealMarkdown) + contentHTML := parsing.ParseMarkdown(contentMarkdown, parsing.DiscordMarkdown) _, err := tx.Exec(ctx, ` @@ -580,14 +590,14 @@ func (bot *botInstance) createMessageSnippet(ctx context.Context, tx pgx.Tx, msg } // Get an asset ID or URL to make a snippet from - assetId, url, err := bot.getSnippetAssetOrUrl(ctx, tx, &existing.Message) - if assetId == nil && url == "" { + assetId, url, err := getSnippetAssetOrUrl(ctx, tx, &existing.Message) + if assetId == nil && url == nil { // Nothing to make a snippet from! return nil, nil } contentMarkdown := existing.MessageContent.LastContent - contentHTML := parsing.ParseMarkdown(contentMarkdown, parsing.RealMarkdown) + contentHTML := parsing.ParseMarkdown(contentMarkdown, parsing.DiscordMarkdown) // TODO(db): Insert isnippet, err := db.QueryOne(ctx, tx, models.Snippet{}, @@ -596,7 +606,7 @@ func (bot *botInstance) createMessageSnippet(ctx context.Context, tx pgx.Tx, msg VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING $columns `, - nil, + url, existing.Message.SentAt, contentMarkdown, contentHTML, @@ -626,7 +636,7 @@ func (bot *botInstance) createMessageSnippet(ctx context.Context, tx pgx.Tx, msg // do we actually want to reuse those, or should we keep them separate? var RESnippetableUrl = regexp.MustCompile(`^https?://(youtu\.be|(www\.)?youtube\.com/watch)`) -func (bot *botInstance) getSnippetAssetOrUrl(ctx context.Context, tx pgx.Tx, msg *models.DiscordMessage) (*uuid.UUID, string, error) { +func getSnippetAssetOrUrl(ctx context.Context, tx db.ConnOrTx, msg *models.DiscordMessage) (*uuid.UUID, *string, error) { // Check attachments itAttachments, err := db.Query(ctx, tx, models.DiscordMessageAttachment{}, ` @@ -637,12 +647,12 @@ func (bot *botInstance) getSnippetAssetOrUrl(ctx context.Context, tx pgx.Tx, msg msg.ID, ) if err != nil { - return nil, "", oops.New(err, "failed to fetch message attachments") + return nil, nil, oops.New(err, "failed to fetch message attachments") } attachments := itAttachments.ToSlice() for _, iattachment := range attachments { attachment := iattachment.(*models.DiscordMessageAttachment) - return &attachment.AssetID, "", nil + return &attachment.AssetID, nil, nil } // Check embeds @@ -655,23 +665,23 @@ func (bot *botInstance) getSnippetAssetOrUrl(ctx context.Context, tx pgx.Tx, msg msg.ID, ) if err != nil { - return nil, "", oops.New(err, "failed to fetch discord embeds") + return nil, nil, oops.New(err, "failed to fetch discord embeds") } embeds := itEmbeds.ToSlice() for _, iembed := range embeds { embed := iembed.(*models.DiscordMessageEmbed) if embed.VideoID != nil { - return embed.VideoID, "", nil + return embed.VideoID, nil, nil } else if embed.ImageID != nil { - return embed.ImageID, "", nil + return embed.ImageID, nil, nil } else if embed.URL != nil { if RESnippetableUrl.MatchString(*embed.URL) { - return nil, *embed.URL, nil + return nil, embed.URL, nil } } } - return nil, "", nil + return nil, nil, nil } func messageHasLinks(content string) bool { diff --git a/src/discord/todo.txt b/src/discord/todo.txt deleted file mode 100644 index 41c3d9b..0000000 --- a/src/discord/todo.txt +++ /dev/null @@ -1,48 +0,0 @@ -the goal: port the old discord showcase bot - -what it does: save #project-showcase posts to your HMN user profile if you have your account linked - -stuff we need to worry about: -- old posts from before you linked your account -- posts that come in while the bot is down -- what to do with posts if you unlink your account -- what to do with posts if you re-link your account -✔ - what to do if you edit the original discord message -- what to do if you delete the original discord message -✔ - the user's preferences re: saving content - - we don't want to save content without the user's consent, especially since it may persist after they disable the integration -- manually adding content for various reasons - - maybe a bug prevented something from saving - - ryan used to post everything in #projects for some reason - - -✔ real-time stuff: -✔ - on new showcase message - - always save the lightweight record - - if we have permission, create a snippet -✔ - on edit - - re-save the lightweight record and content as if it was new - - create snippet, unconditionally???? (bug??) - - update snippet contents if the edit makes sense -✔ - on delete - - delete snippet if the user so desires - - delete the message records -✔ - on bulk delete - - same stuff - -background stuff: -- watch mode - - every five seconds - - fetch all HMN users with Discord accounts - - check if we have message records without content - - if so, run a full scrape (no snippets) - - every hour - - run a full scrape, creating snippets -- scrape behavior - - look at every message ever in the channel - - do exactly what the real-time bot does on new messages (although maybe don't do snippets depending on context) - - -what the heck do we do with discord's markdown -- when we save message contents, we should save both the raw discord markdown and a version with their custom stuff replaced. We do _not_ (yet) need a full markdown parse with HTML tags and stuff. (That arguably doesn't make sense for the handmade_discordmessagecontent record anyway.) -- when we create a snippet, we should store both markdown that makes sense to a user and the rendered version of that HTML. THIS MEANS: The markdown we save is the "clean" version of the Discord markdown. diff --git a/src/main.go b/src/main.go index 96bcdab..2377cbc 100644 --- a/src/main.go +++ b/src/main.go @@ -4,6 +4,7 @@ import ( _ "git.handmade.network/hmn/hmn/src/admintools" _ "git.handmade.network/hmn/hmn/src/assets" _ "git.handmade.network/hmn/hmn/src/buildscss" + _ "git.handmade.network/hmn/hmn/src/discord/cmd" _ "git.handmade.network/hmn/hmn/src/initimage" _ "git.handmade.network/hmn/hmn/src/migration" "git.handmade.network/hmn/hmn/src/website" diff --git a/src/parsing/parsing.go b/src/parsing/parsing.go index fbd70da..68bf27b 100644 --- a/src/parsing/parsing.go +++ b/src/parsing/parsing.go @@ -10,21 +10,38 @@ import ( ) // Used for rendering real-time previews of post content. -var PreviewMarkdown = goldmark.New( - goldmark.WithExtensions(makeGoldmarkExtensions(true)...), +var ForumPreviewMarkdown = goldmark.New( + goldmark.WithExtensions(makeGoldmarkExtensions(MarkdownOptions{ + Previews: true, + Embeds: true, + })...), ) // Used for generating the final HTML for a post. -var RealMarkdown = goldmark.New( - goldmark.WithExtensions(makeGoldmarkExtensions(false)...), +var ForumRealMarkdown = goldmark.New( + goldmark.WithExtensions(makeGoldmarkExtensions(MarkdownOptions{ + Previews: false, + Embeds: true, + })...), ) // Used for generating plain-text previews of posts. var PlaintextMarkdown = goldmark.New( - goldmark.WithExtensions(makeGoldmarkExtensions(false)...), + goldmark.WithExtensions(makeGoldmarkExtensions(MarkdownOptions{ + Previews: false, + Embeds: true, + })...), goldmark.WithRenderer(plaintextRenderer{}), ) +// Used for processing Discord messages +var DiscordMarkdown = goldmark.New( + goldmark.WithExtensions(makeGoldmarkExtensions(MarkdownOptions{ + Previews: false, + Embeds: false, + })...), +) + func ParseMarkdown(source string, md goldmark.Markdown) string { var buf bytes.Buffer if err := md.Convert([]byte(source), &buf); err != nil { @@ -34,19 +51,35 @@ func ParseMarkdown(source string, md goldmark.Markdown) string { return buf.String() } -func makeGoldmarkExtensions(preview bool) []goldmark.Extender { - return []goldmark.Extender{ +type MarkdownOptions struct { + Previews bool + Embeds bool +} + +func makeGoldmarkExtensions(opts MarkdownOptions) []goldmark.Extender { + var extenders []goldmark.Extender + extenders = append(extenders, extension.GFM, highlightExtension, SpoilerExtension{}, - EmbedExtension{ - Preview: preview, - }, + ) + + if opts.Embeds { + extenders = append(extenders, + EmbedExtension{ + Preview: opts.Previews, + }, + ) + } + + extenders = append(extenders, MathjaxExtension{}, BBCodeExtension{ - Preview: preview, + Preview: opts.Previews, }, - } + ) + + return extenders } var highlightExtension = highlighting.NewHighlighting( diff --git a/src/parsing/parsing_test.go b/src/parsing/parsing_test.go index 1e68bdf..42ad807 100644 --- a/src/parsing/parsing_test.go +++ b/src/parsing/parsing_test.go @@ -10,14 +10,14 @@ import ( func TestMarkdown(t *testing.T) { t.Run("fenced code blocks", func(t *testing.T) { t.Run("multiple lines", func(t *testing.T) { - html := ParseMarkdown("```\nmultiple lines\n\tof code\n```", RealMarkdown) + html := ParseMarkdown("```\nmultiple lines\n\tof code\n```", ForumRealMarkdown) t.Log(html) assert.Equal(t, 1, strings.Count(html, "