Add migration system and initial migration
This commit is contained in:
parent
35291d609f
commit
c69bf04785
|
@ -1,25 +0,0 @@
|
||||||
package migration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
registerMigration(CreateEverything{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type CreateEverything struct{}
|
|
||||||
|
|
||||||
func (m CreateEverything) Date() time.Time {
|
|
||||||
return time.Date(2021, 3, 9, 6, 53, 0, 0, time.UTC)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m CreateEverything) Up(conn *pgx.Conn) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m CreateEverything) Down(conn *pgx.Conn) {
|
|
||||||
|
|
||||||
}
|
|
|
@ -2,45 +2,257 @@ package migration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.handmade.network/hmn/hmn/db"
|
"git.handmade.network/hmn/hmn/db"
|
||||||
|
"git.handmade.network/hmn/hmn/migration/migrations"
|
||||||
|
"git.handmade.network/hmn/hmn/migration/types"
|
||||||
"git.handmade.network/hmn/hmn/website"
|
"git.handmade.network/hmn/hmn/website"
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
var migrations map[time.Time]Migration = make(map[time.Time]Migration)
|
var listMigrations bool
|
||||||
|
|
||||||
type Migration interface {
|
|
||||||
Date() time.Time
|
|
||||||
Up(conn *pgx.Conn)
|
|
||||||
Down(conn *pgx.Conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func registerMigration(m Migration) {
|
|
||||||
migrations[m.Date()] = m
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
migrateCommand := &cobra.Command{
|
migrateCommand := &cobra.Command{
|
||||||
Use: "migrate",
|
Use: "migrate [target migration id]",
|
||||||
Short: "Run database migrations",
|
Short: "Run database migrations",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
Migrate()
|
if listMigrations {
|
||||||
|
ListMigrations()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetVersion := time.Time{}
|
||||||
|
if len(args) > 0 {
|
||||||
|
var err error
|
||||||
|
targetVersion, err = time.Parse(time.RFC3339, args[0])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("ERROR: bad version string: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Migrate(types.MigrationVersion(targetVersion))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
migrateCommand.Flags().BoolVar(&listMigrations, "list", false, "List available migrations")
|
||||||
|
|
||||||
|
makeMigrationCommand := &cobra.Command{
|
||||||
|
Use: "makemigration <name> <description>...",
|
||||||
|
Short: "Create a new database migration file",
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
if len(args) < 2 {
|
||||||
|
fmt.Println("You must provide a name and a description.")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
name := args[0]
|
||||||
|
description := strings.Join(args[1:], " ")
|
||||||
|
|
||||||
|
MakeMigration(name, description)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
website.WebsiteCommand.AddCommand(migrateCommand)
|
website.WebsiteCommand.AddCommand(migrateCommand)
|
||||||
|
website.WebsiteCommand.AddCommand(makeMigrationCommand)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Migrate() {
|
func getSortedMigrationVersions() []types.MigrationVersion {
|
||||||
|
var allVersions []types.MigrationVersion
|
||||||
|
for migrationTime, _ := range migrations.All {
|
||||||
|
allVersions = append(allVersions, migrationTime)
|
||||||
|
}
|
||||||
|
sort.Slice(allVersions, func(i, j int) bool {
|
||||||
|
return allVersions[i].Before(allVersions[j])
|
||||||
|
})
|
||||||
|
|
||||||
|
return allVersions
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCurrentVersion(conn *pgx.Conn) (types.MigrationVersion, error) {
|
||||||
|
var currentVersion time.Time
|
||||||
|
row := conn.QueryRow(context.Background(), "SELECT version FROM hmn_migration")
|
||||||
|
err := row.Scan(¤tVersion)
|
||||||
|
if err != nil {
|
||||||
|
return types.MigrationVersion{}, err
|
||||||
|
}
|
||||||
|
currentVersion = currentVersion.UTC()
|
||||||
|
|
||||||
|
return types.MigrationVersion(currentVersion), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMigrations() {
|
||||||
conn := db.NewConn()
|
conn := db.NewConn()
|
||||||
defer conn.Close(context.Background())
|
defer conn.Close(context.Background())
|
||||||
|
|
||||||
// check for existence of database??
|
currentVersion, _ := getCurrentVersion(conn)
|
||||||
|
for _, version := range getSortedMigrationVersions() {
|
||||||
|
migration := migrations.All[version]
|
||||||
|
indicator := " "
|
||||||
|
if version.Equal(currentVersion) {
|
||||||
|
indicator = "✔ "
|
||||||
|
}
|
||||||
|
fmt.Printf("%s%v (%s: %s)\n", indicator, version, migration.Name(), migration.Description())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check for migration data, create it if missing
|
func Migrate(targetVersion types.MigrationVersion) {
|
||||||
|
conn := db.NewConn()
|
||||||
|
defer conn.Close(context.Background())
|
||||||
|
|
||||||
|
// create migration table
|
||||||
|
_, err := conn.Exec(context.Background(), `
|
||||||
|
CREATE TABLE IF NOT EXISTS hmn_migration (
|
||||||
|
version TIMESTAMP WITH TIME ZONE
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to create migration table: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure there is a row
|
||||||
|
row := conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM hmn_migration")
|
||||||
|
var numRows int
|
||||||
|
err = row.Scan(&numRows)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if numRows < 1 {
|
||||||
|
_, err := conn.Exec(context.Background(), "INSERT INTO hmn_migration (version) VALUES ($1)", time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to insert initial migration row: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// run migrations
|
// run migrations
|
||||||
|
currentVersion, err := getCurrentVersion(conn)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to get current version: %w", err))
|
||||||
|
}
|
||||||
|
if currentVersion.IsZero() {
|
||||||
|
fmt.Println("This is the first time you have run database migrations.")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Current version: %s\n", currentVersion.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
allVersions := getSortedMigrationVersions()
|
||||||
|
if targetVersion.IsZero() {
|
||||||
|
targetVersion = allVersions[len(allVersions)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
currentIndex := -1
|
||||||
|
targetIndex := -1
|
||||||
|
for i, version := range allVersions {
|
||||||
|
if currentVersion.Equal(version) {
|
||||||
|
currentIndex = i
|
||||||
|
}
|
||||||
|
if targetVersion.Equal(version) {
|
||||||
|
targetIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetIndex < 0 {
|
||||||
|
fmt.Printf("ERROR: Could not find migration with version %v\n", targetVersion)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentIndex < targetIndex {
|
||||||
|
// roll forward
|
||||||
|
for i := currentIndex + 1; i <= targetIndex; i++ {
|
||||||
|
version := allVersions[i]
|
||||||
|
fmt.Printf("Applying migration %v\n", version)
|
||||||
|
migration := migrations.All[version]
|
||||||
|
|
||||||
|
tx, err := conn.Begin(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to start transaction: %w", err))
|
||||||
|
}
|
||||||
|
defer tx.Rollback(context.Background())
|
||||||
|
|
||||||
|
err = migration.Up(tx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("MIGRATION FAILED for migration %v.\n", version)
|
||||||
|
fmt.Printf("Error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(context.Background(), "UPDATE hmn_migration SET version = $1", version)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to update version in migrations table: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to commit transaction: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if currentIndex > targetIndex {
|
||||||
|
// roll back
|
||||||
|
for i := currentIndex; i > targetIndex; i-- {
|
||||||
|
version := allVersions[i]
|
||||||
|
previousVersion := types.MigrationVersion{}
|
||||||
|
if i > 0 {
|
||||||
|
previousVersion = allVersions[i-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := conn.Begin(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to start transaction: %w", err))
|
||||||
|
}
|
||||||
|
defer tx.Rollback(context.Background())
|
||||||
|
|
||||||
|
fmt.Printf("Rolling back migration %v\n", version)
|
||||||
|
migration := migrations.All[version]
|
||||||
|
err = migration.Down(tx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("MIGRATION FAILED for migration %v.\n", version)
|
||||||
|
fmt.Printf("Error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(context.Background(), "UPDATE hmn_migration SET version = $1", previousVersion)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to update version in migrations table: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to commit transaction: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("Already migrated; nothing to do.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed migrationTemplate.txt
|
||||||
|
var migrationTemplate string
|
||||||
|
|
||||||
|
func MakeMigration(name, description string) {
|
||||||
|
result := migrationTemplate
|
||||||
|
result = strings.ReplaceAll(result, "%NAME%", name)
|
||||||
|
result = strings.ReplaceAll(result, "%DESCRIPTION%", fmt.Sprintf("%#v", description))
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
nowConstructor := fmt.Sprintf("time.Date(%d, %d, %d, %d, %d, %d, 0, time.UTC)", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second())
|
||||||
|
result = strings.ReplaceAll(result, "%DATE%", nowConstructor)
|
||||||
|
|
||||||
|
safeVersion := strings.ReplaceAll(types.MigrationVersion(now).String(), ":", "")
|
||||||
|
filename := fmt.Sprintf("%v_%v.go", safeVersion, name)
|
||||||
|
path := filepath.Join("migration", "migrations", filename)
|
||||||
|
|
||||||
|
err := os.WriteFile(path, []byte(result), 0644)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to write migration file: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Successfully created migration file:")
|
||||||
|
fmt.Println(path)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.handmade.network/hmn/hmn/migration/types"
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registerMigration(%NAME%{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type %NAME% struct{}
|
||||||
|
|
||||||
|
func (m %NAME%) Version() types.MigrationVersion {
|
||||||
|
return types.MigrationVersion(%DATE%)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m %NAME%) Name() string {
|
||||||
|
return "%NAME%"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m %NAME%) Description() string {
|
||||||
|
return %DESCRIPTION%
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m %NAME%) Up(tx pgx.Tx) error {
|
||||||
|
panic("Implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m %NAME%) Down(tx pgx.Tx) error {
|
||||||
|
panic("Implement me")
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,9 @@
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import "git.handmade.network/hmn/hmn/migration/types"
|
||||||
|
|
||||||
|
var All map[types.MigrationVersion]types.Migration = make(map[types.MigrationVersion]types.Migration)
|
||||||
|
|
||||||
|
func registerMigration(m types.Migration) {
|
||||||
|
All[m.Version()] = m
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Migration interface {
|
||||||
|
Version() MigrationVersion
|
||||||
|
Name() string
|
||||||
|
Description() string
|
||||||
|
Up(conn pgx.Tx) error
|
||||||
|
Down(conn pgx.Tx) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type MigrationVersion time.Time
|
||||||
|
|
||||||
|
func (v MigrationVersion) String() string {
|
||||||
|
return time.Time(v).Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v MigrationVersion) Before(other MigrationVersion) bool {
|
||||||
|
return time.Time(v).Before(time.Time(other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v MigrationVersion) Equal(other MigrationVersion) bool {
|
||||||
|
return time.Time(v).Equal(time.Time(other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v MigrationVersion) IsZero() bool {
|
||||||
|
return time.Time(v).IsZero()
|
||||||
|
}
|
Loading…
Reference in New Issue