Rework request handling
It's a lot simpler now, control flow is easier to work with in handlers, and HMNHandler now natively implements http.Handler, simplifying our 404 setup by completely removing the need for MakeStdHandler.
This commit is contained in:
parent
a6cdbac4c7
commit
98da461d92
|
@ -8,7 +8,6 @@ import (
|
||||||
"git.handmade.network/hmn/hmn/src/models"
|
"git.handmade.network/hmn/hmn/src/models"
|
||||||
"git.handmade.network/hmn/hmn/src/oops"
|
"git.handmade.network/hmn/hmn/src/oops"
|
||||||
"git.handmade.network/hmn/hmn/src/templates"
|
"git.handmade.network/hmn/hmn/src/templates"
|
||||||
"github.com/julienschmidt/httprouter"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LandingTemplateData struct {
|
type LandingTemplateData struct {
|
||||||
|
@ -29,7 +28,7 @@ type LandingPagePost struct {
|
||||||
HasRead bool
|
HasRead bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) Index(c *RequestContext) ResponseData {
|
||||||
const maxPosts = 5
|
const maxPosts = 5
|
||||||
const numProjectsToGet = 7
|
const numProjectsToGet = 7
|
||||||
|
|
||||||
|
@ -38,8 +37,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
|
||||||
models.HMNProjectID,
|
models.HMNProjectID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get projects for home page"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer iterProjects.Close()
|
defer iterProjects.Close()
|
||||||
|
|
||||||
|
@ -132,8 +130,7 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
|
||||||
models.CatTypeBlog,
|
models.CatTypeBlog,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch latest news post"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
newsThread := newsThreadRow.(*newsThreadQuery)
|
newsThread := newsThreadRow.(*newsThreadQuery)
|
||||||
_ = newsThread // TODO: NO
|
_ = newsThread // TODO: NO
|
||||||
|
@ -141,8 +138,11 @@ func (s *websiteRoutes) Index(c *RequestContext, p httprouter.Params) {
|
||||||
baseData := s.getBaseData(c)
|
baseData := s.getBaseData(c)
|
||||||
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
|
baseData.BodyClasses = append(baseData.BodyClasses, "hmdev", "landing") // TODO: Is "hmdev" necessary any more?
|
||||||
|
|
||||||
err = c.WriteTemplate("index.html", s.getBaseData(c))
|
var res ResponseData
|
||||||
|
err = res.WriteTemplate("index.html", s.getBaseData(c))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package website
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"html"
|
"html"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -35,7 +34,11 @@ func (r *HMNRouter) WrapHandler(handler HMNHandler) HMNHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
|
func (r *HMNRouter) Handle(method, route string, handler HMNHandler) {
|
||||||
r.HttpRouter.Handle(method, route, handleHmnHandler(route, r.WrapHandler(handler)))
|
h := r.WrapHandler(handler)
|
||||||
|
r.HttpRouter.Handle(method, route, func(rw http.ResponseWriter, req *http.Request, p httprouter.Params) {
|
||||||
|
c := NewRequestContext(rw, req, p)
|
||||||
|
doRequest(rw, c, h)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *HMNRouter) GET(route string, handler HMNHandler) {
|
func (r *HMNRouter) GET(route string, handler HMNHandler) {
|
||||||
|
@ -58,39 +61,29 @@ func (r *HMNRouter) WithWrappers(wrappers ...HMNHandlerWrapper) *HMNRouter {
|
||||||
return &result
|
return &result
|
||||||
}
|
}
|
||||||
|
|
||||||
type HMNHandler func(c *RequestContext, p httprouter.Params)
|
type HMNHandler func(c *RequestContext) ResponseData
|
||||||
type HMNHandlerWrapper func(h HMNHandler) HMNHandler
|
type HMNHandlerWrapper func(h HMNHandler) HMNHandler
|
||||||
|
|
||||||
func MakeStdHandler(h HMNHandler, name string) http.Handler {
|
func (h HMNHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
c := NewRequestContext(rw, req, nil)
|
||||||
handleHmnHandler(name, h)(rw, req, nil)
|
doRequest(rw, c, h)
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestContext struct {
|
type RequestContext struct {
|
||||||
StatusCode int
|
|
||||||
Body *bytes.Buffer
|
|
||||||
Logger *zerolog.Logger
|
Logger *zerolog.Logger
|
||||||
Req *http.Request
|
Req *http.Request
|
||||||
Errors []error
|
PathParams httprouter.Params
|
||||||
|
|
||||||
rw http.ResponseWriter
|
CurrentProject *models.Project
|
||||||
|
CurrentUser *models.User
|
||||||
currentProject *models.Project
|
// CurrentMember *models.Member
|
||||||
currentUser *models.User
|
|
||||||
// currentMember *models.Member
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext {
|
func NewRequestContext(rw http.ResponseWriter, req *http.Request, pathParams httprouter.Params) *RequestContext {
|
||||||
logger := logging.With().Str("route", route).Logger()
|
|
||||||
|
|
||||||
return &RequestContext{
|
return &RequestContext{
|
||||||
StatusCode: http.StatusOK,
|
Logger: logging.GlobalLogger(),
|
||||||
Body: new(bytes.Buffer),
|
|
||||||
Logger: &logger,
|
|
||||||
Req: req,
|
Req: req,
|
||||||
|
PathParams: pathParams,
|
||||||
rw: rw,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,14 +95,6 @@ func (c *RequestContext) URL() *url.URL {
|
||||||
return c.Req.URL
|
return c.Req.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RequestContext) Headers() http.Header {
|
|
||||||
return c.rw.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RequestContext) SetCookie(cookie *http.Cookie) {
|
|
||||||
c.rw.Header().Add("Set-Cookie", cookie.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RequestContext) GetFormValues() (url.Values, error) {
|
func (c *RequestContext) GetFormValues() (url.Values, error) {
|
||||||
err := c.Req.ParseForm()
|
err := c.Req.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -119,9 +104,39 @@ func (c *RequestContext) GetFormValues() (url.Values, error) {
|
||||||
return c.Req.PostForm, nil
|
return c.Req.PostForm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ResponseData struct {
|
||||||
|
StatusCode int
|
||||||
|
Body *bytes.Buffer
|
||||||
|
Errors []error
|
||||||
|
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rd *ResponseData) Headers() http.Header {
|
||||||
|
if rd.header == nil {
|
||||||
|
rd.header = make(http.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rd.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rd *ResponseData) Write(p []byte) (n int, err error) {
|
||||||
|
if rd.Body == nil {
|
||||||
|
rd.Body = new(bytes.Buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rd.Body.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rd *ResponseData) SetCookie(cookie *http.Cookie) {
|
||||||
|
rd.Headers().Add("Set-Cookie", cookie.String())
|
||||||
|
}
|
||||||
|
|
||||||
// The logic of this function is copy-pasted from the Go standard library.
|
// The logic of this function is copy-pasted from the Go standard library.
|
||||||
// https://golang.org/pkg/net/http/#Redirect
|
// https://golang.org/pkg/net/http/#Redirect
|
||||||
func (c *RequestContext) Redirect(dest string, code int) {
|
func (c *RequestContext) Redirect(dest string, code int) ResponseData {
|
||||||
|
var res ResponseData
|
||||||
|
|
||||||
if u, err := url.Parse(dest); err == nil {
|
if u, err := url.Parse(dest); err == nil {
|
||||||
// If url was relative, make its path absolute by
|
// If url was relative, make its path absolute by
|
||||||
// combining with request path.
|
// combining with request path.
|
||||||
|
@ -156,46 +171,36 @@ func (c *RequestContext) Redirect(dest string, code int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h := c.Headers()
|
|
||||||
|
|
||||||
// RFC 7231 notes that a short HTML body is usually included in
|
|
||||||
// the response because older user agents may not understand 301/307.
|
|
||||||
// Do it only if the request didn't already have a Content-Type header.
|
|
||||||
_, hadCT := h["Content-Type"]
|
|
||||||
|
|
||||||
// Escape stuff
|
// Escape stuff
|
||||||
destUrl, _ := url.Parse(dest)
|
destUrl, _ := url.Parse(dest)
|
||||||
dest = destUrl.String()
|
dest = destUrl.String()
|
||||||
|
|
||||||
h.Set("Location", dest)
|
res.Headers().Set("Location", dest)
|
||||||
if !hadCT && (c.Req.Method == "GET" || c.Req.Method == "HEAD") {
|
if c.Req.Method == "GET" || c.Req.Method == "HEAD" {
|
||||||
h.Set("Content-Type", "text/html; charset=utf-8")
|
res.Headers().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
}
|
}
|
||||||
c.StatusCode = code
|
res.StatusCode = code
|
||||||
|
|
||||||
// Shouldn't send the body for POST or HEAD; that leaves GET.
|
// Shouldn't send the body for POST or HEAD; that leaves GET.
|
||||||
if !hadCT && c.Req.Method == "GET" {
|
if c.Req.Method == "GET" {
|
||||||
body := "<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n"
|
res.Write([]byte("<a href=\"" + html.EscapeString(dest) + "\">" + http.StatusText(code) + "</a>.\n"))
|
||||||
fmt.Fprintln(c.Body, body)
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rd *ResponseData) WriteTemplate(name string, data interface{}) error {
|
||||||
|
return templates.Templates[name].Execute(rd, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorResponse(status int, errs ...error) ResponseData {
|
||||||
|
return ResponseData{
|
||||||
|
StatusCode: status,
|
||||||
|
Errors: errs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RequestContext) WriteTemplate(name string, data interface{}) error {
|
func doRequest(rw http.ResponseWriter, c *RequestContext, h HMNHandler) {
|
||||||
return templates.Templates[name].Execute(c.Body, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RequestContext) AddErrors(errs ...error) {
|
|
||||||
c.Errors = append(c.Errors, errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RequestContext) Errored(status int, errs ...error) {
|
|
||||||
c.StatusCode = status
|
|
||||||
c.AddErrors(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
|
|
||||||
return func(rw http.ResponseWriter, r *http.Request, p httprouter.Params) {
|
|
||||||
c := newRequestContext(rw, r, route)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
/*
|
/*
|
||||||
This panic recovery is the last resort. If you want to render
|
This panic recovery is the last resort. If you want to render
|
||||||
|
@ -207,9 +212,17 @@ func handleHmnHandler(route string, h HMNHandler) httprouter.Handle {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
h(c, p)
|
res := h(c)
|
||||||
|
|
||||||
rw.WriteHeader(c.StatusCode)
|
if res.StatusCode == 0 {
|
||||||
io.Copy(rw, c.Body)
|
res.StatusCode = http.StatusOK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for name, vals := range res.Headers() {
|
||||||
|
for _, val := range vals {
|
||||||
|
rw.Header().Add(name, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rw.WriteHeader(res.StatusCode)
|
||||||
|
io.Copy(rw, res.Body)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package website
|
package website
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -33,11 +34,12 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper)
|
mainRoutes := routes.WithWrappers(routes.CommonWebsiteDataWrapper)
|
||||||
mainRoutes.GET("/", func(c *RequestContext, p httprouter.Params) {
|
mainRoutes.GET("/", func(c *RequestContext) ResponseData {
|
||||||
if c.currentProject.ID == models.HMNProjectID {
|
if c.CurrentProject.ID == models.HMNProjectID {
|
||||||
routes.Index(c, p)
|
return routes.Index(c)
|
||||||
} else {
|
} else {
|
||||||
// TODO: Return the project landing page
|
// TODO: Return the project landing page
|
||||||
|
panic("route not implemented")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
mainRoutes.GET("/project/:id", routes.Project)
|
mainRoutes.GET("/project/:id", routes.Project)
|
||||||
|
@ -48,29 +50,29 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
|
||||||
|
|
||||||
routes.ServeFiles("/public/*filepath", http.Dir("public"))
|
routes.ServeFiles("/public/*filepath", http.Dir("public"))
|
||||||
|
|
||||||
routes.HttpRouter.NotFound = MakeStdHandler(mainRoutes.WrapHandler(routes.FourOhFour), "404")
|
routes.HttpRouter.NotFound = mainRoutes.WrapHandler(routes.FourOhFour)
|
||||||
|
|
||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
|
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
|
||||||
var templateUser *templates.User
|
var templateUser *templates.User
|
||||||
if c.currentUser != nil {
|
if c.CurrentUser != nil {
|
||||||
templateUser = &templates.User{
|
templateUser = &templates.User{
|
||||||
Username: c.currentUser.Username,
|
Username: c.CurrentUser.Username,
|
||||||
Email: c.currentUser.Email,
|
Email: c.CurrentUser.Email,
|
||||||
IsSuperuser: c.currentUser.IsSuperuser,
|
IsSuperuser: c.CurrentUser.IsSuperuser,
|
||||||
IsStaff: c.currentUser.IsStaff,
|
IsStaff: c.CurrentUser.IsStaff,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return templates.BaseData{
|
return templates.BaseData{
|
||||||
Project: templates.Project{
|
Project: templates.Project{
|
||||||
Name: *c.currentProject.Name,
|
Name: *c.CurrentProject.Name,
|
||||||
Subdomain: *c.currentProject.Slug,
|
Subdomain: *c.CurrentProject.Slug,
|
||||||
Color: c.currentProject.Color1,
|
Color: c.CurrentProject.Color1,
|
||||||
|
|
||||||
IsHMN: c.currentProject.IsHMN(),
|
IsHMN: c.CurrentProject.IsHMN(),
|
||||||
|
|
||||||
HasBlog: true,
|
HasBlog: true,
|
||||||
HasForum: true,
|
HasForum: true,
|
||||||
|
@ -104,24 +106,25 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
|
||||||
return defaultProject, nil
|
return defaultProject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) Project(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) Project(c *RequestContext) ResponseData {
|
||||||
id := p.ByName("id")
|
id := c.PathParams.ByName("id")
|
||||||
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", p.ByName("id"))
|
row := s.conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
|
||||||
var name string
|
var name string
|
||||||
err := row.Scan(&name)
|
err := row.Scan(&name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Body.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
|
var res ResponseData
|
||||||
|
res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) ProjectCSS(c *RequestContext) ResponseData {
|
||||||
color := c.URL().Query().Get("color")
|
color := c.URL().Query().Get("color")
|
||||||
if color == "" {
|
if color == "" {
|
||||||
c.StatusCode = http.StatusBadRequest
|
return ErrorResponse(http.StatusBadRequest, NewSafeError(nil, "You must provide a 'color' parameter.\n"))
|
||||||
c.Body.Write([]byte("You must provide a 'color' parameter.\n"))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
templateData := struct {
|
templateData := struct {
|
||||||
|
@ -132,27 +135,28 @@ func (s *websiteRoutes) ProjectCSS(c *RequestContext, p httprouter.Params) {
|
||||||
Theme: "dark",
|
Theme: "dark",
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Headers().Add("Content-Type", "text/css")
|
var res ResponseData
|
||||||
err := c.WriteTemplate("project.css", templateData)
|
res.Headers().Add("Content-Type", "text/css")
|
||||||
|
err := res.WriteTemplate("project.css", templateData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error().Err(err).Msg("failed to generate project CSS")
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to generate project CSS"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) Login(c *RequestContext) ResponseData {
|
||||||
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
|
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
|
||||||
|
|
||||||
form, err := c.GetFormValues()
|
form, err := c.GetFormValues()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
|
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
username := form.Get("username")
|
username := form.Get("username")
|
||||||
password := form.Get("password")
|
password := form.Get("password")
|
||||||
if username == "" || password == "" {
|
if username == "" || password == "" {
|
||||||
c.Errored(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
|
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
|
||||||
}
|
}
|
||||||
|
|
||||||
redirect := form.Get("redirect")
|
redirect := form.Get("redirect")
|
||||||
|
@ -163,24 +167,23 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
|
||||||
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
|
userRow, err := db.QueryOne(c.Context(), s.conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNoMatchingRows) {
|
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||||
c.StatusCode = http.StatusUnauthorized
|
return ResponseData{
|
||||||
|
StatusCode: http.StatusUnauthorized,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
user := userRow.(models.User)
|
user := userRow.(*models.User)
|
||||||
|
|
||||||
hashed, err := auth.ParsePasswordString(user.Password)
|
hashed, err := auth.ParsePasswordString(user.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
passwordsMatch, err := auth.CheckPassword(password, hashed)
|
passwordsMatch, err := auth.CheckPassword(password, hashed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if passwordsMatch {
|
if passwordsMatch {
|
||||||
|
@ -200,20 +203,19 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
|
||||||
|
|
||||||
session, err := auth.CreateSession(c.Context(), s.conn, username)
|
session, err := auth.CreateSession(c.Context(), s.conn, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.NewSessionCookie(session))
|
res := c.Redirect(redirect, http.StatusSeeOther)
|
||||||
c.Redirect(redirect, http.StatusSeeOther)
|
res.SetCookie(auth.NewSessionCookie(session))
|
||||||
return
|
|
||||||
|
return res
|
||||||
} else {
|
} else {
|
||||||
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
|
return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) Logout(c *RequestContext) ResponseData {
|
||||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// clear the session from the db immediately, no expiration
|
// clear the session from the db immediately, no expiration
|
||||||
|
@ -223,27 +225,33 @@ func (s *websiteRoutes) Logout(c *RequestContext, p httprouter.Params) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.DeleteSessionCookie)
|
res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
|
||||||
c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
|
res.SetCookie(auth.DeleteSessionCookie)
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) FourOhFour(c *RequestContext, p httprouter.Params) {
|
func (s *websiteRoutes) FourOhFour(c *RequestContext) ResponseData {
|
||||||
c.StatusCode = http.StatusNotFound
|
return ResponseData{
|
||||||
c.Body.Write([]byte("go away\n"))
|
StatusCode: http.StatusNotFound,
|
||||||
|
Body: bytes.NewBuffer([]byte("go away\n")),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
|
func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
|
||||||
return func(c *RequestContext, p httprouter.Params) {
|
return func(c *RequestContext) ResponseData {
|
||||||
h(c, p)
|
res := h(c)
|
||||||
|
|
||||||
for _, err := range c.Errors {
|
for _, err := range res.Errors {
|
||||||
c.Logger.Error().Err(err).Msg("error occurred during request")
|
c.Logger.Error().Err(err).Msg("error occurred during request")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
||||||
return func(c *RequestContext, p httprouter.Params) {
|
return func(c *RequestContext) ResponseData {
|
||||||
// get project
|
// get project
|
||||||
{
|
{
|
||||||
slug := ""
|
slug := ""
|
||||||
|
@ -254,26 +262,24 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
||||||
|
|
||||||
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug)
|
dbProject, err := FetchProjectBySlug(c.Context(), s.conn, slug)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch current project"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.currentProject = dbProject
|
c.CurrentProject = dbProject
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
|
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
|
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.currentUser = user
|
c.CurrentUser = user
|
||||||
}
|
}
|
||||||
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
|
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
|
||||||
|
|
||||||
h(c, p)
|
return h(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue