diff --git a/src/website/landing.go b/src/website/landing.go index da1a1236..6f54f018 100644 --- a/src/website/landing.go +++ b/src/website/landing.go @@ -8,7 +8,6 @@ import ( "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/oops" "git.handmade.network/hmn/hmn/src/templates" - "github.com/julienschmidt/httprouter" ) type LandingTemplateData struct { @@ -29,7 +28,7 @@ type LandingPagePost struct { HasRead bool } -func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) { +func (s *websiteRoutes) Index(c *RequestContext) ResponseData { const maxPosts = 5 const numProjectsToGet = 7 @@ -38,8 +37,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) { models.HMNProjectID, ) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page")) } defer iterProjects.Close() @@ -132,8 +130,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) { models.CatTypeBlog, ) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post")) } newsThread := newsThreadRow.(*newsThreadQuery) _ = newsThread // TODO: NO @@ -141,8 +138,11 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) { baseData := s.getBaseData(c) baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more? - err = c.WriteTemplate("index.html", s.getBaseData(c)) + var res ResponseData + err = res.WriteTemplate("index.html", s.getBaseData(c)) if err != nil { panic(err) } + + return res } diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index dc06c6b9..0479b064 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -3,7 +3,6 @@ package website import ( "bytes" "context" - "fmt" "html" "io" "net/http" @@ -35,7 +34,11 @@ func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler { } func (r *HMNRouter) Handle(method, route string, handler HMNHandler) { - r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler))) + h := r.WrapHandler(handler) + r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) { + c := NewRequestContext(rw, req, p) + doRequest(rw, c, h) + }) } func (r *HMNRouter) GET(route string, handler HMNHandler) { @@ -58,39 +61,29 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter { return &result } -type HMNHandler func(c *RequestContext, p httprouter.Params) +type HMNHandler func(c *RequestContext) ResponseData type HMNHandlerWrapper func(h HMNHandler) HMNHandler -func MakeStdHandler(h HMNHandler, name string) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - handleHmnHandler(name, h)(rw, req, nil) - }) +func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + c := NewRequestContext(rw, req, nil) + doRequest(rw, c, h) } type RequestContext struct { - StatusCode int - Body *bytes.Buffer Logger *zerolog.Logger Req *http.Request - Errors []error + PathParams httprouter.Params - rw http.ResponseWriter - - currentProject *models.Project - currentUser *models.User - // currentMember *models.Member + CurrentProject *models.Project + CurrentUser *models.User + // CurrentMember *models.Member } -func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext { - logger := logging.With().Str("route", route).Logger() - +func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params) *RequestContext { return &RequestContext{ - StatusCode: http.StatusOK, - Body: new(bytes.Buffer), - Logger: &logger, + Logger: logging.GlobalLogger(), Req: req, - - rw: rw, + PathParams: pathParams, } } @@ -102,14 +95,6 @@ func (c *RequestContext) URL() *url.URL { return c.Req.URL } -func (c *RequestContext) Headers() http.Header { - return c.rw.Header() -} - -func (c *RequestContext) SetCookie(cookie *http.Cookie) { - c.rw.Header().Add("Set-Cookie", cookie.String()) -} - func (c *RequestContext) GetFormValues() (url.Values, error) { err := c.Req.ParseForm() if err != nil { @@ -119,9 +104,39 @@ func (c *RequestContext) GetFormValues() (url.Values, error) { return c.Req.PostForm, nil } +type ResponseData struct { + StatusCode int + Body *bytes.Buffer + Errors []error + + header http.Header +} + +func (rd *ResponseData) Headers() http.Header { + if rd.header == nil { + rd.header = make(http.Header) + } + + return rd.header +} + +func (rd *ResponseData) Write(p []byte) (n int, err error) { + if rd.Body == nil { + rd.Body = new(bytes.Buffer) + } + + return rd.Body.Write(p) +} + +func (rd *ResponseData) SetCookie(cookie *http.Cookie) { + rd.Headers().Add("Set-Cookie", cookie.String()) +} + // The logic of this function is copy-pasted from the Go standard library. // https://golang.org/pkg/net/http/#Redirect -func (c *RequestContext) Redirect(dest string, code int) { +func (c *RequestContext) Redirect(dest string, code int) ResponseData { + var res ResponseData + if u, err := url.Parse(dest); err == nil { // If url was relative, make its path absolute by // combining with request path. @@ -156,60 +171,58 @@ func (c *RequestContext) Redirect(dest string, code int) { } } - h := c.Headers() - - // RFC 7231 notes that a short HTML body is usually included in - // the response because older user agents may not understand 301/307. - // Do it only if the request didn't already have a Content-Type header. - _, hadCT := h["Content-Type"] - // Escape stuff destUrl, _ := url.Parse(dest) dest = destUrl.String() - h.Set("Location", dest) - if !hadCT && (c.Req.Method == "GET" || c.Req.Method == "HEAD") { - h.Set("Content-Type", "text/html; charset=utf-8") + res.Headers().Set("Location", dest) + if c.Req.Method == "GET" || c.Req.Method == "HEAD" { + res.Headers().Set("Content-Type", "text/html; charset=utf-8") } - c.StatusCode = code + res.StatusCode = code // Shouldn't send the body for POST or HEAD; that leaves GET. - if !hadCT && c.Req.Method == "GET" { - body := "" + http.StatusText(code) + ".\n" - fmt.Fprintln(c.Body, body) + if c.Req.Method == "GET" { + res.Write([]byte("" + http.StatusText(code) + ".\n")) + } + + return res +} + +func (rd *ResponseData) WriteTemplate(name string, data interface{}) error { + return templates.Templates[name].Execute(rd, data) +} + +func ErrorResponse(status int, errs ...error) ResponseData { + return ResponseData{ + StatusCode: status, + Errors: errs, } } -func (c *RequestContext) WriteTemplate(name string, data interface{}) error { - return templates.Templates[name].Execute(c.Body, data) -} +func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) { + defer func() { + /* + This panic recovery is the last resort. If you want to render + an error page or something, make it a request wrapper. + */ + if recovered := recover(); recovered != nil { + rw.WriteHeader(http.StatusInternalServerError) + logging.LogPanicValue(c.Logger, recovered, "request panicked and was not handled") + } + }() -func (c *RequestContext) AddErrors(errs ...error) { - c.Errors = append(c.Errors, errs...) -} + res := h(c) -func (c *RequestContext) Errored(status int, errs ...error) { - c.StatusCode = status - c.AddErrors(errs...) -} - -func handleHmnHandler(route string, h HMNHandler) httprouter.Handle { - return func(rw http.ResponseWriter, r *http.Request, p httprouter.Params) { - c := newRequestContext(rw, r, route) - defer func() { - /* - This panic recovery is the last resort. If you want to render - an error page or something, make it a request wrapper. - */ - if recovered := recover(); recovered != nil { - rw.WriteHeader(http.StatusInternalServerError) - logging.LogPanicValue(c.Logger, recovered, "request panicked and was not handled") - } - }() - - h(c, p) - - rw.WriteHeader(c.StatusCode) - io.Copy(rw, c.Body) + if res.StatusCode == 0 { + res.StatusCode = http.StatusOK } + + for name, vals := range res.Headers() { + for _, val := range vals { + rw.Header().Add(name, val) + } + } + rw.WriteHeader(res.StatusCode) + io.Copy(rw, res.Body) } diff --git a/src/website/routes.go b/src/website/routes.go index 272ebcb2..2fe2af5e 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -1,6 +1,7 @@ package website import ( + "bytes" "context" "errors" "fmt" @@ -33,11 +34,12 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { } mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper) - mainRoutes.GET("/", func(c *RequestContext, p httprouter.Params) { - if c.currentProject.ID == models.HMNProjectID { - routes.Index(c, p) + mainRoutes.GET("/", func(c *RequestContext) ResponseData { + if c.CurrentProject.ID == models.HMNProjectID { + return routes.Index(c) } else { // TODO: Return the project landing page + panic("route not implemented") } }) mainRoutes.GET("/project/:id", routes.Project) @@ -48,29 +50,29 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { routes.ServeFiles("/public/*filepath", http.Dir("public")) - routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404") + routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour) return routes } func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData { var templateUser *templates.User - if c.currentUser != nil { + if c.CurrentUser != nil { templateUser = &templates.User{ - Username: c.currentUser.Username, - Email: c.currentUser.Email, - IsSuperuser: c.currentUser.IsSuperuser, - IsStaff: c.currentUser.IsStaff, + Username: c.CurrentUser.Username, + Email: c.CurrentUser.Email, + IsSuperuser: c.CurrentUser.IsSuperuser, + IsStaff: c.CurrentUser.IsStaff, } } return templates.BaseData{ Project: templates.Project{ - Name: *c.currentProject.Name, - Subdomain: *c.currentProject.Slug, - Color: c.currentProject.Color1, + Name: *c.CurrentProject.Name, + Subdomain: *c.CurrentProject.Slug, + Color: c.CurrentProject.Color1, - IsHMN: c.currentProject.IsHMN(), + IsHMN: c.CurrentProject.IsHMN(), HasBlog: true, HasForum: true, @@ -104,24 +106,25 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* return defaultProject, nil } -func (s *websiteRoutes) Project(c *RequestContext, p httprouter.Params) { - id := p.ByName("id") - row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", p.ByName("id")) +func (s *websiteRoutes) Project(c *RequestContext) ResponseData { + id := c.PathParams.ByName("id") + row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id")) var name string err := row.Scan(&name) if err != nil { panic(err) } - c.Body.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name))) + var res ResponseData + res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name))) + + return res } -func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) { +func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData { color := c.URL().Query().Get("color") if color == "" { - c.StatusCode = http.StatusBadRequest - c.Body.Write([]byte("You must provide a 'color' parameter.\n")) - return + return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) } templateData := struct { @@ -132,27 +135,28 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) { Theme: "dark", } - c.Headers().Add("Content-Type", "text/css") - err := c.WriteTemplate("project.css", templateData) + var res ResponseData + res.Headers().Add("Content-Type", "text/css") + err := res.WriteTemplate("project.css", templateData) if err != nil { - c.Logger.Error().Err(err).Msg("failed to generate project CSS") - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS")) } + + return res } -func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) { +func (s *websiteRoutes) Login(c *RequestContext) ResponseData { // TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks. form, err := c.GetFormValues() if err != nil { - c.Errored(http.StatusBadRequest, NewSafeError(err, "request must contain form data")) - return + return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data")) } username := form.Get("username") password := form.Get("password") if username == "" || password == "" { - c.Errored(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password")) + return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password")) } redirect := form.Get("redirect") @@ -163,24 +167,23 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) { userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username) if err != nil { if errors.Is(err, db.ErrNoMatchingRows) { - c.StatusCode = http.StatusUnauthorized + return ResponseData{ + StatusCode: http.StatusUnauthorized, + } } else { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username")) } - return } - user := userRow.(models.User) + user := userRow.(*models.User) hashed, err := auth.ParsePasswordString(user.Password) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to parse password string")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string")) } passwordsMatch, err := auth.CheckPassword(password, hashed) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to check password against hash")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash")) } if passwordsMatch { @@ -200,20 +203,19 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) { session, err := auth.CreateSession(c.Context(), s.conn, username) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session")) } - c.SetCookie(auth.NewSessionCookie(session)) - c.Redirect(redirect, http.StatusSeeOther) - return + res := c.Redirect(redirect, http.StatusSeeOther) + res.SetCookie(auth.NewSessionCookie(session)) + + return res } else { - c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error - return + return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error } } -func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) { +func (s *websiteRoutes) Logout(c *RequestContext) ResponseData { sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) if err == nil { // clear the session from the db immediately, no expiration @@ -223,27 +225,33 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) { } } - c.SetCookie(auth.DeleteSessionCookie) - c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page. + res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page. + res.SetCookie(auth.DeleteSessionCookie) + + return res } -func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) { - c.StatusCode = http.StatusNotFound - c.Body.Write([]byte("go away\n")) +func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData { + return ResponseData{ + StatusCode: http.StatusNotFound, + Body: bytes.NewBuffer([]byte("go away\n")), + } } func ErrorLoggingWrapper(h HMNHandler) HMNHandler { - return func(c *RequestContext, p httprouter.Params) { - h(c, p) + return func(c *RequestContext) ResponseData { + res := h(c) - for _, err := range c.Errors { + for _, err := range res.Errors { c.Logger.Error().Err(err).Msg("error occurred during request") } + + return res } } func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler { - return func(c *RequestContext, p httprouter.Params) { + return func(c *RequestContext) ResponseData { // get project { slug := "" @@ -254,26 +262,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler { dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) } - c.currentProject = dbProject + c.CurrentProject = dbProject } sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) if err == nil { user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value) if err != nil { - c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) - return + return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) } - c.currentUser = user + c.CurrentUser = user } // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. - h(c, p) + return h(c) } }