Rework request handling

It's a lot simpler now, control flow is easier to work with in handlers,
and HMNHandler now natively implements http.Handler, simplifying our 404
setup by completely removing the need for MakeStdHandler.
This commit is contained in:
Ben Visness 2021-04-05 22:30:11 -05:00
parent a6cdbac4c7
commit 98da461d92
3 changed files with 162 additions and 143 deletions

View File

@ -8,7 +8,6 @@ import (
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates"
"github.com/julienschmidt/httprouter"
)
type LandingTemplateData struct {
@ -29,7 +28,7 @@ type LandingPagePost struct {
HasRead bool
}
func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
const maxPosts = 5
const numProjectsToGet = 7
@ -38,8 +37,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.HMNProjectID,
)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
}
defer iterProjects.Close()
@ -132,8 +130,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.CatTypeBlog,
)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
}
newsThread := newsThreadRow.(*newsThreadQuery)
_ = newsThread // TODO: NO
@ -141,8 +138,11 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
baseData := s.getBaseData(c)
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
err = c.WriteTemplate("index.html", s.getBaseData(c))
var res ResponseData
err = res.WriteTemplate("index.html", s.getBaseData(c))
if err != nil {
panic(err)
}
return res
}

View File

@ -3,7 +3,6 @@ package website
import (
"bytes"
"context"
"fmt"
"html"
"io"
"net/http"
@ -35,7 +34,11 @@ func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler {
}
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler)))
h := r.WrapHandler(handler)
r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
c := NewRequestContext(rw, req, p)
doRequest(rw, c, h)
})
}
func (r *HMNRouter) GET(route string, handler HMNHandler) {
@ -58,39 +61,29 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter {
return &result
}
type HMNHandler func(c *RequestContext, p httprouter.Params)
type HMNHandler func(c *RequestContext) ResponseData
type HMNHandlerWrapper func(h HMNHandler) HMNHandler
func MakeStdHandler(h HMNHandler, name string) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handleHmnHandler(name, h)(rw, req, nil)
})
func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
c := NewRequestContext(rw, req, nil)
doRequest(rw, c, h)
}
type RequestContext struct {
StatusCode int
Body *bytes.Buffer
Logger *zerolog.Logger
Req *http.Request
Errors []error
PathParams httprouter.Params
rw http.ResponseWriter
currentProject *models.Project
currentUser *models.User
// currentMember *models.Member
CurrentProject *models.Project
CurrentUser *models.User
// CurrentMember *models.Member
}
func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext {
logger := logging.With().Str("route", route).Logger()
func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params) *RequestContext {
return &RequestContext{
StatusCode: http.StatusOK,
Body: new(bytes.Buffer),
Logger: &logger,
Logger: logging.GlobalLogger(),
Req: req,
rw: rw,
PathParams: pathParams,
}
}
@ -102,14 +95,6 @@ func (c *RequestContext) URL() *url.URL {
return c.Req.URL
}
func (c *RequestContext) Headers() http.Header {
return c.rw.Header()
}
func (c *RequestContext) SetCookie(cookie *http.Cookie) {
c.rw.Header().Add("Set-Cookie", cookie.String())
}
func (c *RequestContext) GetFormValues() (url.Values, error) {
err := c.Req.ParseForm()
if err != nil {
@ -119,9 +104,39 @@ func (c *RequestContext) GetFormValues() (url.Values, error) {
return c.Req.PostForm, nil
}
type ResponseData struct {
StatusCode int
Body *bytes.Buffer
Errors []error
header http.Header
}
func (rd *ResponseData) Headers() http.Header {
if rd.header == nil {
rd.header = make(http.Header)
}
return rd.header
}
func (rd *ResponseData) Write(p []byte) (n int, err error) {
if rd.Body == nil {
rd.Body = new(bytes.Buffer)
}
return rd.Body.Write(p)
}
func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
rd.Headers().Add("Set-Cookie", cookie.String())
}
// The logic of this function is copy-pasted from the Go standard library.
// https://golang.org/pkg/net/http/#Redirect
func (c *RequestContext) Redirect(dest string, code int) {
func (c *RequestContext) Redirect(dest string, code int) ResponseData {
var res ResponseData
if u, err := url.Parse(dest); err == nil {
// If url was relative, make its path absolute by
// combining with request path.
@ -156,46 +171,36 @@ func (c *RequestContext) Redirect(dest string, code int) {
}
}
h := c.Headers()
// RFC 7231 notes that a short HTML body is usually included in
// the response because older user agents may not understand 301/307.
// Do it only if the request didn't already have a Content-Type header.
_, hadCT := h["Content-Type"]
// Escape stuff
destUrl, _ := url.Parse(dest)
dest = destUrl.String()
h.Set("Location", dest)
if !hadCT && (c.Req.Method == "GET" || c.Req.Method == "HEAD") {
h.Set("Content-Type", "text/html; charset=utf-8")
res.Headers().Set("Location", dest)
if c.Req.Method == "GET" || c.Req.Method == "HEAD" {
res.Headers().Set("Content-Type", "text/html; charset=utf-8")
}
c.StatusCode = code
res.StatusCode = code
// Shouldn't send the body for POST or HEAD; that leaves GET.
if !hadCT && c.Req.Method == "GET" {
body := "<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n"
fmt.Fprintln(c.Body, body)
if c.Req.Method == "GET" {
res.Write([]byte("<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n"))
}
return res
}
func (rd *ResponseData) WriteTemplate(name string, data interface{}) error {
return templates.Templates[name].Execute(rd, data)
}
func ErrorResponse(status int, errs ...error) ResponseData {
return ResponseData{
StatusCode: status,
Errors: errs,
}
}
func (c *RequestContext) WriteTemplate(name string, data interface{}) error {
return templates.Templates[name].Execute(c.Body, data)
}
func (c *RequestContext) AddErrors(errs ...error) {
c.Errors = append(c.Errors, errs...)
}
func (c *RequestContext) Errored(status int, errs ...error) {
c.StatusCode = status
c.AddErrors(errs...)
}
func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
return func(rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
c := newRequestContext(rw, r, route)
func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) {
defer func() {
/*
This panic recovery is the last resort. If you want to render
@ -207,9 +212,17 @@ func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
}
}()
h(c, p)
res := h(c)
rw.WriteHeader(c.StatusCode)
io.Copy(rw, c.Body)
if res.StatusCode == 0 {
res.StatusCode = http.StatusOK
}
for name, vals := range res.Headers() {
for _, val := range vals {
rw.Header().Add(name, val)
}
}
rw.WriteHeader(res.StatusCode)
io.Copy(rw, res.Body)
}

View File

@ -1,6 +1,7 @@
package website
import (
"bytes"
"context"
"errors"
"fmt"
@ -33,11 +34,12 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
}
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper)
mainRoutes.GET("/", func(c *RequestContext, p httprouter.Params) {
if c.currentProject.ID == models.HMNProjectID {
routes.Index(c, p)
mainRoutes.GET("/", func(c *RequestContext) ResponseData {
if c.CurrentProject.ID == models.HMNProjectID {
return routes.Index(c)
} else {
// TODO: Return the project landing page
panic("route not implemented")
}
})
mainRoutes.GET("/project/:id", routes.Project)
@ -48,29 +50,29 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
routes.ServeFiles("/public/*filepath", http.Dir("public"))
routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404")
routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour)
return routes
}
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
var templateUser *templates.User
if c.currentUser != nil {
if c.CurrentUser != nil {
templateUser = &templates.User{
Username: c.currentUser.Username,
Email: c.currentUser.Email,
IsSuperuser: c.currentUser.IsSuperuser,
IsStaff: c.currentUser.IsStaff,
Username: c.CurrentUser.Username,
Email: c.CurrentUser.Email,
IsSuperuser: c.CurrentUser.IsSuperuser,
IsStaff: c.CurrentUser.IsStaff,
}
}
return templates.BaseData{
Project: templates.Project{
Name: *c.currentProject.Name,
Subdomain: *c.currentProject.Slug,
Color: c.currentProject.Color1,
Name: *c.CurrentProject.Name,
Subdomain: *c.CurrentProject.Slug,
Color: c.CurrentProject.Color1,
IsHMN: c.currentProject.IsHMN(),
IsHMN: c.CurrentProject.IsHMN(),
HasBlog: true,
HasForum: true,
@ -104,24 +106,25 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
return defaultProject, nil
}
func (s *websiteRoutes) Project(c *RequestContext, p httprouter.Params) {
id := p.ByName("id")
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", p.ByName("id"))
func (s *websiteRoutes) Project(c *RequestContext) ResponseData {
id := c.PathParams.ByName("id")
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
var name string
err := row.Scan(&name)
if err != nil {
panic(err)
}
c.Body.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
var res ResponseData
res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
return res
}
func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color")
if color == "" {
c.StatusCode = http.StatusBadRequest
c.Body.Write([]byte("You must provide a 'color' parameter.\n"))
return
return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
}
templateData := struct {
@ -132,27 +135,28 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
Theme: "dark",
}
c.Headers().Add("Content-Type", "text/css")
err := c.WriteTemplate("project.css", templateData)
var res ResponseData
res.Headers().Add("Content-Type", "text/css")
err := res.WriteTemplate("project.css", templateData)
if err != nil {
c.Logger.Error().Err(err).Msg("failed to generate project CSS")
return
}
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
return res
}
func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
form, err := c.GetFormValues()
if err != nil {
c.Errored(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
return
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
}
username := form.Get("username")
password := form.Get("password")
if username == "" || password == "" {
c.Errored(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
}
redirect := form.Get("redirect")
@ -163,24 +167,23 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) {
c.StatusCode = http.StatusUnauthorized
return ResponseData{
StatusCode: http.StatusUnauthorized,
}
} else {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
}
return
}
user := userRow.(models.User)
user := userRow.(*models.User)
hashed, err := auth.ParsePasswordString(user.Password)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
}
passwordsMatch, err := auth.CheckPassword(password, hashed)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
}
if passwordsMatch {
@ -200,20 +203,19 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
session, err := auth.CreateSession(c.Context(), s.conn, username)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
}
c.SetCookie(auth.NewSessionCookie(session))
c.Redirect(redirect, http.StatusSeeOther)
return
res := c.Redirect(redirect, http.StatusSeeOther)
res.SetCookie(auth.NewSessionCookie(session))
return res
} else {
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
return
return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
}
}
func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
func (s *websiteRoutes) Logout(c *RequestContext) ResponseData {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
// clear the session from the db immediately, no expiration
@ -223,27 +225,33 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
}
}
c.SetCookie(auth.DeleteSessionCookie)
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
res.SetCookie(auth.DeleteSessionCookie)
return res
}
func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) {
c.StatusCode = http.StatusNotFound
c.Body.Write([]byte("go away\n"))
func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData {
return ResponseData{
StatusCode: http.StatusNotFound,
Body: bytes.NewBuffer([]byte("go away\n")),
}
}
func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
return func(c *RequestContext, p httprouter.Params) {
h(c, p)
return func(c *RequestContext) ResponseData {
res := h(c)
for _, err := range c.Errors {
for _, err := range res.Errors {
c.Logger.Error().Err(err).Msg("error occurred during request")
}
return res
}
}
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
return func(c *RequestContext, p httprouter.Params) {
return func(c *RequestContext) ResponseData {
// get project
{
slug := ""
@ -254,26 +262,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
c.currentProject = dbProject
c.CurrentProject = dbProject
}
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
if err != nil {
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
return
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
}
c.currentUser = user
c.CurrentUser = user
}
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
h(c, p)
return h(c)
}
}