diff --git a/src/website/landing.go b/src/website/landing.go index 6f54f018..ad2af6da 100644 --- a/src/website/landing.go +++ b/src/website/landing.go @@ -28,11 +28,11 @@ type LandingPagePost struct { HasRead bool } -func (s *websiteRoutes) Index(c *RequestContext) ResponseData { +func Index(c *RequestContext) ResponseData { const maxPosts = 5 const numProjectsToGet = 7 - iterProjects, err := db.Query(c.Context(), s.conn, models.Project{}, + iterProjects, err := db.Query(c.Context(), c.Conn, models.Project{}, "SELECT $columns FROM handmade_project WHERE flags = 0 OR id = $1", models.HMNProjectID, ) @@ -54,7 +54,7 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData { } memberId := 3 // TODO: NO - projectPostIter, err := db.Query(c.Context(), s.conn, ProjectPost{}, + projectPostIter, err := db.Query(c.Context(), c.Conn, ProjectPost{}, ` SELECT $columns FROM @@ -116,7 +116,7 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData { type newsThreadQuery struct { Thread models.Thread `db:"thread"` } - newsThreadRow, err := db.QueryOne(c.Context(), s.conn, newsThreadQuery{}, + newsThreadRow, err := db.QueryOne(c.Context(), c.Conn, newsThreadQuery{}, ` SELECT $columns FROM @@ -135,11 +135,11 @@ func (s *websiteRoutes) Index(c *RequestContext) ResponseData { newsThread := newsThreadRow.(*newsThreadQuery) _ = newsThread // TODO: NO - baseData := s.getBaseData(c) + baseData := getBaseData(c) baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more? var res ResponseData - err = res.WriteTemplate("index.html", s.getBaseData(c)) + err = res.WriteTemplate("index.html", getBaseData(c)) if err != nil { panic(err) } diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index 0479b064..d4d68bc9 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -13,56 +13,28 @@ import ( "git.handmade.network/hmn/hmn/src/logging" "git.handmade.network/hmn/hmn/src/models" "git.handmade.network/hmn/hmn/src/templates" + "github.com/jackc/pgx/v4/pgxpool" "github.com/julienschmidt/httprouter" "github.com/rs/zerolog" ) -type HMNRouter struct { - HttpRouter *httprouter.Router - Wrappers []HMNHandlerWrapper -} - -func (r *HMNRouter) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - r.HttpRouter.ServeHTTP(rw, req) -} - -func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler { - for i := len(r.Wrappers) - 1; i >= 0; i-- { - handler = r.Wrappers[i](handler) - } - return handler -} - -func (r *HMNRouter) Handle(method, route string, handler HMNHandler) { - h := r.WrapHandler(handler) - r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) { - c := NewRequestContext(rw, req, p) - doRequest(rw, c, h) - }) -} - -func (r *HMNRouter) GET(route string, handler HMNHandler) { - r.Handle(http.MethodGet, route, handler) -} - -func (r *HMNRouter) POST(route string, handler HMNHandler) { - r.Handle(http.MethodPost, route, handler) -} - -// TODO: More methods - -func (r *HMNRouter) ServeFiles(path string, root http.FileSystem) { - r.HttpRouter.ServeFiles(path, root) -} - -func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter { - result := *r - result.Wrappers = append(result.Wrappers, wrappers...) - return &result -} - +// The typical handler. Handles a request and returns data about the response. type HMNHandler func(c *RequestContext) ResponseData -type HMNHandlerWrapper func(h HMNHandler) HMNHandler + +// A special handler that runs before the primary handler. Intended to set +// information on the context for later handlers, or to give the request a +// means to bail out early if preconditions are not met (like auth). If `ok` +// is false, the request will immediately bail out, no further handlers will +// be run, and it will respond with the provided response data. +// +// The response data from this function will still be fed through any after +// handlers, to ensure that errors will get logged and whatnot. +type HMNBeforeHandler func(c *RequestContext) (ok bool, res ResponseData) + +// A special handler that runs after the primary handler and can modify the +// response information. Intended for error logging, error pages, +// cleanup, etc. +type HMNAfterHandler func(c *RequestContext, res ResponseData) ResponseData func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { c := NewRequestContext(rw, req, nil) @@ -74,6 +46,7 @@ type RequestContext struct { Req *http.Request PathParams httprouter.Params + Conn *pgxpool.Pool CurrentProject *models.Project CurrentUser *models.User // CurrentMember *models.Member @@ -104,34 +77,6 @@ func (c *RequestContext) GetFormValues() (url.Values, error) { return c.Req.PostForm, nil } -type ResponseData struct { - StatusCode int - Body *bytes.Buffer - Errors []error - - header http.Header -} - -func (rd *ResponseData) Headers() http.Header { - if rd.header == nil { - rd.header = make(http.Header) - } - - return rd.header -} - -func (rd *ResponseData) Write(p []byte) (n int, err error) { - if rd.Body == nil { - rd.Body = new(bytes.Buffer) - } - - return rd.Body.Write(p) -} - -func (rd *ResponseData) SetCookie(cookie *http.Cookie) { - rd.Headers().Add("Set-Cookie", cookie.String()) -} - // The logic of this function is copy-pasted from the Go standard library. // https://golang.org/pkg/net/http/#Redirect func (c *RequestContext) Redirect(dest string, code int) ResponseData { @@ -189,10 +134,88 @@ func (c *RequestContext) Redirect(dest string, code int) ResponseData { return res } +type ResponseData struct { + StatusCode int + Body *bytes.Buffer + Errors []error + + header http.Header +} + +func (rd *ResponseData) Headers() http.Header { + if rd.header == nil { + rd.header = make(http.Header) + } + + return rd.header +} + +func (rd *ResponseData) Write(p []byte) (n int, err error) { + if rd.Body == nil { + rd.Body = new(bytes.Buffer) + } + + return rd.Body.Write(p) +} + +func (rd *ResponseData) SetCookie(cookie *http.Cookie) { + rd.Headers().Add("Set-Cookie", cookie.String()) +} + func (rd *ResponseData) WriteTemplate(name string, data interface{}) error { return templates.Templates[name].Execute(rd, data) } +type RouteBuilder struct { + Router *httprouter.Router + BeforeHandlers []HMNBeforeHandler + AfterHandlers []HMNAfterHandler +} + +func (b RouteBuilder) ChainHandlers(h HMNHandler) HMNHandler { + return func(c *RequestContext) ResponseData { + beforeOk := true + var res ResponseData + for _, before := range b.BeforeHandlers { + if ok, errorRes := before(c); !ok { + beforeOk = false + res = errorRes + } + } + + if beforeOk { + res = h(c) + } + + for _, after := range b.AfterHandlers { + res = after(c, res) + } + return res + } +} + +func (b *RouteBuilder) Handle(method, route string, handler HMNHandler) { + h := b.ChainHandlers(handler) + b.Router.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) { + c := NewRequestContext(rw, req, p) + doRequest(rw, c, h) + }) +} + +func (b *RouteBuilder) GET(route string, handler HMNHandler) { + b.Handle(http.MethodGet, route, handler) +} + +func (b *RouteBuilder) POST(route string, handler HMNHandler) { + b.Handle(http.MethodPost, route, handler) +} + +// TODO: More methods + +func (b *RouteBuilder) ServeFiles(path string, root http.FileSystem) { + b.Router.ServeFiles(path, root) +} + func ErrorResponse(status int, errs ...error) ResponseData { return ResponseData{ StatusCode: status, diff --git a/src/website/routes.go b/src/website/routes.go index 2fe2af5e..777280e3 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -18,44 +18,67 @@ import ( "github.com/julienschmidt/httprouter" ) -type websiteRoutes struct { - *HMNRouter - - conn *pgxpool.Pool -} - func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { - routes := &websiteRoutes{ - HMNRouter: &HMNRouter{ - HttpRouter: httprouter.New(), - Wrappers: []HMNHandlerWrapper{ErrorLoggingWrapper}, + router := httprouter.New() + routes := RouteBuilder{ + Router: router, + BeforeHandlers: []HMNBeforeHandler{ + func(c *RequestContext) (bool, ResponseData) { + c.Conn = conn + return true, ResponseData{} + }, }, - conn: conn, + AfterHandlers: []HMNAfterHandler{ErrorLoggingHandler}, } - mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper) + routes.POST("/login", Login) + routes.GET("/logout", Logout) + routes.ServeFiles("/public/*filepath", http.Dir("public")) + + mainRoutes := routes + mainRoutes.BeforeHandlers = append(mainRoutes.BeforeHandlers, + CommonWebsiteDataWrapper, + ) + mainRoutes.GET("/", func(c *RequestContext) ResponseData { if c.CurrentProject.ID == models.HMNProjectID { - return routes.Index(c) + return Index(c) } else { // TODO: Return the project landing page panic("route not implemented") } }) - mainRoutes.GET("/project/:id", routes.Project) - mainRoutes.GET("/assets/project.css", routes.ProjectCSS) + mainRoutes.GET("/project/:id", Project) + mainRoutes.GET("/assets/project.css", ProjectCSS) - routes.POST("/login", routes.Login) - routes.GET("/logout", routes.Logout) + router.NotFound = mainRoutes.ChainHandlers(FourOhFour) - routes.ServeFiles("/public/*filepath", http.Dir("public")) + adminRoutes := routes + adminRoutes.BeforeHandlers = append(adminRoutes.BeforeHandlers, + func(c *RequestContext) (ok bool, res ResponseData) { + return false, ResponseData{ + StatusCode: http.StatusUnauthorized, + Body: bytes.NewBufferString("No one is allowed!\n"), + } + }, + ) + adminRoutes.AfterHandlers = append(adminRoutes.AfterHandlers, + func(c *RequestContext, res ResponseData) ResponseData { + res.Body.WriteString("Now go away. Sincerely, the after handler.\n") + return res + }, + ) - routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour) + adminRoutes.GET("/admin", func(c *RequestContext) ResponseData { + return ResponseData{ + Body: bytes.NewBufferString("Here are all the secrets.\n"), + } + }) - return routes + return router } -func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData { +func getBaseData(c *RequestContext) templates.BaseData { var templateUser *templates.User if c.CurrentUser != nil { templateUser = &templates.User{ @@ -106,9 +129,9 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (* return defaultProject, nil } -func (s *websiteRoutes) Project(c *RequestContext) ResponseData { +func Project(c *RequestContext) ResponseData { id := c.PathParams.ByName("id") - row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id")) + row := c.Conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id")) var name string err := row.Scan(&name) if err != nil { @@ -121,7 +144,7 @@ func (s *websiteRoutes) Project(c *RequestContext) ResponseData { return res } -func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData { +func ProjectCSS(c *RequestContext) ResponseData { color := c.URL().Query().Get("color") if color == "" { return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n")) @@ -145,7 +168,7 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData { return res } -func (s *websiteRoutes) Login(c *RequestContext) ResponseData { +func Login(c *RequestContext) ResponseData { // TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks. form, err := c.GetFormValues() @@ -164,7 +187,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData { redirect = "/" } - userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username) + userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username) if err != nil { if errors.Is(err, db.ErrNoMatchingRows) { return ResponseData{ @@ -191,7 +214,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData { if hashed.IsOutdated() { newHashed, err := auth.HashPassword(password) if err == nil { - err := auth.UpdatePassword(c.Context(), s.conn, username, newHashed) + err := auth.UpdatePassword(c.Context(), c.Conn, username, newHashed) if err != nil { c.Logger.Error().Err(err).Msg("failed to update user's password") } @@ -201,7 +224,7 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData { // If errors happen here, we can still continue with logging them in } - session, err := auth.CreateSession(c.Context(), s.conn, username) + session, err := auth.CreateSession(c.Context(), c.Conn, username) if err != nil { return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session")) } @@ -215,11 +238,11 @@ func (s *websiteRoutes) Login(c *RequestContext) ResponseData { } } -func (s *websiteRoutes) Logout(c *RequestContext) ResponseData { +func Logout(c *RequestContext) ResponseData { sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) if err == nil { // clear the session from the db immediately, no expiration - err := auth.DeleteSession(c.Context(), s.conn, sessionCookie.Value) + err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value) if err != nil { logging.Error().Err(err).Msg("failed to delete session on logout") } @@ -231,64 +254,58 @@ func (s *websiteRoutes) Logout(c *RequestContext) ResponseData { return res } -func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData { +func FourOhFour(c *RequestContext) ResponseData { return ResponseData{ StatusCode: http.StatusNotFound, - Body: bytes.NewBuffer([]byte("go away\n")), + Body: bytes.NewBufferString("go away\n"), } } -func ErrorLoggingWrapper(h HMNHandler) HMNHandler { - return func(c *RequestContext) ResponseData { - res := h(c) - - for _, err := range res.Errors { - c.Logger.Error().Err(err).Msg("error occurred during request") - } - - return res +func ErrorLoggingHandler(c *RequestContext, res ResponseData) ResponseData { + for _, err := range res.Errors { + c.Logger.Error().Err(err).Msg("error occurred during request") } + + return res } -func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler { - return func(c *RequestContext) ResponseData { - // get project - { - slug := "" - hostParts := strings.SplitN(c.Req.Host, ".", 3) - if len(hostParts) >= 3 { - slug = hostParts[0] - } - - dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug) - if err != nil { - return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) - } - - c.CurrentProject = dbProject +func CommonWebsiteDataWrapper(c *RequestContext) (bool, ResponseData) { + // get project + { + slug := "" + hostParts := strings.SplitN(c.Req.Host, ".", 3) + if len(hostParts) >= 3 { + slug = hostParts[0] } - sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) - if err == nil { - user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value) - if err != nil { - return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) - } - - c.CurrentUser = user + dbProject, err := FetchProjectBySlug(c.Context(), c.Conn, slug) + if err != nil { + return false, ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project")) } - // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. - return h(c) + c.CurrentProject = dbProject } + + sessionCookie, err := c.Req.Cookie(auth.SessionCookieName) + if err == nil { + user, err := getCurrentUserAndMember(c, sessionCookie.Value) + if err != nil { + return false, ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member")) + } + + c.CurrentUser = user + } + // http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here. + + return true, ResponseData{} } // Given a session id, fetches user and member data from the database. Will return nil for // both if neither can be found, and will only return an error if it's serious. // // TODO: actually return members :) -func (s *websiteRoutes) getCurrentUserAndMember(ctx context.Context, sessionId string) (*models.User, error) { - session, err := auth.GetSession(ctx, s.conn, sessionId) +func getCurrentUserAndMember(c *RequestContext, sessionId string) (*models.User, error) { + session, err := auth.GetSession(c.Context(), c.Conn, sessionId) if err != nil { if errors.Is(err, auth.ErrNoSession) { return nil, nil @@ -297,7 +314,7 @@ func (s *websiteRoutes) getCurrentUserAndMember(ctx context.Context, sessionId s } } - userRow, err := db.QueryOne(ctx, s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username) + userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", session.Username) if err != nil { if errors.Is(err, db.ErrNoMatchingRows) { logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")