106 lines
2.3 KiB
Go
106 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
|
|
"git.petrovv.com/nikola/go-migrate"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/joho/godotenv"
|
|
)
|
|
|
|
func main() {
|
|
// Parse flags
|
|
schemaDir := flag.String("dir", "schema", "Directory containing desired schema SQL files")
|
|
cmd := flag.String("cmd", "", "Command: generate or apply")
|
|
flag.Parse()
|
|
|
|
if *cmd == "" {
|
|
fmt.Println("Usage:")
|
|
fmt.Println(" go run main.go -cmd=generate -dir=schema - Compare schema and generate migrations")
|
|
fmt.Println(" go run main.go -cmd=apply -dir=schema - Apply pending migrations")
|
|
fmt.Println("\nFlags:")
|
|
flag.PrintDefaults()
|
|
return
|
|
}
|
|
|
|
switch *cmd {
|
|
case "generate":
|
|
generateMigrations(*schemaDir)
|
|
case "apply":
|
|
applyMigrations(*schemaDir)
|
|
default:
|
|
fmt.Println("Unknown command:", *cmd)
|
|
fmt.Println("Use 'generate' or 'apply'")
|
|
}
|
|
}
|
|
|
|
func generateMigrations(schemaDir string) {
|
|
_ = godotenv.Load()
|
|
db := getDBConnection()
|
|
defer db.Close()
|
|
ctx := context.Background()
|
|
|
|
migrations, err := migrate.GetMigrations(ctx, db, schemaDir)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if len(migrations) == 0 {
|
|
fmt.Println("Database schema is up to date")
|
|
return
|
|
}
|
|
|
|
fmt.Println("-- Migration SQL")
|
|
for _, m := range migrations {
|
|
fmt.Println(m)
|
|
}
|
|
}
|
|
|
|
func applyMigrations(schemaDir string) {
|
|
_ = godotenv.Load()
|
|
db := getDBConnection()
|
|
defer db.Close()
|
|
ctx := context.Background()
|
|
|
|
migrations, err := migrate.GetMigrations(ctx, db, schemaDir)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if len(migrations) == 0 {
|
|
fmt.Println("Database schema is up to date")
|
|
return
|
|
}
|
|
|
|
fmt.Println("Applying migrations...")
|
|
for _, m := range migrations {
|
|
fmt.Printf("Executing: %s\n", m)
|
|
_, err := db.Exec(ctx, m)
|
|
if err != nil {
|
|
log.Fatalf("failed to execute %s: %v", m, err)
|
|
}
|
|
}
|
|
|
|
fmt.Println("All migrations applied successfully")
|
|
}
|
|
|
|
func getDBConnection() *pgxpool.Pool {
|
|
dbUser := os.Getenv("DB_USER")
|
|
dbPass := os.Getenv("DB_PASS")
|
|
dbHost := os.Getenv("DB_HOST")
|
|
dbPort := os.Getenv("DB_PORT")
|
|
dbName := os.Getenv("DB_NAME")
|
|
|
|
dsn := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dbUser, dbPass, dbHost, dbPort, dbName)
|
|
|
|
pool, err := pgxpool.New(context.Background(), dsn)
|
|
if err != nil {
|
|
log.Fatalf("Unable to connect to database: %v", err)
|
|
}
|
|
return pool
|
|
}
|