Get login working
This commit is contained in:
parent
56cd737203
commit
cdfe02726c
|
@ -0,0 +1,79 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.handmade.network/hmn/hmn/src/config"
|
||||||
|
"git.handmade.network/hmn/hmn/src/db"
|
||||||
|
"git.handmade.network/hmn/hmn/src/models"
|
||||||
|
"git.handmade.network/hmn/hmn/src/oops"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
const SessionCookieName = "HMNSession"
|
||||||
|
|
||||||
|
const sessionDuration = time.Hour * 24 * 14
|
||||||
|
|
||||||
|
func makeSessionId() string {
|
||||||
|
idBytes := make([]byte, 40)
|
||||||
|
_, err := io.ReadFull(rand.Reader, idBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(idBytes)[:40]
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNoSession = errors.New("no session found")
|
||||||
|
|
||||||
|
func GetSession(ctx context.Context, conn *pgxpool.Pool, id string) (*models.Session, error) {
|
||||||
|
var sess models.Session
|
||||||
|
err := db.QueryOneToStruct(ctx, conn, &sess, "SELECT $columns FROM sessions WHERE id = $1", id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||||
|
return nil, ErrNoSession
|
||||||
|
} else {
|
||||||
|
return nil, oops.New(err, "failed to get session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateSession(ctx context.Context, conn *pgxpool.Pool, username string) (*models.Session, error) {
|
||||||
|
session := models.Session{
|
||||||
|
ID: makeSessionId(),
|
||||||
|
Username: username,
|
||||||
|
ExpiresAt: time.Now().Add(sessionDuration),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := conn.Exec(ctx,
|
||||||
|
"INSERT INTO sessions (id, username, expires_at) VALUES ($1, $2, $3)",
|
||||||
|
session.ID, session.Username, session.ExpiresAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oops.New(err, "failed to persist session")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionCookie(session *models.Session) *http.Cookie {
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: SessionCookieName,
|
||||||
|
Value: session.ID,
|
||||||
|
|
||||||
|
Domain: config.Config.Auth.CookieDomain,
|
||||||
|
Expires: time.Now().Add(sessionDuration),
|
||||||
|
|
||||||
|
Secure: config.Config.Auth.CookieSecure,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteDefaultMode,
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,49 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/src/config"
|
|
||||||
"git.handmade.network/hmn/hmn/src/oops"
|
|
||||||
)
|
|
||||||
|
|
||||||
const AuthCookieName = "HMNToken"
|
|
||||||
|
|
||||||
type Token struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: ENCRYPT THIS
|
|
||||||
|
|
||||||
func EncodeToken(token Token) string {
|
|
||||||
tokenBytes, _ := json.Marshal(token)
|
|
||||||
return string(tokenBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeToken(tokenStr string) (Token, error) {
|
|
||||||
var token Token
|
|
||||||
err := json.Unmarshal([]byte(tokenStr), &token)
|
|
||||||
if err != nil {
|
|
||||||
// TODO: Is this worthy of an oops error, or should this just be a value handled silently by code?
|
|
||||||
return Token{}, oops.New(err, "failed to unmarshal token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthCookie(username string) *http.Cookie {
|
|
||||||
return &http.Cookie{
|
|
||||||
Name: AuthCookieName,
|
|
||||||
Value: EncodeToken(Token{
|
|
||||||
Username: username,
|
|
||||||
}),
|
|
||||||
|
|
||||||
Domain: config.Config.CookieDomain,
|
|
||||||
// TODO: Path?
|
|
||||||
|
|
||||||
// Secure: true,
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteDefaultMode,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,8 +19,7 @@ type HMNConfig struct {
|
||||||
Addr string
|
Addr string
|
||||||
BaseUrl string
|
BaseUrl string
|
||||||
Postgres PostgresConfig
|
Postgres PostgresConfig
|
||||||
CookieDomain string
|
Auth AuthConfig
|
||||||
TokenSecret string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PostgresConfig struct {
|
type PostgresConfig struct {
|
||||||
|
@ -32,6 +31,11 @@ type PostgresConfig struct {
|
||||||
LogLevel pgx.LogLevel
|
LogLevel pgx.LogLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthConfig struct {
|
||||||
|
CookieDomain string
|
||||||
|
CookieSecure bool
|
||||||
|
}
|
||||||
|
|
||||||
func (info PostgresConfig) DSN() string {
|
func (info PostgresConfig) DSN() string {
|
||||||
return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName)
|
return fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s", info.User, info.Password, info.Hostname, info.Port, info.DbName)
|
||||||
}
|
}
|
||||||
|
|
|
@ -246,7 +246,7 @@ func MakeMigration(name, description string) {
|
||||||
|
|
||||||
safeVersion := strings.ReplaceAll(types.MigrationVersion(now).String(), ":", "")
|
safeVersion := strings.ReplaceAll(types.MigrationVersion(now).String(), ":", "")
|
||||||
filename := fmt.Sprintf("%v_%v.go", safeVersion, name)
|
filename := fmt.Sprintf("%v_%v.go", safeVersion, name)
|
||||||
path := filepath.Join("migration", "migrations", filename)
|
path := filepath.Join("src", "migration", "migrations", filename)
|
||||||
|
|
||||||
err := os.WriteFile(path, []byte(result), 0644)
|
err := os.WriteFile(path, []byte(result), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.handmade.network/hmn/hmn/src/migration/types"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registerMigration(AddSessionTable{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type AddSessionTable struct{}
|
||||||
|
|
||||||
|
func (m AddSessionTable) Version() types.MigrationVersion {
|
||||||
|
return types.MigrationVersion(time.Date(2021, 3, 26, 3, 38, 34, 0, time.UTC))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m AddSessionTable) Name() string {
|
||||||
|
return "AddSessionTable"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m AddSessionTable) Description() string {
|
||||||
|
return "Adds a session table to replace the Django session table"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m AddSessionTable) Up(tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(context.Background(), `
|
||||||
|
CREATE TABLE sessions (
|
||||||
|
id VARCHAR(40) PRIMARY KEY,
|
||||||
|
username VARCHAR(150) NOT NULL,
|
||||||
|
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m AddSessionTable) Down(tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(context.Background(), `
|
||||||
|
DROP TABLE sessions;
|
||||||
|
`)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
Username string `db:"username"`
|
||||||
|
ExpiresAt time.Time `db:"expires_at"`
|
||||||
|
}
|
|
@ -67,6 +67,8 @@ type RequestContext struct {
|
||||||
rw http.ResponseWriter
|
rw http.ResponseWriter
|
||||||
|
|
||||||
currentProject *models.Project
|
currentProject *models.Project
|
||||||
|
currentUser *models.User
|
||||||
|
// currentMember *models.Member
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext {
|
func newRequestContext(rw http.ResponseWriter, req *http.Request, route string) *RequestContext {
|
||||||
|
|
|
@ -45,6 +45,16 @@ func NewWebsiteRoutes(conn *pgxpool.Pool) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
|
func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
|
||||||
|
var templateUser *templates.User
|
||||||
|
if c.currentUser != nil {
|
||||||
|
templateUser = &templates.User{
|
||||||
|
Username: c.currentUser.Username,
|
||||||
|
Email: c.currentUser.Email,
|
||||||
|
IsSuperuser: c.currentUser.IsSuperuser,
|
||||||
|
IsStaff: c.currentUser.IsStaff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return templates.BaseData{
|
return templates.BaseData{
|
||||||
Project: templates.Project{
|
Project: templates.Project{
|
||||||
Name: c.currentProject.Name,
|
Name: c.currentProject.Name,
|
||||||
|
@ -58,6 +68,7 @@ func (s *websiteRoutes) getBaseData(c *RequestContext) templates.BaseData {
|
||||||
HasWiki: true,
|
HasWiki: true,
|
||||||
HasLibrary: true,
|
HasLibrary: true,
|
||||||
},
|
},
|
||||||
|
User: templateUser,
|
||||||
Theme: "dark",
|
Theme: "dark",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -171,8 +182,13 @@ func (s *websiteRoutes) Login(c *RequestContext, p httprouter.Params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if passwordsMatch {
|
if passwordsMatch {
|
||||||
logging.Debug().Str("cookie", auth.NewAuthCookie(username).String()).Msg("logged in")
|
session, err := auth.CreateSession(c.Context(), s.conn, username)
|
||||||
c.SetCookie(auth.NewAuthCookie(username))
|
if err != nil {
|
||||||
|
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to create session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie(auth.NewSessionCookie(session))
|
||||||
c.Redirect(redirect, http.StatusSeeOther)
|
c.Redirect(redirect, http.StatusSeeOther)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
|
@ -193,6 +209,8 @@ func ErrorLoggingWrapper(h HMNHandler) HMNHandler {
|
||||||
|
|
||||||
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
||||||
return func(c *RequestContext, p httprouter.Params) {
|
return func(c *RequestContext, p httprouter.Params) {
|
||||||
|
// get project
|
||||||
|
{
|
||||||
slug := ""
|
slug := ""
|
||||||
hostParts := strings.SplitN(c.Req.Host, ".", 3)
|
hostParts := strings.SplitN(c.Req.Host, ".", 3)
|
||||||
if len(hostParts) >= 3 {
|
if len(hostParts) >= 3 {
|
||||||
|
@ -206,7 +224,50 @@ func (s *websiteRoutes) CommonWebsiteDataWrapper(h HMNHandler) HMNHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.currentProject = dbProject
|
c.currentProject = dbProject
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionCookie, err := c.Req.Cookie(auth.SessionCookieName)
|
||||||
|
if err == nil {
|
||||||
|
user, err := s.getCurrentUserAndMember(c.Context(), sessionCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
c.Errored(http.StatusInternalServerError, oops.New(err, "failed to get current user and member"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.currentUser = user
|
||||||
|
}
|
||||||
|
// http.ErrNoCookie is the only error Cookie ever returns, so no further handling to do here.
|
||||||
|
|
||||||
h(c, p)
|
h(c, p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given a session id, fetches user and member data from the database. Will return nil for
|
||||||
|
// both if neither can be found, and will only return an error if it's serious.
|
||||||
|
//
|
||||||
|
// TODO: actually return members :)
|
||||||
|
func (s *websiteRoutes) getCurrentUserAndMember(ctx context.Context, sessionId string) (*models.User, error) {
|
||||||
|
session, err := auth.GetSession(ctx, s.conn, sessionId)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, auth.ErrNoSession) {
|
||||||
|
return nil, nil
|
||||||
|
} else {
|
||||||
|
return nil, oops.New(err, "failed to get current session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var user models.User
|
||||||
|
err = db.QueryOneToStruct(ctx, s.conn, &user, "SELECT $columns FROM auth_user WHERE username = $1", session.Username)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrNoMatchingRows) {
|
||||||
|
logging.Debug().Str("username", session.Username).Msg("returning no current user for this request because the user for the session couldn't be found")
|
||||||
|
return nil, nil // user was deleted or something
|
||||||
|
} else {
|
||||||
|
return nil, oops.New(err, "failed to get user for session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Also get the member model
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue