diff --git a/src/website/requesthandling.go b/src/website/requesthandling.go index 85f9fc3b..dc06c6b9 100644 --- a/src/website/requesthandling.go +++ b/src/website/requesthandling.go @@ -27,11 +27,15 @@ func (r *HMNRouter) ServeHTTP(rw http.ResponseWriter, req *http.Request) { r.HttpRouter.ServeHTTP(rw, req) } -func (r *HMNRouter) Handle(method, route string, handler HMNHandler) { +func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler { for i := len(r.Wrappers) - 1; i >= 0; i-- { handler = r.Wrappers[i](handler) } - r.HttpRouter.Handle(method, route, handleHmnHandler(route, handler)) + return handler +} + +func (r *HMNRouter) Handle(method, route string, handler HMNHandler) { + r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler))) } func (r *HMNRouter) GET(route string, handler HMNHandler) { @@ -57,6 +61,12 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter { type HMNHandler func(c *RequestContext, p httprouter.Params) type HMNHandlerWrapper func(h HMNHandler) HMNHandler +func MakeStdHandler(h HMNHandler, name string) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + handleHmnHandler(name, h)(rw, req, nil) + }) +} + type RequestContext struct { StatusCode int Body *bytes.Buffer diff --git a/src/website/routes.go b/src/website/routes.go index fc92ad93..272ebcb2 100644 --- a/src/website/routes.go +++ b/src/website/routes.go @@ -48,6 +48,8 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler { routes.ServeFiles("/public/*filepath", http.Dir("public")) + routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404") + return routes } @@ -225,6 +227,11 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) { c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page. } +func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) { + c.StatusCode = http.StatusNotFound + c.Body.Write([]byte("go away\n")) +} + func ErrorLoggingWrapper(h HMNHandler) HMNHandler { return func(c *RequestContext, p httprouter.Params) { h(c, p)