diff --git a/src/website/landing.go b/src/website/landing.go
index da1a1236..6f54f018 100644
--- a/src/website/landing.go
+++ b/src/website/landing.go
@@ -8,7 +8,6 @@ import (
"git.handmade.network/hmn/hmn/src/models"
"git.handmade.network/hmn/hmn/src/oops"
"git.handmade.network/hmn/hmn/src/templates"
- "github.com/julienschmidt/httprouter"
)
type LandingTemplateData struct {
@@ -29,7 +28,7 @@ type LandingPagePost struct {
HasRead bool
}
-func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
+func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
const maxPosts = 5
const numProjectsToGet = 7
@@ -38,8 +37,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.HMNProjectID,
)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
}
defer iterProjects.Close()
@@ -132,8 +130,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
models.CatTypeBlog,
)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
}
newsThread := newsThreadRow.(*newsThreadQuery)
_ = newsThread // TODO: NO
@@ -141,8 +138,11 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
baseData := s.getBaseData(c)
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
- err = c.WriteTemplate("index.html", s.getBaseData(c))
+ var res ResponseData
+ err = res.WriteTemplate("index.html", s.getBaseData(c))
if err != nil {
panic(err)
}
+
+ return res
}
diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go
index dc06c6b9..0479b064 100644
--- a/src/website/requesthandling.go
+++ b/src/website/requesthandling.go
@@ -3,7 +3,6 @@ package website
import (
"bytes"
"context"
- "fmt"
"html"
"io"
"net/http"
@@ -35,7 +34,11 @@ func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler {
}
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
- r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler)))
+ h := r.WrapHandler(handler)
+ r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
+ c := NewRequestContext(rw, req, p)
+ doRequest(rw, c, h)
+ })
}
func (r *HMNRouter) GET(route string, handler HMNHandler) {
@@ -58,39 +61,29 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter {
return &result
}
-type HMNHandler func(c *RequestContext, p httprouter.Params)
+type HMNHandler func(c *RequestContext) ResponseData
type HMNHandlerWrapper func(h HMNHandler) HMNHandler
-func MakeStdHandler(h HMNHandler, name string) http.Handler {
- return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
- handleHmnHandler(name, h)(rw, req, nil)
- })
+func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ c := NewRequestContext(rw, req, nil)
+ doRequest(rw, c, h)
}
type RequestContext struct {
- StatusCode int
- Body *bytes.Buffer
Logger *zerolog.Logger
Req *http.Request
- Errors []error
+ PathParams httprouter.Params
- rw http.ResponseWriter
-
- currentProject *models.Project
- currentUser *models.User
- // currentMember *models.Member
+ CurrentProject *models.Project
+ CurrentUser *models.User
+ // CurrentMember *models.Member
}
-func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext {
- logger := logging.With().Str("route", route).Logger()
-
+func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params) *RequestContext {
return &RequestContext{
- StatusCode: http.StatusOK,
- Body: new(bytes.Buffer),
- Logger: &logger,
+ Logger: logging.GlobalLogger(),
Req: req,
-
- rw: rw,
+ PathParams: pathParams,
}
}
@@ -102,14 +95,6 @@ func (c *RequestContext) URL() *url.URL {
return c.Req.URL
}
-func (c *RequestContext) Headers() http.Header {
- return c.rw.Header()
-}
-
-func (c *RequestContext) SetCookie(cookie *http.Cookie) {
- c.rw.Header().Add("Set-Cookie", cookie.String())
-}
-
func (c *RequestContext) GetFormValues() (url.Values, error) {
err := c.Req.ParseForm()
if err != nil {
@@ -119,9 +104,39 @@ func (c *RequestContext) GetFormValues() (url.Values, error) {
return c.Req.PostForm, nil
}
+type ResponseData struct {
+ StatusCode int
+ Body *bytes.Buffer
+ Errors []error
+
+ header http.Header
+}
+
+func (rd *ResponseData) Headers() http.Header {
+ if rd.header == nil {
+ rd.header = make(http.Header)
+ }
+
+ return rd.header
+}
+
+func (rd *ResponseData) Write(p []byte) (n int, err error) {
+ if rd.Body == nil {
+ rd.Body = new(bytes.Buffer)
+ }
+
+ return rd.Body.Write(p)
+}
+
+func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
+ rd.Headers().Add("Set-Cookie", cookie.String())
+}
+
// The logic of this function is copy-pasted from the Go standard library.
// https://golang.org/pkg/net/http/#Redirect
-func (c *RequestContext) Redirect(dest string, code int) {
+func (c *RequestContext) Redirect(dest string, code int) ResponseData {
+ var res ResponseData
+
if u, err := url.Parse(dest); err == nil {
// If url was relative, make its path absolute by
// combining with request path.
@@ -156,60 +171,58 @@ func (c *RequestContext) Redirect(dest string, code int) {
}
}
- h := c.Headers()
-
- // RFC 7231 notes that a short HTML body is usually included in
- // the response because older user agents may not understand 301/307.
- // Do it only if the request didn't already have a Content-Type header.
- _, hadCT := h["Content-Type"]
-
// Escape stuff
destUrl, _ := url.Parse(dest)
dest = destUrl.String()
- h.Set("Location", dest)
- if !hadCT && (c.Req.Method == "GET" || c.Req.Method == "HEAD") {
- h.Set("Content-Type", "text/html; charset=utf-8")
+ res.Headers().Set("Location", dest)
+ if c.Req.Method == "GET" || c.Req.Method == "HEAD" {
+ res.Headers().Set("Content-Type", "text/html; charset=utf-8")
}
- c.StatusCode = code
+ res.StatusCode = code
// Shouldn't send the body for POST or HEAD; that leaves GET.
- if !hadCT && c.Req.Method == "GET" {
- body := "" + http.StatusText(code) + ".\n"
- fmt.Fprintln(c.Body, body)
+ if c.Req.Method == "GET" {
+ res.Write([]byte("" + http.StatusText(code) + ".\n"))
+ }
+
+ return res
+}
+
+func (rd *ResponseData) WriteTemplate(name string, data interface{}) error {
+ return templates.Templates[name].Execute(rd, data)
+}
+
+func ErrorResponse(status int, errs ...error) ResponseData {
+ return ResponseData{
+ StatusCode: status,
+ Errors: errs,
}
}
-func (c *RequestContext) WriteTemplate(name string, data interface{}) error {
- return templates.Templates[name].Execute(c.Body, data)
-}
+func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) {
+ defer func() {
+ /*
+ This panic recovery is the last resort. If you want to render
+ an error page or something, make it a request wrapper.
+ */
+ if recovered := recover(); recovered != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ logging.LogPanicValue(c.Logger, recovered, "request panicked and was not handled")
+ }
+ }()
-func (c *RequestContext) AddErrors(errs ...error) {
- c.Errors = append(c.Errors, errs...)
-}
+ res := h(c)
-func (c *RequestContext) Errored(status int, errs ...error) {
- c.StatusCode = status
- c.AddErrors(errs...)
-}
-
-func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
- return func(rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
- c := newRequestContext(rw, r, route)
- defer func() {
- /*
- This panic recovery is the last resort. If you want to render
- an error page or something, make it a request wrapper.
- */
- if recovered := recover(); recovered != nil {
- rw.WriteHeader(http.StatusInternalServerError)
- logging.LogPanicValue(c.Logger, recovered, "request panicked and was not handled")
- }
- }()
-
- h(c, p)
-
- rw.WriteHeader(c.StatusCode)
- io.Copy(rw, c.Body)
+ if res.StatusCode == 0 {
+ res.StatusCode = http.StatusOK
}
+
+ for name, vals := range res.Headers() {
+ for _, val := range vals {
+ rw.Header().Add(name, val)
+ }
+ }
+ rw.WriteHeader(res.StatusCode)
+ io.Copy(rw, res.Body)
}
diff --git a/src/website/routes.go b/src/website/routes.go
index 272ebcb2..2fe2af5e 100644
--- a/src/website/routes.go
+++ b/src/website/routes.go
@@ -1,6 +1,7 @@
package website
import (
+ "bytes"
"context"
"errors"
"fmt"
@@ -33,11 +34,12 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
}
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper)
- mainRoutes.GET("/", func(c *RequestContext, p httprouter.Params) {
- if c.currentProject.ID == models.HMNProjectID {
- routes.Index(c, p)
+ mainRoutes.GET("/", func(c *RequestContext) ResponseData {
+ if c.CurrentProject.ID == models.HMNProjectID {
+ return routes.Index(c)
} else {
// TODO: Return the project landing page
+ panic("route not implemented")
}
})
mainRoutes.GET("/project/:id", routes.Project)
@@ -48,29 +50,29 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
routes.ServeFiles("/public/*filepath", http.Dir("public"))
- routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404")
+ routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour)
return routes
}
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
var templateUser *templates.User
- if c.currentUser != nil {
+ if c.CurrentUser != nil {
templateUser = &templates.User{
- Username: c.currentUser.Username,
- Email: c.currentUser.Email,
- IsSuperuser: c.currentUser.IsSuperuser,
- IsStaff: c.currentUser.IsStaff,
+ Username: c.CurrentUser.Username,
+ Email: c.CurrentUser.Email,
+ IsSuperuser: c.CurrentUser.IsSuperuser,
+ IsStaff: c.CurrentUser.IsStaff,
}
}
return templates.BaseData{
Project: templates.Project{
- Name: *c.currentProject.Name,
- Subdomain: *c.currentProject.Slug,
- Color: c.currentProject.Color1,
+ Name: *c.CurrentProject.Name,
+ Subdomain: *c.CurrentProject.Slug,
+ Color: c.CurrentProject.Color1,
- IsHMN: c.currentProject.IsHMN(),
+ IsHMN: c.CurrentProject.IsHMN(),
HasBlog: true,
HasForum: true,
@@ -104,24 +106,25 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
return defaultProject, nil
}
-func (s *websiteRoutes) Project(c *RequestContext, p httprouter.Params) {
- id := p.ByName("id")
- row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", p.ByName("id"))
+func (s *websiteRoutes) Project(c *RequestContext) ResponseData {
+ id := c.PathParams.ByName("id")
+ row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
var name string
err := row.Scan(&name)
if err != nil {
panic(err)
}
- c.Body.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
+ var res ResponseData
+ res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
+
+ return res
}
-func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
+func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData {
color := c.URL().Query().Get("color")
if color == "" {
- c.StatusCode = http.StatusBadRequest
- c.Body.Write([]byte("You must provide a 'color' parameter.\n"))
- return
+ return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
}
templateData := struct {
@@ -132,27 +135,28 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
Theme: "dark",
}
- c.Headers().Add("Content-Type", "text/css")
- err := c.WriteTemplate("project.css", templateData)
+ var res ResponseData
+ res.Headers().Add("Content-Type", "text/css")
+ err := res.WriteTemplate("project.css", templateData)
if err != nil {
- c.Logger.Error().Err(err).Msg("failed to generate project CSS")
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
}
+
+ return res
}
-func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
+func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
form, err := c.GetFormValues()
if err != nil {
- c.Errored(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
- return
+ return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
}
username := form.Get("username")
password := form.Get("password")
if username == "" || password == "" {
- c.Errored(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
+ return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
}
redirect := form.Get("redirect")
@@ -163,24 +167,23 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
if err != nil {
if errors.Is(err, db.ErrNoMatchingRows) {
- c.StatusCode = http.StatusUnauthorized
+ return ResponseData{
+ StatusCode: http.StatusUnauthorized,
+ }
} else {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
}
- return
}
- user := userRow.(models.User)
+ user := userRow.(*models.User)
hashed, err := auth.ParsePasswordString(user.Password)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
}
passwordsMatch, err := auth.CheckPassword(password, hashed)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
}
if passwordsMatch {
@@ -200,20 +203,19 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
session, err := auth.CreateSession(c.Context(), s.conn, username)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
}
- c.SetCookie(auth.NewSessionCookie(session))
- c.Redirect(redirect, http.StatusSeeOther)
- return
+ res := c.Redirect(redirect, http.StatusSeeOther)
+ res.SetCookie(auth.NewSessionCookie(session))
+
+ return res
} else {
- c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
- return
+ return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
}
}
-func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
+func (s *websiteRoutes) Logout(c *RequestContext) ResponseData {
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
// clear the session from the db immediately, no expiration
@@ -223,27 +225,33 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
}
}
- c.SetCookie(auth.DeleteSessionCookie)
- c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
+ res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
+ res.SetCookie(auth.DeleteSessionCookie)
+
+ return res
}
-func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) {
- c.StatusCode = http.StatusNotFound
- c.Body.Write([]byte("go away\n"))
+func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData {
+ return ResponseData{
+ StatusCode: http.StatusNotFound,
+ Body: bytes.NewBuffer([]byte("go away\n")),
+ }
}
func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
- return func(c *RequestContext, p httprouter.Params) {
- h(c, p)
+ return func(c *RequestContext) ResponseData {
+ res := h(c)
- for _, err := range c.Errors {
+ for _, err := range res.Errors {
c.Logger.Error().Err(err).Msg("error occurred during request")
}
+
+ return res
}
}
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
- return func(c *RequestContext, p httprouter.Params) {
+ return func(c *RequestContext) ResponseData {
// get project
{
slug := ""
@@ -254,26 +262,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
}
- c.currentProject = dbProject
+ c.CurrentProject = dbProject
}
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
if err == nil {
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
if err != nil {
- c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
- return
+ return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
}
- c.currentUser = user
+ c.CurrentUser = user
}
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
- h(c, p)
+ return h(c)
}
}