From ce582df6108eacd076eee7dc006e87b40a314e97 Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Wed, 28 Apr 2021 22:07:14 -0500 Subject: [PATCH] Redo the request handling system again --- src/website/feed.go | 4 +- src/website/requesthandling.go | 191 ++++++++++++++++++--------------- src/website/routes.go | 132 +++++++++++------------ 3 files changed, 171 insertions(+), 156 deletions(-) diff --git a/src/website/feed.go b/src/website/feed.go index 4c7d7f8..b41629a 100644 --- a/src/website/feed.go +++ b/src/website/feed.go @@ -44,8 +44,8 @@ func Feed(c *RequestContext) ResponseData { numPages := int(math.Ceil(float64(numPosts) / 30)) page := 1 - pageString := c.PathParams.ByName("page") - if pageString != "" { + pageString, hasPage := c.PathParams["page"] + if hasPage && pageString != "" { if pageParsed, err := strconv.Atoi(pageString); err == nil { page = pageParsed } else { diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index 1daacdf..bf3463a 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -3,11 +3,13 @@ package website import ( "bytes" "context" + "fmt" "html" "io" "net/http" "net/url" "path" + "regexp" "strings" "git.handmade.network/hmn/hmn/src/logging" @@ -15,51 +17,114 @@ import ( "git.handmade.network/hmn/hmn/src/perf" "git.handmade.network/hmn/hmn/src/templates" "github.com/jackc/pgx/v4/pgxpool" - "github.com/julienschmidt/httprouter" "github.com/rs/zerolog" ) -// The typical handler. Handles a request and returns data about the response. -type HMNHandler func(c *RequestContext) ResponseData +type Router struct { + Routes []Route +} -// A special handler that runs before the primary handler. Intended to set -// information on the context for later handlers, or to give the request a -// means to bail out early if preconditions are not met (like auth). If `ok` -// is false, the request will immediately bail out, no further handlers will -// be run, and it will respond with the provided response data. -// -// The response data from this function will still be fed through any after -// handlers, to ensure that errors will get logged and whatnot. -type HMNBeforeHandler func(c *RequestContext) (ok bool, res ResponseData) +type Route struct { + Method string + Regex *regexp.Regexp + Handler Handler +} -// A special handler that runs after the primary handler and can modify the -// response information. Intended for error logging, error pages, -// cleanup, etc. -type HMNAfterHandler func(c *RequestContext, res ResponseData) ResponseData +type RouteBuilder struct { + Router *Router + Middleware Middleware +} -func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - c := NewRequestContext(rw, req, nil, "") - doRequest(rw, c, h) +type Handler func(c *RequestContext) ResponseData + +func WrapStdHandler(h http.Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + h.ServeHTTP(&res, c.Req) + return res + } +} + +type Middleware func(h Handler) Handler + +func (rb *RouteBuilder) Handle(method string, regexStr string, h Handler) { + h = rb.Middleware(h) + rb.Router.Routes = append(rb.Router.Routes, Route{ + Method: method, + Regex: regexp.MustCompile(regexStr), + Handler: h, + }) +} + +func (rb *RouteBuilder) AnyMethod(regexStr string, h Handler) { + rb.Handle("", regexStr, h) +} + +func (rb *RouteBuilder) GET(regexStr string, h Handler) { + rb.Handle(http.MethodGet, regexStr, h) +} + +func (rb *RouteBuilder) POST(regexStr string, h Handler) { + rb.Handle(http.MethodGet, regexStr, h) +} + +func (rb *RouteBuilder) StdHandler(regexStr string, h http.Handler) { + rb.Router.Routes = append(rb.Router.Routes, Route{ + Method: "", + Regex: regexp.MustCompile(regexStr), + Handler: WrapStdHandler(h), + }) +} + +func (r *Router) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + path := req.URL.Path + for _, route := range r.Routes { + if route.Method != "" && req.Method != route.Method { + continue + } + + match := route.Regex.FindStringSubmatch(path) + if match == nil { + continue + } + + c := &RequestContext{ + Route: "", // TODO + Logger: logging.GlobalLogger(), + Req: req, + } + + if len(match) > 0 { + params := map[string]string{} + subexpNames := route.Regex.SubexpNames() + for i, paramValue := range match { + paramName := subexpNames[i] + if paramName == "" { + continue + } + params[paramName] = paramValue + } + c.PathParams = params + } + + doRequest(rw, c, route.Handler) + + return + } + + panic(fmt.Sprintf("Path '%s' did not match any routes! Make sure to register a wildcard route to act as a 404.", path)) } type RequestContext struct { + Route string Logger *zerolog.Logger Req *http.Request - PathParams httprouter.Params + PathParams map[string]string Conn *pgxpool.Pool - Perf *perf.RequestPerf CurrentProject *models.Project CurrentUser *models.User -} -func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params, route string) *RequestContext { - return &RequestContext{ - Logger: logging.GlobalLogger(), - Req: req, - PathParams: pathParams, - Perf: perf.MakeNewRequestPerf(route), - } + Perf *perf.RequestPerf } func (c *RequestContext) Context() context.Context { @@ -122,9 +187,9 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData { destUrl, _ := url.Parse(dest) dest = destUrl.String() - res.Headers().Set("Location", dest) + res.Header().Set("Location", dest) if c.Req.Method == "GET" || c.Req.Method == "HEAD" { - res.Headers().Set("Content-Type", "text/html; charset=utf-8") + res.Header().Set("Content-Type", "text/html; charset=utf-8") } res.StatusCode = code @@ -144,7 +209,9 @@ type ResponseData struct { header http.Header } -func (rd *ResponseData) Headers() http.Header { +var _ http.ResponseWriter = &ResponseData{} + +func (rd *ResponseData) Header() http.Header { if rd.header == nil { rd.header = make(http.Header) } @@ -160,8 +227,12 @@ func (rd *ResponseData) Write(p []byte) (n int, err error) { return rd.Body.Write(p) } +func (rd *ResponseData) WriteHeader(status int) { + rd.StatusCode = status +} + func (rd *ResponseData) SetCookie(cookie *http.Cookie) { - rd.Headers().Add("Set-Cookie", cookie.String()) + rd.Header().Add("Set-Cookie", cookie.String()) } func (rd *ResponseData) WriteTemplate(name string, data interface{}, rp *perf.RequestPerf) error { @@ -172,56 +243,6 @@ func (rd *ResponseData) WriteTemplate(name string, data interface{}, rp *perf.Re return templates.Templates[name].Execute(rd, data) } -type RouteBuilder struct { - Router *httprouter.Router - BeforeHandlers []HMNBeforeHandler - AfterHandlers []HMNAfterHandler -} - -func (b RouteBuilder) ChainHandlers(h HMNHandler) HMNHandler { - return func(c *RequestContext) ResponseData { - beforeOk := true - var res ResponseData - for _, before := range b.BeforeHandlers { - if ok, errorRes := before(c); !ok { - beforeOk = false - res = errorRes - } - } - - if beforeOk { - res = h(c) - } - - for _, after := range b.AfterHandlers { - res = after(c, res) - } - return res - } -} - -func (b *RouteBuilder) Handle(method, route string, handler HMNHandler) { - h := b.ChainHandlers(handler) - b.Router.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) { - c := NewRequestContext(rw, req, p, route) - doRequest(rw, c, h) - }) -} - -func (b *RouteBuilder) GET(route string, handler HMNHandler) { - b.Handle(http.MethodGet, route, handler) -} - -func (b *RouteBuilder) POST(route string, handler HMNHandler) { - b.Handle(http.MethodPost, route, handler) -} - -// TODO: More methods - -func (b *RouteBuilder) ServeFiles(path string, root http.FileSystem) { - b.Router.ServeFiles(path, root) -} - func ErrorResponse(status int, errs ...error) ResponseData { return ResponseData{ StatusCode: status, @@ -229,7 +250,7 @@ func ErrorResponse(status int, errs ...error) ResponseData { } } -func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) { +func doRequest(rw http.ResponseWriter, c *RequestContext, h Handler) { defer func() { /* This panic recovery is the last resort. If you want to render @@ -247,7 +268,7 @@ func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) { res.StatusCode = http.StatusOK } - for name, vals := range res.Headers() { + for name, vals := range res.Header() { for _, val := range vals { rw.Header().Add(name, val) } diff --git a/src/website/routes.go b/src/website/routes.go index 7bf87ef..e684aec 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -17,51 +17,52 @@ import ( "git.handmade.network/hmn/hmn/src/perf" "git.handmade.network/hmn/hmn/src/templates" "github.com/jackc/pgx/v4/pgxpool" - "github.com/julienschmidt/httprouter" ) func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) http.Handler { - router := httprouter.New() + router := &Router{} routes := RouteBuilder{ Router: router, - BeforeHandlers: []HMNBeforeHandler{ - func(c *RequestContext) (bool, ResponseData) { + Middleware: func(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { c.Conn = conn - return true, ResponseData{} - }, - // TODO: Add a timeout? We don't want routes hanging forever - }, - AfterHandlers: []HMNAfterHandler{ - ErrorLoggingHandler, - func(c *RequestContext, res ResponseData) ResponseData { - // Do perf printout - c.Perf.EndRequest() - log := logging.Info() - blockStack := make([]time.Time, 0) - for _, block := range c.Perf.Blocks { - for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) { - blockStack = blockStack[:len(blockStack)-1] - } - log.Str(fmt.Sprintf("At %9.2fms", c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs())) - blockStack = append(blockStack, block.End) - } - log.Msg(fmt.Sprintf("Served %s in %.4fms", c.Perf.Route, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000)) - perfCollector.SubmitRun(c.Perf) - return res - }, + + logPerf := TrackRequestPerf(c, perfCollector) + defer logPerf() + + defer LogContextErrors(c, res) + + return h(c) + } }, } - routes.POST("/login", Login) - routes.GET("/logout", Logout) - routes.ServeFiles("/public/*filepath", http.Dir("public")) - mainRoutes := routes - mainRoutes.BeforeHandlers = append(mainRoutes.BeforeHandlers, - CommonWebsiteDataWrapper, + mainRoutes.Middleware = func(h Handler) Handler { + return func(c *RequestContext) (res ResponseData) { + c.Conn = conn + + logPerf := TrackRequestPerf(c, perfCollector) + defer logPerf() + + defer LogContextErrors(c, res) + + ok, errRes := LoadCommonWebsiteData(c) + if !ok { + return errRes + } + + return h(c) + } + } + + routes.POST("^/login$", Login) + routes.GET("^/logout$", Logout) + routes.StdHandler("^/public/.*$", + http.StripPrefix("/public/", http.FileServer(http.Dir("public"))), ) - mainRoutes.GET("/", func(c *RequestContext) ResponseData { + mainRoutes.GET("^/$", func(c *RequestContext) ResponseData { if c.CurrentProject.IsHMN() { return Index(c) } else { @@ -69,34 +70,11 @@ func NewWebsiteRoutes(conn *pgxpool.Pool, perfCollector *perf.PerfCollector) htt panic("route not implemented") } }) - mainRoutes.GET("/feed", Feed) - mainRoutes.GET("/feed/:page", Feed) + mainRoutes.GET(`^/feed(/(?P.+)?)?$`, Feed) - mainRoutes.GET("/assets/project.css", ProjectCSS) + mainRoutes.GET("^/assets/project.css$", ProjectCSS) - router.NotFound = mainRoutes.ChainHandlers(FourOhFour) - - adminRoutes := routes - adminRoutes.BeforeHandlers = append(adminRoutes.BeforeHandlers, - func(c *RequestContext) (ok bool, res ResponseData) { - return false, ResponseData{ - StatusCode: http.StatusUnauthorized, - Body: bytes.NewBufferString("No one is allowed!\n"), - } - }, - ) - adminRoutes.AfterHandlers = append(adminRoutes.AfterHandlers, - func(c *RequestContext, res ResponseData) ResponseData { - res.Body.WriteString("Now go away. Sincerely, the after handler.\n") - return res - }, - ) - - adminRoutes.GET("/admin", func(c *RequestContext) ResponseData { - return ResponseData{ - Body: bytes.NewBufferString("Here are all the secrets.\n"), - } - }) + mainRoutes.AnyMethod("", FourOhFour) return router } @@ -156,7 +134,7 @@ func ProjectCSS(c *RequestContext) ResponseData { } var res ResponseData - res.Headers().Add("Content-Type", "text/css") + res.Header().Add("Content-Type", "text/css") err := res.WriteTemplate("project.css", templateData, c.Perf) if err != nil { return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS")) @@ -172,15 +150,7 @@ func FourOhFour(c *RequestContext) ResponseData { } } -func ErrorLoggingHandler(c *RequestContext, res ResponseData) ResponseData { - for _, err := range res.Errors { - c.Logger.Error().Err(err).Msg("error occurred during request") - } - - return res -} - -func CommonWebsiteDataWrapper(c *RequestContext) (bool, ResponseData) { +func LoadCommonWebsiteData(c *RequestContext) (bool, ResponseData) { c.Perf.StartBlock("MIDDLEWARE", "Load common website data") defer c.Perf.EndBlock() // get project @@ -240,3 +210,27 @@ func getCurrentUser(c *RequestContext, sessionId string) (*models.User, error) { return user, nil } + +func TrackRequestPerf(c *RequestContext, perfCollector *perf.PerfCollector) (after func()) { + c.Perf = perf.MakeNewRequestPerf(c.Route) + return func() { + c.Perf.EndRequest() + log := logging.Info() + blockStack := make([]time.Time, 0) + for _, block := range c.Perf.Blocks { + for len(blockStack) > 0 && block.End.After(blockStack[len(blockStack)-1]) { + blockStack = blockStack[:len(blockStack)-1] + } + log.Str(fmt.Sprintf("At %9.2fms", c.Perf.MsFromStart(&block)), fmt.Sprintf("%*.s[%s] %s (%.4fms)", len(blockStack)*2, "", block.Category, block.Description, block.DurationMs())) + blockStack = append(blockStack, block.End) + } + log.Msg(fmt.Sprintf("Served %s in %.4fms", c.Perf.Route, float64(c.Perf.End.Sub(c.Perf.Start).Nanoseconds())/1000/1000)) + perfCollector.SubmitRun(c.Perf) + } +} + +func LogContextErrors(c *RequestContext, res ResponseData) { + for _, err := range res.Errors { + c.Logger.Error().Err(err).Msg("error occurred during request") + } +}