package migration import ( "context" _ "embed" "fmt" "os" "path/filepath" "sort" "strings" "time" "git.handmade.network/hmn/hmn/db" "git.handmade.network/hmn/hmn/migration/migrations" "git.handmade.network/hmn/hmn/migration/types" "git.handmade.network/hmn/hmn/website" "github.com/jackc/pgx/v4" "github.com/spf13/cobra" ) var listMigrations bool func init() { migrateCommand := &cobra.Command{ Use: "migrate [target migration id]", Short: "Run database migrations", Run: func(cmd *cobra.Command, args []string) { if listMigrations { ListMigrations() return } targetVersion := time.Time{} if len(args) > 0 { var err error targetVersion, err = time.Parse(time.RFC3339, args[0]) if err != nil { fmt.Printf("ERROR: bad version string: %v", err) os.Exit(1) } } Migrate(types.MigrationVersion(targetVersion)) }, } migrateCommand.Flags().BoolVar(&listMigrations, "list", false, "List available migrations") makeMigrationCommand := &cobra.Command{ Use: "makemigration ...", Short: "Create a new database migration file", Run: func(cmd *cobra.Command, args []string) { if len(args) < 2 { fmt.Println("You must provide a name and a description.") os.Exit(1) } name := args[0] description := strings.Join(args[1:], " ") MakeMigration(name, description) }, } website.WebsiteCommand.AddCommand(migrateCommand) website.WebsiteCommand.AddCommand(makeMigrationCommand) } func getSortedMigrationVersions() []types.MigrationVersion { var allVersions []types.MigrationVersion for migrationTime, _ := range migrations.All { allVersions = append(allVersions, migrationTime) } sort.Slice(allVersions, func(i, j int) bool { return allVersions[i].Before(allVersions[j]) }) return allVersions } func getCurrentVersion(conn *pgx.Conn) (types.MigrationVersion, error) { var currentVersion time.Time row := conn.QueryRow(context.Background(), "SELECT version FROM hmn_migration") err := row.Scan(¤tVersion) if err != nil { return types.MigrationVersion{}, err } currentVersion = currentVersion.UTC() return types.MigrationVersion(currentVersion), nil } func ListMigrations() { conn := db.NewConn() defer conn.Close(context.Background()) currentVersion, _ := getCurrentVersion(conn) for _, version := range getSortedMigrationVersions() { migration := migrations.All[version] indicator := " " if version.Equal(currentVersion) { indicator = "✔ " } fmt.Printf("%s%v (%s: %s)\n", indicator, version, migration.Name(), migration.Description()) } } func Migrate(targetVersion types.MigrationVersion) { conn := db.NewConn() defer conn.Close(context.Background()) // create migration table _, err := conn.Exec(context.Background(), ` CREATE TABLE IF NOT EXISTS hmn_migration ( version TIMESTAMP WITH TIME ZONE ) `) if err != nil { panic(fmt.Errorf("failed to create migration table: %w", err)) } // ensure there is a row row := conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM hmn_migration") var numRows int err = row.Scan(&numRows) if err != nil { panic(err) } if numRows < 1 { _, err := conn.Exec(context.Background(), "INSERT INTO hmn_migration (version) VALUES ($1)", time.Time{}) if err != nil { panic(fmt.Errorf("failed to insert initial migration row: %w", err)) } } // run migrations currentVersion, err := getCurrentVersion(conn) if err != nil { panic(fmt.Errorf("failed to get current version: %w", err)) } if currentVersion.IsZero() { fmt.Println("This is the first time you have run database migrations.") } else { fmt.Printf("Current version: %s\n", currentVersion.String()) } allVersions := getSortedMigrationVersions() if targetVersion.IsZero() { targetVersion = allVersions[len(allVersions)-1] } currentIndex := -1 targetIndex := -1 for i, version := range allVersions { if currentVersion.Equal(version) { currentIndex = i } if targetVersion.Equal(version) { targetIndex = i } } if targetIndex < 0 { fmt.Printf("ERROR: Could not find migration with version %v\n", targetVersion) return } if currentIndex < targetIndex { // roll forward for i := currentIndex + 1; i <= targetIndex; i++ { version := allVersions[i] fmt.Printf("Applying migration %v\n", version) migration := migrations.All[version] tx, err := conn.Begin(context.Background()) if err != nil { panic(fmt.Errorf("failed to start transaction: %w", err)) } defer tx.Rollback(context.Background()) err = migration.Up(tx) if err != nil { fmt.Printf("MIGRATION FAILED for migration %v.\n", version) fmt.Printf("Error: %v\n", err) return } _, err = tx.Exec(context.Background(), "UPDATE hmn_migration SET version = $1", version) if err != nil { panic(fmt.Errorf("failed to update version in migrations table: %w", err)) } err = tx.Commit(context.Background()) if err != nil { panic(fmt.Errorf("failed to commit transaction: %w", err)) } } } else if currentIndex > targetIndex { // roll back for i := currentIndex; i > targetIndex; i-- { version := allVersions[i] previousVersion := types.MigrationVersion{} if i > 0 { previousVersion = allVersions[i-1] } tx, err := conn.Begin(context.Background()) if err != nil { panic(fmt.Errorf("failed to start transaction: %w", err)) } defer tx.Rollback(context.Background()) fmt.Printf("Rolling back migration %v\n", version) migration := migrations.All[version] err = migration.Down(tx) if err != nil { fmt.Printf("MIGRATION FAILED for migration %v.\n", version) fmt.Printf("Error: %v\n", err) return } _, err = tx.Exec(context.Background(), "UPDATE hmn_migration SET version = $1", previousVersion) if err != nil { panic(fmt.Errorf("failed to update version in migrations table: %w", err)) } err = tx.Commit(context.Background()) if err != nil { panic(fmt.Errorf("failed to commit transaction: %w", err)) } } } else { fmt.Println("Already migrated; nothing to do.") } } //go:embed migrationTemplate.txt var migrationTemplate string func MakeMigration(name, description string) { result := migrationTemplate result = strings.ReplaceAll(result, "%NAME%", name) result = strings.ReplaceAll(result, "%DESCRIPTION%", fmt.Sprintf("%#v", description)) now := time.Now().UTC() nowConstructor := fmt.Sprintf("time.Date(%d, %d, %d, %d, %d, %d, 0, time.UTC)", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second()) result = strings.ReplaceAll(result, "%DATE%", nowConstructor) safeVersion := strings.ReplaceAll(types.MigrationVersion(now).String(), ":", "") filename := fmt.Sprintf("%v_%v.go", safeVersion, name) path := filepath.Join("migration", "migrations", filename) err := os.WriteFile(path, []byte(result), 0644) if err != nil { panic(fmt.Errorf("failed to write migration file: %w", err)) } fmt.Println("Successfully created migration file:") fmt.Println(path) }