Get the feed largely complete
This commit is contained in:
parent
ed6ce26697
commit
e7ff342842
|
@ -8570,13 +8570,13 @@ input[type=submit] {
|
|||
.forum-narrow .forum-narrow-hide {
|
||||
display: none; }
|
||||
|
||||
.post-bg-alternate:nth-of-type(odd) {
|
||||
background-color: #f0f0f0;
|
||||
background-color: var(--forum-even-background); }
|
||||
|
||||
.thread {
|
||||
color: black;
|
||||
color: var(--fg-font-color); }
|
||||
.forum .thread:nth-of-type(odd),
|
||||
.feed .thread:nth-of-type(odd) {
|
||||
background-color: #f0f0f0;
|
||||
background-color: var(--forum-even-background); }
|
||||
.profile .thread {
|
||||
padding-left: 15px; }
|
||||
.thread .title {
|
||||
|
|
41
src/db/db.go
41
src/db/db.go
|
@ -212,3 +212,44 @@ func QueryOne(ctx context.Context, conn *pgxpool.Pool, destExample interface{},
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func QueryScalar(ctx context.Context, conn *pgxpool.Pool, query string, args ...interface{}) (interface{}, error) {
|
||||
rows, err := conn.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
vals, err := rows.Values()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if len(vals) != 1 {
|
||||
return nil, oops.New(nil, "you must query exactly one field with QueryScalar, not %v", len(vals))
|
||||
}
|
||||
|
||||
return vals[0], nil
|
||||
}
|
||||
|
||||
return nil, ErrNoMatchingRows
|
||||
}
|
||||
|
||||
func QueryInt(ctx context.Context, conn *pgxpool.Pool, query string, args ...interface{}) (int, error) {
|
||||
result, err := QueryScalar(ctx, conn, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
switch r := result.(type) {
|
||||
case int:
|
||||
return r, nil
|
||||
case int32:
|
||||
return int(r), nil
|
||||
case int64:
|
||||
return int(r), nil
|
||||
default:
|
||||
return 0, oops.New(nil, "QueryInt got a non-int result: %v", result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
)
|
||||
|
||||
type CategoryType int
|
||||
|
||||
const (
|
||||
|
@ -23,5 +30,50 @@ type Category struct {
|
|||
Kind CategoryType `db:"kind"`
|
||||
Color1 string `db:"color_1"`
|
||||
Color2 string `db:"color_2"`
|
||||
Depth int `db:"depth"`
|
||||
Depth int `db:"depth"` // TODO: What is this?
|
||||
}
|
||||
|
||||
func (c *Category) GetParents(ctx context.Context, conn *pgxpool.Pool) []Category {
|
||||
type breadcrumbRow struct {
|
||||
Cat Category `db:"cats"`
|
||||
}
|
||||
rows, err := db.Query(ctx, conn, breadcrumbRow{},
|
||||
`
|
||||
WITH RECURSIVE cats AS (
|
||||
SELECT *
|
||||
FROM handmade_category AS cat
|
||||
WHERE cat.id = $1
|
||||
UNION ALL
|
||||
SELECT parentcat.*
|
||||
FROM
|
||||
handmade_category AS parentcat
|
||||
JOIN cats ON cats.parent_id = parentcat.id
|
||||
)
|
||||
SELECT $columns FROM cats;
|
||||
`,
|
||||
c.ID,
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var result []Category
|
||||
for _, irow := range rows.ToSlice()[1:] {
|
||||
row := irow.(*breadcrumbRow)
|
||||
result = append(result, row.Cat)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// func GetCategoryUrls(cats ...*Category) map[int]string {
|
||||
|
||||
// }
|
||||
|
||||
// func makeCategoryUrl(cat *Category, subdomain string) string {
|
||||
// switch cat.Kind {
|
||||
// case CatTypeBlog:
|
||||
// case CatTypeForum:
|
||||
// }
|
||||
// return hmnurl.ProjectUrl("/flooger", nil, subdomain)
|
||||
// }
|
||||
|
|
|
@ -12,7 +12,7 @@ type Post struct {
|
|||
AuthorID *int `db:"author_id"`
|
||||
CategoryID int `db:"category_id"`
|
||||
ParentID *int `db:"parent_id"`
|
||||
ThreadID *int `db:"thread_id"`
|
||||
ThreadID *int `db:"thread_id"` // TODO: This is only null for posts that are actually static pages. Which probably shouldn't be posts anyway. Plz make not null thanks
|
||||
CurrentID int `db:"current_id"`
|
||||
|
||||
Depth int `db:"depth"`
|
||||
|
|
|
@ -17,21 +17,15 @@
|
|||
display: none;
|
||||
}
|
||||
|
||||
.post-bg-alternate:nth-of-type(odd) {
|
||||
@include usevar('background-color', 'forum-even-background');
|
||||
}
|
||||
|
||||
.thread {
|
||||
@include usevar('color', 'fg-font-color');
|
||||
|
||||
@extend .ma0;
|
||||
|
||||
.forum &:nth-of-type(odd),
|
||||
.feed &:nth-of-type(odd),
|
||||
{
|
||||
@include usevar('background-color', 'forum-even-background');
|
||||
}
|
||||
|
||||
&.more {
|
||||
// background-color: transparent;
|
||||
}
|
||||
|
||||
.profile & {
|
||||
padding-left:15px;
|
||||
}
|
||||
|
|
|
@ -41,6 +41,8 @@ func ProjectToTemplate(p *models.Project) Project {
|
|||
}
|
||||
|
||||
func UserToTemplate(u *models.User) User {
|
||||
// TODO: Handle deleted users. Maybe not here, but if not, at call sites of this function.
|
||||
|
||||
avatar := ""
|
||||
if u.Avatar != nil {
|
||||
avatar = hmnurl.StaticUrl(*u.Avatar, nil)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
{{ template "base.html" . }}
|
||||
|
||||
{{ define "extrahead" }}
|
||||
{{/* TODO
|
||||
<script type="text/javascript" src="{% static 'util.js' %}?v={% cachebust %}"></script>
|
||||
*/}}
|
||||
{{ end }}
|
||||
|
||||
{{ define "content" }}
|
||||
<div class="content-block">
|
||||
<div class="optionbar">
|
||||
<div class="options">
|
||||
<a class="button" href="{{ url "/atom" }}"><span class="icon big">4</span> RSS Feed</span></a>
|
||||
{{ if .User }}
|
||||
<a class="button" href="{{ url "/markread" }}"><span class="big">✓</span> Mark all posts on site as read</a>
|
||||
{{ end }}
|
||||
</div>
|
||||
<div class="options">
|
||||
{{ template "pagination.html" .Pagination }}
|
||||
</div>
|
||||
</div>
|
||||
{{ range .Posts }}
|
||||
{{ template "post_list_item.html" . }}
|
||||
{{ end }}
|
||||
<div class="optionbar bottom">
|
||||
<div>
|
||||
</div>
|
||||
<div class="options">
|
||||
{{ template "pagination.html" .Pagination }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{{ end }}
|
|
@ -0,0 +1,31 @@
|
|||
<div class="pagination">
|
||||
{{ if gt .Current 1 }}
|
||||
<a class="button" href="{{ .PreviousUrl }}">Prev</a>
|
||||
{{ end }}
|
||||
{{ if gt .Total 1 }}
|
||||
{{ if gt .Current 1 }}
|
||||
{{ if gt .Current 2 }}
|
||||
<a class="page button" href="{{ .FirstUrl }}">1</a>
|
||||
{{ end }}
|
||||
{{ if gt .Current 3 }}
|
||||
<a class="page button"> ... </a>
|
||||
{{ end }}
|
||||
<a class="page button" href="{{ .PreviousUrl }}">{{ sub .Current 1 }}</a>
|
||||
{{ end }}
|
||||
<a class="page button current">{{ .Current }}</a>
|
||||
{{ if lt .Current .Total }}
|
||||
<a class="page button" href="{{ .NextUrl }}">{{ add .Current 1 }}</a>
|
||||
{{ if lt .Current (sub .Total 2) }}
|
||||
<a class="page button"> ... </a>
|
||||
{{ end }}
|
||||
{{ if lt .Current (sub .Total 1) }}
|
||||
<a class="page button" href="{{ .LastUrl }}">{{ .Total }}</a>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ else }}
|
||||
|
||||
{{ end }}
|
||||
{{ if lt .Current .Total }}
|
||||
<a class="button" href="{{ .NextUrl }}">Next</a>
|
||||
{{ end }}
|
||||
</div>
|
|
@ -4,7 +4,7 @@ This template is intended to display a single post or thread in the context of a
|
|||
It should be called with PostListItem.
|
||||
*/}}
|
||||
|
||||
<div class="post-list-item flex items-center ph3 pv2 {{ if .Unread }}unread{{ else }}read{{ end }}">
|
||||
<div class="post-list-item flex items-center ph3 pv2 {{ if .Unread }}unread{{ else }}read{{ end }} {{ .Classes }}">
|
||||
<img class="avatar-icon mr2" src="{{ .User.AvatarUrl }}">
|
||||
<div class="flex-grow-1 overflow-hidden">
|
||||
<div class="breadcrumbs">
|
||||
|
@ -17,6 +17,11 @@ It should be called with PostListItem.
|
|||
<div class="details">
|
||||
<a class="user" href="{{ .User.ProfileUrl }}">{{ .User.Name }}</a> — <span class="datetime">{{ relativedate .Date }}</span>
|
||||
</div>
|
||||
{{ with .Content }}
|
||||
<div class="mt2">
|
||||
{{ . }}
|
||||
</div>
|
||||
{{ end }}
|
||||
</div>
|
||||
<div class="goto">
|
||||
<a href="{{ .Url }}">»</a>
|
||||
|
|
|
@ -87,8 +87,20 @@ type PostListItem struct {
|
|||
User User
|
||||
Date time.Time
|
||||
Unread bool
|
||||
Classes string
|
||||
Content string
|
||||
}
|
||||
|
||||
type Breadcrumb struct {
|
||||
Name, Url string
|
||||
}
|
||||
|
||||
type Pagination struct {
|
||||
Current int
|
||||
Total int
|
||||
|
||||
FirstUrl string
|
||||
LastUrl string
|
||||
PreviousUrl string
|
||||
NextUrl string
|
||||
}
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
"git.handmade.network/hmn/hmn/src/models"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
"git.handmade.network/hmn/hmn/src/templates"
|
||||
)
|
||||
|
||||
type FeedData struct {
|
||||
templates.BaseData
|
||||
|
||||
Posts []templates.PostListItem
|
||||
Pagination templates.Pagination
|
||||
}
|
||||
|
||||
func Feed(c *RequestContext) ResponseData {
|
||||
const postsPerPage = 30
|
||||
|
||||
numPosts, err := db.QueryInt(c.Context(), c.Conn,
|
||||
`
|
||||
SELECT COUNT(*)
|
||||
FROM
|
||||
handmade_post AS post
|
||||
JOIN handmade_category AS cat ON cat.id = post.category_id
|
||||
WHERE
|
||||
cat.kind IN ($1, $2, $3, $4)
|
||||
AND NOT moderated
|
||||
`,
|
||||
models.CatTypeForum,
|
||||
models.CatTypeBlog,
|
||||
models.CatTypeWiki,
|
||||
models.CatTypeLibraryResource,
|
||||
) // TODO(inarray)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to get count of feed posts"))
|
||||
}
|
||||
|
||||
numPages := int(math.Ceil(float64(numPosts) / 30))
|
||||
|
||||
page := 1
|
||||
pageString := c.PathParams.ByName("page")
|
||||
if pageString != "" {
|
||||
if pageParsed, err := strconv.Atoi(pageString); err == nil {
|
||||
page = pageParsed
|
||||
} else {
|
||||
return c.Redirect("/feed", http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
if page < 1 || numPages < page {
|
||||
return c.Redirect("/feed", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
howManyPostsToSkip := (page - 1) * postsPerPage
|
||||
|
||||
pagination := templates.Pagination{
|
||||
Current: page,
|
||||
Total: numPages,
|
||||
|
||||
// TODO: urls
|
||||
}
|
||||
|
||||
var currentUserId *int
|
||||
if c.CurrentUser != nil {
|
||||
currentUserId = &c.CurrentUser.ID
|
||||
}
|
||||
|
||||
type feedPostQuery struct {
|
||||
Post models.Post `db:"post"`
|
||||
Thread models.Thread `db:"thread"`
|
||||
Cat models.Category `db:"cat"`
|
||||
Proj models.Project `db:"proj"`
|
||||
User models.User `db:"auth_user"`
|
||||
ThreadLastReadTime *time.Time `db:"tlri.lastread"`
|
||||
CatLastReadTime *time.Time `db:"clri.lastread"`
|
||||
}
|
||||
posts, err := db.Query(c.Context(), c.Conn, feedPostQuery{},
|
||||
`
|
||||
SELECT $columns
|
||||
FROM
|
||||
handmade_post AS post
|
||||
JOIN handmade_thread AS thread ON thread.id = post.thread_id
|
||||
JOIN handmade_category AS cat ON cat.id = thread.category_id
|
||||
JOIN handmade_project AS proj ON proj.id = cat.project_id
|
||||
LEFT OUTER JOIN handmade_threadlastreadinfo AS tlri ON (
|
||||
tlri.thread_id = thread.id
|
||||
AND tlri.user_id = $1
|
||||
)
|
||||
LEFT OUTER JOIN handmade_categorylastreadinfo AS clri ON (
|
||||
clri.category_id = cat.id
|
||||
AND clri.user_id = $1
|
||||
)
|
||||
LEFT OUTER JOIN auth_user ON post.author_id = auth_user.id
|
||||
WHERE
|
||||
cat.kind IN ($2, $3, $4, $5)
|
||||
AND post.moderated = FALSE
|
||||
AND post.thread_id IS NOT NULL
|
||||
ORDER BY postdate DESC
|
||||
LIMIT $6 OFFSET $7
|
||||
`,
|
||||
currentUserId,
|
||||
models.CatTypeForum,
|
||||
models.CatTypeBlog,
|
||||
models.CatTypeWiki,
|
||||
models.CatTypeLibraryResource,
|
||||
postsPerPage,
|
||||
howManyPostsToSkip,
|
||||
) // TODO(inarray)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to fetch feed posts"))
|
||||
}
|
||||
|
||||
var postItems []templates.PostListItem
|
||||
for _, iPostResult := range posts.ToSlice() {
|
||||
postResult := iPostResult.(*feedPostQuery)
|
||||
|
||||
hasRead := false
|
||||
if postResult.ThreadLastReadTime != nil && postResult.ThreadLastReadTime.After(postResult.Post.PostDate) {
|
||||
hasRead = true
|
||||
} else if postResult.CatLastReadTime != nil && postResult.CatLastReadTime.After(postResult.Post.PostDate) {
|
||||
hasRead = true
|
||||
}
|
||||
|
||||
parents := postResult.Cat.GetParents(c.Context(), c.Conn)
|
||||
logging.Debug().Interface("parents", parents).Msg("")
|
||||
|
||||
var breadcrumbs []templates.Breadcrumb
|
||||
breadcrumbs = append(breadcrumbs, templates.Breadcrumb{
|
||||
Name: *postResult.Proj.Name,
|
||||
Url: "nargle", // TODO
|
||||
})
|
||||
for i := len(parents) - 1; i >= 0; i-- {
|
||||
breadcrumbs = append(breadcrumbs, templates.Breadcrumb{
|
||||
Name: *parents[i].Name,
|
||||
Url: "nargle", // TODO
|
||||
})
|
||||
}
|
||||
|
||||
postItems = append(postItems, templates.PostListItem{
|
||||
Title: postResult.Thread.Title,
|
||||
Url: templates.PostUrl(postResult.Post, postResult.Cat.Kind, postResult.Proj.Subdomain()),
|
||||
User: templates.UserToTemplate(&postResult.User),
|
||||
Date: postResult.Post.PostDate,
|
||||
Breadcrumbs: breadcrumbs,
|
||||
Unread: !hasRead,
|
||||
Classes: "post-bg-alternate", // TODO: Should this be the default, and the home page can suppress it?
|
||||
Content: postResult.Post.Preview,
|
||||
})
|
||||
}
|
||||
|
||||
baseData := getBaseData(c)
|
||||
baseData.BodyClasses = append(baseData.BodyClasses, "feed")
|
||||
|
||||
var res ResponseData
|
||||
res.WriteTemplate("feed.html", FeedData{
|
||||
BaseData: baseData,
|
||||
|
||||
Posts: postItems,
|
||||
Pagination: pagination,
|
||||
})
|
||||
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package website
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"git.handmade.network/hmn/hmn/src/auth"
|
||||
"git.handmade.network/hmn/hmn/src/db"
|
||||
"git.handmade.network/hmn/hmn/src/logging"
|
||||
"git.handmade.network/hmn/hmn/src/models"
|
||||
"git.handmade.network/hmn/hmn/src/oops"
|
||||
)
|
||||
|
||||
func Login(c *RequestContext) ResponseData {
|
||||
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
|
||||
|
||||
form, err := c.GetFormValues()
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
|
||||
}
|
||||
|
||||
username := form.Get("username")
|
||||
password := form.Get("password")
|
||||
if username == "" || password == "" {
|
||||
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
|
||||
}
|
||||
|
||||
redirect := form.Get("redirect")
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||
return ResponseData{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
}
|
||||
} else {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||
}
|
||||
}
|
||||
user := userRow.(*models.User)
|
||||
|
||||
hashed, err := auth.ParsePasswordString(user.Password)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
|
||||
}
|
||||
|
||||
passwordsMatch, err := auth.CheckPassword(password, hashed)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
|
||||
}
|
||||
|
||||
if passwordsMatch {
|
||||
// re-hash and save the user's password if necessary
|
||||
if hashed.IsOutdated() {
|
||||
newHashed, err := auth.HashPassword(password)
|
||||
if err == nil {
|
||||
err := auth.UpdatePassword(c.Context(), c.Conn, username, newHashed)
|
||||
if err != nil {
|
||||
c.Logger.Error().Err(err).Msg("failed to update user's password")
|
||||
}
|
||||
} else {
|
||||
c.Logger.Error().Err(err).Msg("failed to re-hash password")
|
||||
}
|
||||
// If errors happen here, we can still continue with logging them in
|
||||
}
|
||||
|
||||
session, err := auth.CreateSession(c.Context(), c.Conn, username)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
|
||||
}
|
||||
|
||||
res := c.Redirect(redirect, http.StatusSeeOther)
|
||||
res.SetCookie(auth.NewSessionCookie(session))
|
||||
|
||||
return res
|
||||
} else {
|
||||
return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
|
||||
}
|
||||
}
|
||||
|
||||
func Logout(c *RequestContext) ResponseData {
|
||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||
if err == nil {
|
||||
// clear the session from the db immediately, no expiration
|
||||
err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value)
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Msg("failed to delete session on logout")
|
||||
}
|
||||
}
|
||||
|
||||
res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
|
||||
res.SetCookie(auth.DeleteSessionCookie)
|
||||
|
||||
return res
|
||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -49,7 +48,9 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
|
|||
panic("route not implemented")
|
||||
}
|
||||
})
|
||||
mainRoutes.GET("/project/:id", Project)
|
||||
mainRoutes.GET("/feed", Feed)
|
||||
mainRoutes.GET("/feed/:page", Feed)
|
||||
|
||||
mainRoutes.GET("/assets/project.css", ProjectCSS)
|
||||
|
||||
router.NotFound = mainRoutes.ChainHandlers(FourOhFour)
|
||||
|
@ -119,21 +120,6 @@ func FetchProjectBySlug(ctx context.Context, conn *pgxpool.Pool, slug string) (*
|
|||
return defaultProject, nil
|
||||
}
|
||||
|
||||
func Project(c *RequestContext) ResponseData {
|
||||
id := c.PathParams.ByName("id")
|
||||
row := c.Conn.QueryRow(context.Background(), "SELECT name FROM handmade_project WHERE id = $1", c.PathParams.ByName("id"))
|
||||
var name string
|
||||
err := row.Scan(&name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var res ResponseData
|
||||
res.Write([]byte(fmt.Sprintf("(%s) %s\n", id, name)))
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func ProjectCSS(c *RequestContext) ResponseData {
|
||||
color := c.URL().Query().Get("color")
|
||||
if color == "" {
|
||||
|
@ -158,92 +144,6 @@ func ProjectCSS(c *RequestContext) ResponseData {
|
|||
return res
|
||||
}
|
||||
|
||||
func Login(c *RequestContext) ResponseData {
|
||||
// TODO: Update this endpoint to give uniform responses on errors and to be resilient to timing attacks.
|
||||
|
||||
form, err := c.GetFormValues()
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "request must contain form data"))
|
||||
}
|
||||
|
||||
username := form.Get("username")
|
||||
password := form.Get("password")
|
||||
if username == "" || password == "" {
|
||||
return ErrorResponse(http.StatusBadRequest, NewSafeError(err, "you must provide both a username and password"))
|
||||
}
|
||||
|
||||
redirect := form.Get("redirect")
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
userRow, err := db.QueryOne(c.Context(), c.Conn, models.User{}, "SELECT $columns FROM auth_user WHERE username = $1", username)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||
return ResponseData{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
}
|
||||
} else {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to look up user by username"))
|
||||
}
|
||||
}
|
||||
user := userRow.(*models.User)
|
||||
|
||||
hashed, err := auth.ParsePasswordString(user.Password)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to parse password string"))
|
||||
}
|
||||
|
||||
passwordsMatch, err := auth.CheckPassword(password, hashed)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to check password against hash"))
|
||||
}
|
||||
|
||||
if passwordsMatch {
|
||||
// re-hash and save the user's password if necessary
|
||||
if hashed.IsOutdated() {
|
||||
newHashed, err := auth.HashPassword(password)
|
||||
if err == nil {
|
||||
err := auth.UpdatePassword(c.Context(), c.Conn, username, newHashed)
|
||||
if err != nil {
|
||||
c.Logger.Error().Err(err).Msg("failed to update user's password")
|
||||
}
|
||||
} else {
|
||||
c.Logger.Error().Err(err).Msg("failed to re-hash password")
|
||||
}
|
||||
// If errors happen here, we can still continue with logging them in
|
||||
}
|
||||
|
||||
session, err := auth.CreateSession(c.Context(), c.Conn, username)
|
||||
if err != nil {
|
||||
return ErrorResponse(http.StatusInternalServerError, oops.New(err, "failed to create session"))
|
||||
}
|
||||
|
||||
res := c.Redirect(redirect, http.StatusSeeOther)
|
||||
res.SetCookie(auth.NewSessionCookie(session))
|
||||
|
||||
return res
|
||||
} else {
|
||||
return c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to standalone login page with error
|
||||
}
|
||||
}
|
||||
|
||||
func Logout(c *RequestContext) ResponseData {
|
||||
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||
if err == nil {
|
||||
// clear the session from the db immediately, no expiration
|
||||
err := auth.DeleteSession(c.Context(), c.Conn, sessionCookie.Value)
|
||||
if err != nil {
|
||||
logging.Error().Err(err).Msg("failed to delete session on logout")
|
||||
}
|
||||
}
|
||||
|
||||
res := c.Redirect("/", http.StatusSeeOther) // TODO: Redirect to the page the user was currently on, or if not authorized to view that page, immediately to the home page.
|
||||
res.SetCookie(auth.DeleteSessionCookie)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func FourOhFour(c *RequestContext) ResponseData {
|
||||
return ResponseData{
|
||||
StatusCode: http.StatusNotFound,
|
||||
|
|
Loading…
Reference in New Issue