hmn/src/website/middlewares.go

144 lines
3.9 KiB
Go
Raw Normal View History

package website
import (
"fmt"
"math/rand"
"net/http"
"time"
"git.handmade.network/hmn/hmn/src/auth"
"git.handmade.network/hmn/hmn/src/hmnurl"
"git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/perf"
"git.handmade.network/hmn/hmn/src/utils"
)
func panicCatcherMiddleware(h Handler) Handler {
return func(c *RequestContext) (res ResponseData) {
defer func() {
if recovered := recover(); recovered != nil {
maybeError, ok := recovered.(*error)
var err error
if ok {
err = *maybeError
} else {
err = oops.New(nil, fmt.Sprintf("Recovered from panic with value: %v", recovered))
}
res = c.ErrorResponse(http.StatusInternalServerError, err)
}
}()
return h(c)
}
}
func trackRequestPerf(h Handler) Handler {
return func(c *RequestContext) ResponseData {
c.Perf = perf.MakeNewRequestPerf(c.Route, c.Req.Method, c.Req.URL.Path)
defer func() {
c.Perf.EndRequest()
log := logging.Info()
blockStack := make([]time.Time, 0)
for i, block := range c.Perf.Blocks {
for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) {
blockStack = blockStack[:len(blockStack)-1]
}
log.Str(fmt.Sprintf("[%4.d] At %9.2fms", i, c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs()))
blockStack = append(blockStack, block.End)
}
log.Msg(fmt.Sprintf("Served [%s] %s in %.4fms", c.Perf.Method, c.Perf.Path, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000))
// perfCollector.SubmitRun(c.Perf) // TODO(asaf): Implement a use for this
}()
return h(c)
}
}
func needsAuth(h Handler) Handler {
return func(c *RequestContext) ResponseData {
if c.CurrentUser == nil {
return c.Redirect(hmnurl.BuildLoginPage(c.FullUrl()), http.StatusSeeOther)
}
return h(c)
}
}
func adminsOnly(h Handler) Handler {
return func(c *RequestContext) ResponseData {
if c.CurrentUser == nil || !c.CurrentUser.IsStaff {
return FourOhFour(c)
}
return h(c)
}
}
func educationBetaTestersOnly(h Handler) Handler {
return func(c *RequestContext) ResponseData {
if c.CurrentUser == nil || !c.CurrentUser.CanSeeUnpublishedEducationContent() {
return FourOhFour(c)
}
return h(c)
}
}
func educationAuthorsOnly(h Handler) Handler {
return func(c *RequestContext) ResponseData {
if c.CurrentUser == nil || !c.CurrentUser.CanAuthorEducation() {
return FourOhFour(c)
}
return h(c)
}
}
func csrfMiddleware(h Handler) Handler {
// CSRF mitigation actions per the OWASP cheat sheet:
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
return func(c *RequestContext) ResponseData {
c.Req.ParseMultipartForm(100 * 1024 * 1024)
csrfToken := c.Req.Form.Get(auth.CSRFFieldName)
if csrfToken != c.CurrentSession.CSRFToken {
c.Logger.Warn().Str("userId", c.CurrentUser.Username).Msg("user failed CSRF validation - potential attack?")
res := c.Redirect("/", http.StatusSeeOther)
logoutUser(c, &res)
return res
}
return h(c)
}
}
func securityTimerMiddleware(duration time.Duration, h Handler) Handler {
// NOTE(asaf): Will make sure that the request takes at least `duration` to finish. Adds a 10% random duration.
return func(c *RequestContext) ResponseData {
additionalDuration := time.Duration(rand.Int63n(utils.Int64Max(1, int64(duration)/10)))
timer := time.NewTimer(duration + additionalDuration)
res := h(c)
select {
case <-c.Done():
case <-timer.C:
}
return res
}
}
func logContextErrors(c *RequestContext, errs ...error) {
for _, err := range errs {
c.Logger.Error().Timestamp().Stack().Str("Requested", c.FullUrl()).Err(err).Msg("error occurred during request")
}
}
func logContextErrorsMiddleware(h Handler) Handler {
return func(c *RequestContext) ResponseData {
res := h(c)
logContextErrors(c, res.Errors...)
return res
}
}