diff --git a/src/auth/session.go b/src/auth/session.go index 4a9c9e4..ae4adca 100644 --- a/src/auth/session.go +++ b/src/auth/session.go @@ -11,6 +11,7 @@ import ( "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" + "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" "github.com/jackc/pgx/v4/pgxpool" @@ -94,3 +95,39 @@ var DeleteSessionCookie = &http.Cookie{ Domain: config.Config.Auth.CookieDomain, MaxAge: -1, } + +func DeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) (int64, error) { + tag, err := conn.Exec(ctx, "DELETE FROM sessions WHERE expires_at <= CURRENT_TIMESTAMP") + if err != nil { + return 0, oops.New(err, "failed to delete expired sessions") + } + + return tag.RowsAffected(), nil +} + +func PeriodicallyDeleteExpiredSessions(ctx context.Context, conn *pgxpool.Pool) <-chan struct{} { + done := make(chan struct{}) + go func() { + defer close(done) + + t := time.NewTicker(1 * time.Minute) + for { + select { + case <-t.C: + n, err := DeleteExpiredSessions(ctx, conn) + if err == nil { + if n > 0 { + logging.Info().Int64("num deleted sessions", n).Msg("Deleted expired sessions") + } else { + logging.Debug().Msg("no sessions to delete") + } + } else { + logging.Error().Err(err).Msg("Failed to delete expired sessions") + } + case <-ctx.Done(): + return + } + } + }() + return done +} diff --git a/src/website/website.go b/src/website/website.go index ea041bd..8727289 100644 --- a/src/website/website.go +++ b/src/website/website.go @@ -8,6 +8,7 @@ import ( "os/signal" "time" + "git.handmade.network/hmn/hmn/src/auth" "git.handmade.network/hmn/hmn/src/config" "git.handmade.network/hmn/hmn/src/db" "git.handmade.network/hmn/hmn/src/logging" @@ -31,6 +32,11 @@ var WebsiteCommand = &cobra.Command{ Handler: NewWebsiteRoutes(conn), } + backgroundJobContext, cancelBackgroundJobs := context.WithCancel(context.Background()) + backgroundJobsDone := zipJobs( + auth.PeriodicallyDeleteExpiredSessions(backgroundJobContext, conn), + ) + signals := make(chan os.Signal, 1) signal.Notify(signals, os.Interrupt) go func() { @@ -39,6 +45,7 @@ var WebsiteCommand = &cobra.Command{ timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() server.Shutdown(timeout) + cancelBackgroundJobs() <-signals logging.Warn().Msg("Forcibly killed the website") @@ -50,5 +57,18 @@ var WebsiteCommand = &cobra.Command{ if !errors.Is(serverErr, http.ErrServerClosed) { logging.Error().Err(serverErr).Msg("Server shut down unexpectedly") } + + <-backgroundJobsDone }, } + +func zipJobs(cs ...<-chan struct{}) <-chan struct{} { + out := make(chan struct{}) + go func() { + for _, c := range cs { + <-c + } + close(out) + }() + return out +}