Further clean up the request handling after talking with Asaf

This commit is contained in:
Ben Visness 2021-04-06 00:06:19 -05:00
parent 98da461d92
commit 7a01ddae66
3 changed files with 189 additions and 149 deletions

View File

@ -28,11 +28,11 @@ type LandingPagePost struct {
HasRead bool HasRead bool
} }
func (s *websiteRoutes) Index(c *RequestContext) ResponseData { func Index(c *RequestContext) ResponseData {
const maxPosts = 5 const maxPosts = 5
const numProjectsToGet = 7 const numProjectsToGet = 7
iterProjects, err := db.Query(c.Context(), s.conn, models.Project{}, iterProjects, err := db.Query(c.Context(), c.Conn, models.Project{},
"SELECT $columns FROM handmade_project WHERE flags = 0 OR id = $1", "SELECT $columns FROM handmade_project WHERE flags = 0 OR id = $1",
models.HMNProjectID, models.HMNProjectID,
) )
@ -54,7 +54,7 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
} }
memberId := 3 // TODO: NO memberId := 3 // TODO: NO
projectPostIter, err := db.Query(c.Context(), s.conn, ProjectPost{}, projectPostIter, err := db.Query(c.Context(), c.Conn, ProjectPost{},
` `
SELECT $columns SELECT $columns
FROM FROM
@ -116,7 +116,7 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
type newsThreadQuery struct { type newsThreadQuery struct {
Thread models.Thread `db:"thread"` Thread models.Thread `db:"thread"`
} }
newsThreadRow, err := db.QueryOne(c.Context(), s.conn, newsThreadQuery{}, newsThreadRow, err := db.QueryOne(c.Context(), c.Conn, newsThreadQuery{},
` `
SELECT $columns SELECT $columns
FROM FROM
@ -135,11 +135,11 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
newsThread := newsThreadRow.(*newsThreadQuery) newsThread := newsThreadRow.(*newsThreadQuery)
_ = newsThread // TODO: NO _ = newsThread // TODO: NO
baseData := s.getBaseData(c) baseData := getBaseData(c)
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more? baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
var res ResponseData var res ResponseData
err = res.WriteTemplate("index.html", s.getBaseData(c)) err = res.WriteTemplate("index.html", getBaseData(c))
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -13,56 +13,28 @@ import (
"git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/logging"
"git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/templates" "git.handmade.network/hmn/hmn/src/templates"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
type HMNRouter struct { // The typical handler. Handles a request and returns data about the response.
HttpRouter *httprouter.Router
Wrappers []HMNHandlerWrapper
}
func (r *HMNRouter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.HttpRouter.ServeHTTP(rw, req)
}
func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler {
for i := len(r.Wrappers) - 1; i >= 0; i-- {
handler = r.Wrappers[i](handler)
}
return handler
}
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
h := r.WrapHandler(handler)
r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
c := NewRequestContext(rw, req, p)
doRequest(rw, c, h)
})
}
func (r *HMNRouter) GET(route string, handler HMNHandler) {
r.Handle(http.MethodGet, route, handler)
}
func (r *HMNRouter) POST(route string, handler HMNHandler) {
r.Handle(http.MethodPost, route, handler)
}
// TODO: More methods
func (r *HMNRouter) ServeFiles(path string, root http.FileSystem) {
r.HttpRouter.ServeFiles(path, root)
}
func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter {
result := *r
result.Wrappers = append(result.Wrappers, wrappers...)
return &result
}
type HMNHandler func(c *RequestContext) ResponseData type HMNHandler func(c *RequestContext) ResponseData
type HMNHandlerWrapper func(h HMNHandler) HMNHandler
// A special handler that runs before the primary handler. Intended to set
// information on the context for later handlers, or to give the request a
// means to bail out early if preconditions are not met (like auth). If `ok`
// is false, the request will immediately bail out, no further handlers will
// be run, and it will respond with the provided response data.
//
// The response data from this function will still be fed through any after
// handlers, to ensure that errors will get logged and whatnot.
type HMNBeforeHandler func(c *RequestContext) (ok bool, res ResponseData)
// A special handler that runs after the primary handler and can modify the
// response information. Intended for error logging, error pages,
// cleanup, etc.
type HMNAfterHandler func(c *RequestContext, res ResponseData) ResponseData
func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
c := NewRequestContext(rw, req, nil) c := NewRequestContext(rw, req, nil)
@ -74,6 +46,7 @@ type RequestContext struct {
Req *http.Request Req *http.Request
PathParams httprouter.Params PathParams httprouter.Params
Conn *pgxpool.Pool
CurrentProject *models.Project CurrentProject *models.Project
CurrentUser *models.User CurrentUser *models.User
// CurrentMember *models.Member // CurrentMember *models.Member
@ -104,34 +77,6 @@ func (c *RequestContext) GetFormValues() (url.Values, error) {
return c.Req.PostForm, nil return c.Req.PostForm, nil
} }
type ResponseData struct {
StatusCode int
Body *bytes.Buffer
Errors []error
header http.Header
}
func (rd *ResponseData) Headers() http.Header {
if rd.header == nil {
rd.header = make(http.Header)
}
return rd.header
}
func (rd *ResponseData) Write(p []byte) (n int, err error) {
if rd.Body == nil {
rd.Body = new(bytes.Buffer)
}
return rd.Body.Write(p)
}
func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
rd.Headers().Add("Set-Cookie", cookie.String())
}
// The logic of this function is copy-pasted from the Go standard library. // The logic of this function is copy-pasted from the Go standard library.
// https://golang.org/pkg/net/http/#Redirect // https://golang.org/pkg/net/http/#Redirect
func (c *RequestContext) Redirect(dest string, code int) ResponseData { func (c *RequestContext) Redirect(dest string, code int) ResponseData {
@ -189,10 +134,88 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData {
return res return res
} }
type ResponseData struct {
StatusCode int
Body *bytes.Buffer
Errors []error
header http.Header
}
func (rd *ResponseData) Headers() http.Header {
if rd.header == nil {
rd.header = make(http.Header)
}
return rd.header
}
func (rd *ResponseData) Write(p []byte) (n int, err error) {
if rd.Body == nil {
rd.Body = new(bytes.Buffer)
}
return rd.Body.Write(p)
}
func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
rd.Headers().Add("Set-Cookie", cookie.String())
}
func (rd *ResponseData) WriteTemplate(name string, data interface{}) error { func (rd *ResponseData) WriteTemplate(name string, data interface{}) error {
return templates.Templates[name].Execute(rd, data) return templates.Templates[name].Execute(rd, data)
} }
type RouteBuilder struct {
Router *httprouter.Router
BeforeHandlers []HMNBeforeHandler
AfterHandlers []HMNAfterHandler
}
func (b RouteBuilder) ChainHandlers(h HMNHandler) HMNHandler {
return func(c *RequestContext) ResponseData {
beforeOk := true
var res ResponseData
for _, before := range b.BeforeHandlers {
if ok, errorRes := before(c); !ok {
beforeOk = false
res = errorRes
}
}
if beforeOk {
res = h(c)
}
for _, after := range b.AfterHandlers {
res = after(c, res)
}
return res
}
}
func (b *RouteBuilder) Handle(method, route string, handler HMNHandler) {
h := b.ChainHandlers(handler)
b.Router.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
c := NewRequestContext(rw, req, p)
doRequest(rw, c, h)
})
}
func (b *RouteBuilder) GET(route string, handler HMNHandler) {
b.Handle(http.MethodGet, route, handler)
}
func (b *RouteBuilder) POST(route string, handler HMNHandler) {
b.Handle(http.MethodPost, route, handler)
}
// TODO: More methods
func (b *RouteBuilder) ServeFiles(path string, root http.FileSystem) {
b.Router.ServeFiles(path, root)
}
func ErrorResponse(status int, errs ...error) ResponseData { func ErrorResponse(status int, errs ...error) ResponseData {
return ResponseData{ return ResponseData{
StatusCode: status, StatusCode: status,

View File

@ -18,44 +18,67 @@ import (
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
) )
type websiteRoutes struct {
*HMNRouter
conn *pgxpool.Pool
}
func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
routes := &websiteRoutes{ router := httprouter.New()
HMNRouter: &HMNRouter{ routes := RouteBuilder{
HttpRouter: httprouter.New(), Router: router,
Wrappers: []HMNHandlerWrapper{ErrorLoggingWrapper}, BeforeHandlers: []HMNBeforeHandler{
func(c *RequestContext) (bool, ResponseData) {
c.Conn = conn
return true, ResponseData{}
}, },
conn: conn, },
AfterHandlers: []HMNAfterHandler{ErrorLoggingHandler},
} }
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper) routes.POST("/login", Login)
routes.GET("/logout", Logout)
routes.ServeFiles("/public/*filepath", http.Dir("public"))
mainRoutes := routes
mainRoutes.BeforeHandlers = append(mainRoutes.BeforeHandlers,
CommonWebsiteDataWrapper,
)
mainRoutes.GET("/", func(c *RequestContext) ResponseData { mainRoutes.GET("/", func(c *RequestContext) ResponseData {
if c.CurrentProject.ID == models.HMNProjectID { if c.CurrentProject.ID == models.HMNProjectID {
return routes.Index(c) return Index(c)
} else { } else {
// TODO: Return the project landing page // TODO: Return the project landing page
panic("route not implemented") panic("route not implemented")
} }
}) })
mainRoutes.GET("/project/:id", routes.Project) mainRoutes.GET("/project/:id", Project)
mainRoutes.GET("/assets/project.css", routes.ProjectCSS) mainRoutes.GET("/assets/project.css", ProjectCSS)
routes.POST("/login", routes.Login) router.NotFound = mainRoutes.ChainHandlers(FourOhFour)
routes.GET("/logout", routes.Logout)
routes.ServeFiles("/public/*filepath", http.Dir("public")) adminRoutes := routes
adminRoutes.BeforeHandlers = append(adminRoutes.BeforeHandlers,
func(c *RequestContext) (ok bool, res ResponseData) {
return false, ResponseData{
StatusCode: http.StatusUnauthorized,
Body: bytes.NewBufferString("No one is allowed!\n"),
}
},
)
adminRoutes.AfterHandlers = append(adminRoutes.AfterHandlers,
func(c *RequestContext, res ResponseData) ResponseData {
res.Body.WriteString("Now go away. Sincerely, the after handler.\n")
return res
},
)
routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour) adminRoutes.GET("/admin", func(c *RequestContext) ResponseData {
return ResponseData{
Body: bytes.NewBufferString("Here are all the secrets.\n"),
}
})
return routes return router
} }
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData { func getBaseData(c *RequestContext) templates.BaseData {
var templateUser *templates.User var templateUser *templates.User
if c.CurrentUser != nil { if c.CurrentUser != nil {
templateUser = &templates.User{ templateUser = &templates.User{
@ -106,9 +129,9 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
return defaultProject, nil return defaultProject, nil
} }
func (s *websiteRoutes) Project(c *RequestContext) ResponseData { func Project(c *RequestContext) ResponseData {
id := c.PathParams.ByName("id") id := c.PathParams.ByName("id")
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id")) row := c.Conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
var name string var name string
err := row.Scan(&name) err := row.Scan(&name)
if err != nil { if err != nil {
@ -121,7 +144,7 @@ func (s *websiteRoutes) Project(c *RequestContext) ResponseData {
return res return res
} }
func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData { func ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color") color := c.URL().Query().Get("color")
if color == "" { if color == "" {
return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
@ -145,7 +168,7 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData {
return res return res
} }
func (s *websiteRoutes) Login(c *RequestContext) ResponseData { func Login(c *RequestContext) ResponseData {
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks. // TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
form, err := c.GetFormValues() form, err := c.GetFormValues()
@ -164,7 +187,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
redirect = "/" redirect = "/"
} }
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username) userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) { if errors.Is(err, db.ErrNoMatchingRows) {
return ResponseData{ return ResponseData{
@ -191,7 +214,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
if hashed.IsOutdated() { if hashed.IsOutdated() {
newHashed, err := auth.HashPassword(password) newHashed, err := auth.HashPassword(password)
if err == nil { if err == nil {
err := auth.UpdatePassword(c.Context(), s.conn, username, newHashed) err := auth.UpdatePassword(c.Context(), c.Conn, username, newHashed)
if err != nil { if err != nil {
c.Logger.Error().Err(err).Msg("failed to update user's password") c.Logger.Error().Err(err).Msg("failed to update user's password")
} }
@ -201,7 +224,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
// If errors happen here, we can still continue with logging them in // If errors happen here, we can still continue with logging them in
} }
session, err := auth.CreateSession(c.Context(), s.conn, username) session, err := auth.CreateSession(c.Context(), c.Conn, username)
if err != nil { if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session")) return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
} }
@ -215,11 +238,11 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
} }
} }
func (s *websiteRoutes) Logout(c *RequestContext) ResponseData { func Logout(c *RequestContext) ResponseData {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil { if err == nil {
// clear the session from the db immediately, no expiration // clear the session from the db immediately, no expiration
err := auth.DeleteSession(c.Context(), s.conn, sessionCookie.Value) err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value)
if err != nil { if err != nil {
logging.Error().Err(err).Msg("failed to delete session on logout") logging.Error().Err(err).Msg("failed to delete session on logout")
} }
@ -231,27 +254,22 @@ func (s *websiteRoutes) Logout(c *RequestContext) ResponseData {
return res return res
} }
func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData { func FourOhFour(c *RequestContext) ResponseData {
return ResponseData{ return ResponseData{
StatusCode: http.StatusNotFound, StatusCode: http.StatusNotFound,
Body: bytes.NewBuffer([]byte("go away\n")), Body: bytes.NewBufferString("go away\n"),
} }
} }
func ErrorLoggingWrapper(h HMNHandler) HMNHandler { func ErrorLoggingHandler(c *RequestContext, res ResponseData) ResponseData {
return func(c *RequestContext) ResponseData {
res := h(c)
for _, err := range res.Errors { for _, err := range res.Errors {
c.Logger.Error().Err(err).Msg("error occurred during request") c.Logger.Error().Err(err).Msg("error occurred during request")
} }
return res return res
} }
}
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler { func CommonWebsiteDataWrapper(c *RequestContext) (bool, ResponseData) {
return func(c *RequestContext) ResponseData {
// get project // get project
{ {
slug := "" slug := ""
@ -260,9 +278,9 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
slug = hostParts[0] slug = hostParts[0]
} }
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug) dbProject, err := FetchProjectBySlug(c.Context(), c.Conn, slug)
if err != nil { if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) return false, ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
} }
c.CurrentProject = dbProject c.CurrentProject = dbProject
@ -270,25 +288,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil { if err == nil {
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value) user, err := getCurrentUserAndMember(c, sessionCookie.Value)
if err != nil { if err != nil {
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) return false, ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
} }
c.CurrentUser = user c.CurrentUser = user
} }
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
return h(c) return true, ResponseData{}
}
} }
// Given a session id, fetches user and member data from the database. Will return nil for // Given a session id, fetches user and member data from the database. Will return nil for
// both if neither can be found, and will only return an error if it's serious. // both if neither can be found, and will only return an error if it's serious.
// //
// TODO: actually return members :) // TODO: actually return members :)
func (s *websiteRoutes) getCurrentUserAndMember(ctx context.Context, sessionId string) (*models.User, error) { func getCurrentUserAndMember(c *RequestContext, sessionId string) (*models.User, error) {
session, err := auth.GetSession(ctx, s.conn, sessionId) session, err := auth.GetSession(c.Context(), c.Conn, sessionId)
if err != nil { if err != nil {
if errors.Is(err, auth.ErrNoSession) { if errors.Is(err, auth.ErrNoSession) {
return nil, nil return nil, nil
@ -297,7 +314,7 @@ func (s *websiteRoutes) getCurrentUserAndMember(ctx context.Context, sessionId s
} }
} }
userRow, err := db.QueryOne(ctx, s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username) userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) { if errors.Is(err, db.ErrNoMatchingRows) {
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found") logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")