Rework request handling

It's a lot simpler now, control flow is easier to work with in handlers,
and HMNHandler now natively implements http.Handler, simplifying our 404
setup by completely removing the need for MakeStdHandler.
This commit is contained in:
Ben Visness 2021-04-05 22:30:11 -05:00
parent a6cdbac4c7
commit 98da461d92
3 changed files with 162 additions and 143 deletions

View File

@ -8,7 +8,6 @@ import (
"git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops" "git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates" "git.handmade.network/hmn/hmn/src/templates"
"github.com/julienschmidt/httprouter"
) )
type LandingTemplateData struct { type LandingTemplateData struct {
@ -29,7 +28,7 @@ type LandingPagePost struct {
HasRead bool HasRead bool
} }
func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
const maxPosts = 5 const maxPosts = 5
const numProjectsToGet = 7 const numProjectsToGet = 7
@ -38,8 +37,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.HMNProjectID, models.HMNProjectID,
) )
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
return
} }
defer iterProjects.Close() defer iterProjects.Close()
@ -132,8 +130,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.CatTypeBlog, models.CatTypeBlog,
) )
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
return
} }
newsThread := newsThreadRow.(*newsThreadQuery) newsThread := newsThreadRow.(*newsThreadQuery)
_ = newsThread // TODO: NO _ = newsThread // TODO: NO
@ -141,8 +138,11 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
baseData := s.getBaseData(c) baseData := s.getBaseData(c)
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more? baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
err = c.WriteTemplate("index.html", s.getBaseData(c)) var res ResponseData
err = res.WriteTemplate("index.html", s.getBaseData(c))
if err != nil { if err != nil {
panic(err) panic(err)
} }
return res
} }

View File

@ -3,7 +3,6 @@ package website
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"html" "html"
"io" "io"
"net/http" "net/http"
@ -35,7 +34,11 @@ func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler {
} }
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) { func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler))) h := r.WrapHandler(handler)
r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
c := NewRequestContext(rw, req, p)
doRequest(rw, c, h)
})
} }
func (r *HMNRouter) GET(route string, handler HMNHandler) { func (r *HMNRouter) GET(route string, handler HMNHandler) {
@ -58,39 +61,29 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter {
return &result return &result
} }
type HMNHandler func(c *RequestContext, p httprouter.Params) type HMNHandler func(c *RequestContext) ResponseData
type HMNHandlerWrapper func(h HMNHandler) HMNHandler type HMNHandlerWrapper func(h HMNHandler) HMNHandler
func MakeStdHandler(h HMNHandler, name string) http.Handler { func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { c := NewRequestContext(rw, req, nil)
handleHmnHandler(name, h)(rw, req, nil) doRequest(rw, c, h)
})
} }
type RequestContext struct { type RequestContext struct {
StatusCode int
Body *bytes.Buffer
Logger *zerolog.Logger Logger *zerolog.Logger
Req *http.Request Req *http.Request
Errors []error PathParams httprouter.Params
rw http.ResponseWriter CurrentProject *models.Project
CurrentUser *models.User
currentProject *models.Project // CurrentMember *models.Member
currentUser *models.User
// currentMember *models.Member
} }
func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext { func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params) *RequestContext {
logger := logging.With().Str("route", route).Logger()
return &RequestContext{ return &RequestContext{
StatusCode: http.StatusOK, Logger: logging.GlobalLogger(),
Body: new(bytes.Buffer),
Logger: &logger,
Req: req, Req: req,
PathParams: pathParams,
rw: rw,
} }
} }
@ -102,14 +95,6 @@ func (c *RequestContext) URL() *url.URL {
return c.Req.URL return c.Req.URL
} }
func (c *RequestContext) Headers() http.Header {
return c.rw.Header()
}
func (c *RequestContext) SetCookie(cookie *http.Cookie) {
c.rw.Header().Add("Set-Cookie", cookie.String())
}
func (c *RequestContext) GetFormValues() (url.Values, error) { func (c *RequestContext) GetFormValues() (url.Values, error) {
err := c.Req.ParseForm() err := c.Req.ParseForm()
if err != nil { if err != nil {
@ -119,9 +104,39 @@ func (c *RequestContext) GetFormValues() (url.Values, error) {
return c.Req.PostForm, nil return c.Req.PostForm, nil
} }
type ResponseData struct {
StatusCode int
Body *bytes.Buffer
Errors []error
header http.Header
}
func (rd *ResponseData) Headers() http.Header {
if rd.header == nil {
rd.header = make(http.Header)
}
return rd.header
}
func (rd *ResponseData) Write(p []byte) (n int, err error) {
if rd.Body == nil {
rd.Body = new(bytes.Buffer)
}
return rd.Body.Write(p)
}
func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
rd.Headers().Add("Set-Cookie", cookie.String())
}
// The logic of this function is copy-pasted from the Go standard library. // The logic of this function is copy-pasted from the Go standard library.
// https://golang.org/pkg/net/http/#Redirect // https://golang.org/pkg/net/http/#Redirect
func (c *RequestContext) Redirect(dest string, code int) { func (c *RequestContext) Redirect(dest string, code int) ResponseData {
var res ResponseData
if u, err := url.Parse(dest); err == nil { if u, err := url.Parse(dest); err == nil {
// If url was relative, make its path absolute by // If url was relative, make its path absolute by
// combining with request path. // combining with request path.
@ -156,46 +171,36 @@ func (c *RequestContext) Redirect(dest string, code int) {
} }
} }
h := c.Headers()
// RFC 7231 notes that a short HTML body is usually included in
// the response because older user agents may not understand 301/307.
// Do it only if the request didn't already have a Content-Type header.
_, hadCT := h["Content-Type"]
// Escape stuff // Escape stuff
destUrl, _ := url.Parse(dest) destUrl, _ := url.Parse(dest)
dest = destUrl.String() dest = destUrl.String()
h.Set("Location", dest) res.Headers().Set("Location", dest)
if !hadCT && (c.Req.Method == "GET" || c.Req.Method == "HEAD") { if c.Req.Method == "GET" || c.Req.Method == "HEAD" {
h.Set("Content-Type", "text/html; charset=utf-8") res.Headers().Set("Content-Type", "text/html; charset=utf-8")
} }
c.StatusCode = code res.StatusCode = code
// Shouldn't send the body for POST or HEAD; that leaves GET. // Shouldn't send the body for POST or HEAD; that leaves GET.
if !hadCT && c.Req.Method == "GET" { if c.Req.Method == "GET" {
body := "<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n" res.Write([]byte("<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n"))
fmt.Fprintln(c.Body, body) }
return res
}
func (rd *ResponseData) WriteTemplate(name string, data interface{}) error {
return templates.Templates[name].Execute(rd, data)
}
func ErrorResponse(status int, errs ...error) ResponseData {
return ResponseData{
StatusCode: status,
Errors: errs,
} }
} }
func (c *RequestContext) WriteTemplate(name string, data interface{}) error { func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) {
return templates.Templates[name].Execute(c.Body, data)
}
func (c *RequestContext) AddErrors(errs ...error) {
c.Errors = append(c.Errors, errs...)
}
func (c *RequestContext) Errored(status int, errs ...error) {
c.StatusCode = status
c.AddErrors(errs...)
}
func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
return func(rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
c := newRequestContext(rw, r, route)
defer func() { defer func() {
/* /*
This panic recovery is the last resort. If you want to render This panic recovery is the last resort. If you want to render
@ -207,9 +212,17 @@ func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
} }
}() }()
h(c, p) res := h(c)
rw.WriteHeader(c.StatusCode) if res.StatusCode == 0 {
io.Copy(rw, c.Body) res.StatusCode = http.StatusOK
} }
for name, vals := range res.Headers() {
for _, val := range vals {
rw.Header().Add(name, val)
}
}
rw.WriteHeader(res.StatusCode)
io.Copy(rw, res.Body)
} }

View File

@ -1,6 +1,7 @@
package website package website
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -33,11 +34,12 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
} }
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper) mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper)
mainRoutes.GET("/", func(c *RequestContext, p httprouter.Params) { mainRoutes.GET("/", func(c *RequestContext) ResponseData {
if c.currentProject.ID == models.HMNProjectID { if c.CurrentProject.ID == models.HMNProjectID {
routes.Index(c, p) return routes.Index(c)
} else { } else {
// TODO: Return the project landing page // TODO: Return the project landing page
panic("route not implemented")
} }
}) })
mainRoutes.GET("/project/:id", routes.Project) mainRoutes.GET("/project/:id", routes.Project)
@ -48,29 +50,29 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
routes.ServeFiles("/public/*filepath", http.Dir("public")) routes.ServeFiles("/public/*filepath", http.Dir("public"))
routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404") routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour)
return routes return routes
} }
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData { func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
var templateUser *templates.User var templateUser *templates.User
if c.currentUser != nil { if c.CurrentUser != nil {
templateUser = &templates.User{ templateUser = &templates.User{
Username: c.currentUser.Username, Username: c.CurrentUser.Username,
Email: c.currentUser.Email, Email: c.CurrentUser.Email,
IsSuperuser: c.currentUser.IsSuperuser, IsSuperuser: c.CurrentUser.IsSuperuser,
IsStaff: c.currentUser.IsStaff, IsStaff: c.CurrentUser.IsStaff,
} }
} }
return templates.BaseData{ return templates.BaseData{
Project: templates.Project{ Project: templates.Project{
Name: *c.currentProject.Name, Name: *c.CurrentProject.Name,
Subdomain: *c.currentProject.Slug, Subdomain: *c.CurrentProject.Slug,
Color: c.currentProject.Color1, Color: c.CurrentProject.Color1,
IsHMN: c.currentProject.IsHMN(), IsHMN: c.CurrentProject.IsHMN(),
HasBlog: true, HasBlog: true,
HasForum: true, HasForum: true,
@ -104,24 +106,25 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
return defaultProject, nil return defaultProject, nil
} }
func (s *websiteRoutes) Project(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) Project(c *RequestContext) ResponseData {
id := p.ByName("id") id := c.PathParams.ByName("id")
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", p.ByName("id")) row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
var name string var name string
err := row.Scan(&name) err := row.Scan(&name)
if err != nil { if err != nil {
panic(err) panic(err)
} }
c.Body.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name))) var res ResponseData
res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
return res
} }
func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color") color := c.URL().Query().Get("color")
if color == "" { if color == "" {
c.StatusCode = http.StatusBadRequest return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
c.Body.Write([]byte("You must provide a 'color' parameter.\n"))
return
} }
templateData := struct { templateData := struct {
@ -132,27 +135,28 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
Theme: "dark", Theme: "dark",
} }
c.Headers().Add("Content-Type", "text/css") var res ResponseData
err := c.WriteTemplate("project.css", templateData) res.Headers().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData)
if err != nil { if err != nil {
c.Logger.Error().Err(err).Msg("failed to generate project CSS") return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
return
} }
return res
} }
func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks. // TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
form, err := c.GetFormValues() form, err := c.GetFormValues()
if err != nil { if err != nil {
c.Errored(http.StatusBadRequest, NewSafeError(err, "request must contain form data")) return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
return
} }
username := form.Get("username") username := form.Get("username")
password := form.Get("password") password := form.Get("password")
if username == "" || password == "" { if username == "" || password == "" {
c.Errored(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password")) return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
} }
redirect := form.Get("redirect") redirect := form.Get("redirect")
@ -163,24 +167,23 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username) userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) { if errors.Is(err, db.ErrNoMatchingRows) {
c.StatusCode = http.StatusUnauthorized return ResponseData{
StatusCode: http.StatusUnauthorized,
}
} else { } else {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
} }
return
} }
user := userRow.(models.User) user := userRow.(*models.User)
hashed, err := auth.ParsePasswordString(user.Password) hashed, err := auth.ParsePasswordString(user.Password)
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to parse password string")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
return
} }
passwordsMatch, err := auth.CheckPassword(password, hashed) passwordsMatch, err := auth.CheckPassword(password, hashed)
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to check password against hash")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
return
} }
if passwordsMatch { if passwordsMatch {
@ -200,20 +203,19 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
session, err := auth.CreateSession(c.Context(), s.conn, username) session, err := auth.CreateSession(c.Context(), s.conn, username)
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
return
} }
c.SetCookie(auth.NewSessionCookie(session)) res := c.Redirect(redirect, http.StatusSeeOther)
c.Redirect(redirect, http.StatusSeeOther) res.SetCookie(auth.NewSessionCookie(session))
return
return res
} else { } else {
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
return
} }
} }
func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) Logout(c *RequestContext) ResponseData {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil { if err == nil {
// clear the session from the db immediately, no expiration // clear the session from the db immediately, no expiration
@ -223,27 +225,33 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
} }
} }
c.SetCookie(auth.DeleteSessionCookie) res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page. res.SetCookie(auth.DeleteSessionCookie)
return res
} }
func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) { func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData {
c.StatusCode = http.StatusNotFound return ResponseData{
c.Body.Write([]byte("go away\n")) StatusCode: http.StatusNotFound,
Body: bytes.NewBuffer([]byte("go away\n")),
}
} }
func ErrorLoggingWrapper(h HMNHandler) HMNHandler { func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
return func(c *RequestContext, p httprouter.Params) { return func(c *RequestContext) ResponseData {
h(c, p) res := h(c)
for _, err := range c.Errors { for _, err := range res.Errors {
c.Logger.Error().Err(err).Msg("error occurred during request") c.Logger.Error().Err(err).Msg("error occurred during request")
} }
return res
} }
} }
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler { func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
return func(c *RequestContext, p httprouter.Params) { return func(c *RequestContext) ResponseData {
// get project // get project
{ {
slug := "" slug := ""
@ -254,26 +262,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug) dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug)
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
return
} }
c.currentProject = dbProject c.CurrentProject = dbProject
} }
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil { if err == nil {
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value) user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
if err != nil { if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
return
} }
c.currentUser = user c.CurrentUser = user
} }
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
h(c, p) return h(c)
} }
} }