1007 lines
25 KiB
Go
1007 lines
25 KiB
Go
package go_schema_migrate
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// Column represents a database column
|
|
type Column struct {
|
|
Name string
|
|
Type string
|
|
Nullable bool
|
|
Default string
|
|
}
|
|
|
|
// Table represents a database table
|
|
type Table struct {
|
|
Name string
|
|
Columns map[string]*Column
|
|
PrimaryKey []string
|
|
Indexes map[string]*Index
|
|
ForeignKeys map[string]*ForeignKey
|
|
}
|
|
|
|
// Index represents a database index
|
|
type Index struct {
|
|
Name string
|
|
Columns []string
|
|
Unique bool
|
|
}
|
|
|
|
// ForeignKey represents a foreign key constraint
|
|
type ForeignKey struct {
|
|
Name string
|
|
Columns []string
|
|
RefTable string
|
|
RefColumns []string
|
|
OnDelete string
|
|
}
|
|
|
|
// Schema represents the database schema
|
|
type Schema struct {
|
|
Types map[string]string
|
|
Tables map[string]*Table
|
|
}
|
|
|
|
// SchemaFile represents a loaded schema file
|
|
type SchemaFile struct {
|
|
Name string
|
|
SQL string
|
|
}
|
|
|
|
func isSerialType(t string) bool {
|
|
upper := strings.ToUpper(t)
|
|
return upper == "SERIAL" || upper == "BIGSERIAL" || upper == "SMALLSERIAL" ||
|
|
upper == "SERIAL4" || upper == "SERIAL8" || upper == "SERIAL2"
|
|
}
|
|
|
|
func normalizeType(t string) string {
|
|
upper := strings.ToUpper(t)
|
|
switch upper {
|
|
case "BIGSERIAL", "SERIAL8":
|
|
return "BIGINT"
|
|
case "SERIAL", "SERIAL4":
|
|
return "INTEGER"
|
|
case "SMALLSERIAL", "SERIAL2":
|
|
return "SMALLINT"
|
|
case "DOUBLE":
|
|
return "DOUBLE PRECISION"
|
|
case "TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ":
|
|
return "TIMESTAMP WITH TIME ZONE"
|
|
case "TIMESTAMP WITHOUT TIME ZONE":
|
|
return "TIMESTAMP"
|
|
default:
|
|
return upper
|
|
}
|
|
}
|
|
|
|
func normalizeDefault(d string) string {
|
|
if d == "" {
|
|
return ""
|
|
}
|
|
// Remove PostgreSQL type casts like ::TEXT, ::VARCHAR, etc.
|
|
d = regexp.MustCompile(`::\w+`).ReplaceAllString(d, "")
|
|
// Normalize whitespace
|
|
d = strings.Join(strings.Fields(d), " ")
|
|
// Convert to uppercase for case-insensitive comparison of SQL functions/keywords
|
|
d = strings.ToUpper(d)
|
|
return d
|
|
}
|
|
|
|
func columnsMatch(a, b *Column) bool {
|
|
if a == nil || b == nil {
|
|
return false
|
|
}
|
|
return normalizeType(a.Type) == normalizeType(b.Type) && a.Nullable == b.Nullable && normalizeDefault(a.Default) == normalizeDefault(b.Default)
|
|
}
|
|
|
|
func lastIndexOfMatchingParen(s string, start int) int {
|
|
depth := 0
|
|
for i := start; i < len(s); i++ {
|
|
if s[i] == '(' {
|
|
depth++
|
|
} else if s[i] == ')' {
|
|
depth--
|
|
if depth == 0 {
|
|
return i
|
|
}
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func splitByCommaRespectParen(s string) []string {
|
|
var parts []string
|
|
depth := 0
|
|
start := 0
|
|
for i, c := range s {
|
|
if c == '(' {
|
|
depth++
|
|
} else if c == ')' {
|
|
depth--
|
|
} else if c == ',' && depth == 0 {
|
|
parts = append(parts, strings.TrimSpace(s[start:i]))
|
|
start = i + 1
|
|
}
|
|
}
|
|
if start < len(s) {
|
|
parts = append(parts, strings.TrimSpace(s[start:]))
|
|
}
|
|
return parts
|
|
}
|
|
|
|
func getDependencyOrder(schema *Schema) []string {
|
|
// Simple topological sort based on foreign key dependencies
|
|
var order []string
|
|
visited := make(map[string]bool)
|
|
|
|
var visit func(name string)
|
|
visit = func(name string) {
|
|
if visited[name] {
|
|
return
|
|
}
|
|
visited[name] = true
|
|
|
|
// Visit dependencies first
|
|
if table, exists := schema.Tables[name]; exists {
|
|
for _, fk := range table.ForeignKeys {
|
|
if _, depExists := schema.Tables[fk.RefTable]; depExists {
|
|
visit(fk.RefTable)
|
|
}
|
|
}
|
|
}
|
|
|
|
order = append(order, name)
|
|
}
|
|
|
|
for tableName := range schema.Tables {
|
|
visit(tableName)
|
|
}
|
|
|
|
return order
|
|
}
|
|
|
|
func parseInlineForeignKey(def string) *ForeignKey {
|
|
fk := &ForeignKey{}
|
|
|
|
// Extract FOREIGN KEY (col1, col2) REFERENCES table(col1, col2)
|
|
fkIdx := strings.Index(strings.ToUpper(def), "FOREIGN KEY")
|
|
if fkIdx < 0 {
|
|
return nil
|
|
}
|
|
|
|
refIdx := strings.Index(strings.ToUpper(def), "REFERENCES")
|
|
if refIdx < 0 {
|
|
return nil
|
|
}
|
|
|
|
// Extract columns
|
|
fkPart := strings.TrimSpace(def[fkIdx+11 : refIdx])
|
|
fkPart = strings.Trim(fkPart, "()")
|
|
fk.Columns = strings.Split(fkPart, ",")
|
|
for i := range fk.Columns {
|
|
fk.Columns[i] = strings.TrimSpace(fk.Columns[i])
|
|
}
|
|
|
|
// Extract referenced table and columns
|
|
refPart := strings.TrimSpace(def[refIdx+10:])
|
|
parenStart := strings.Index(refPart, "(")
|
|
if parenStart > 0 {
|
|
fk.RefTable = strings.TrimSpace(refPart[:parenStart])
|
|
// Find matching closing parenthesis
|
|
depth := 1
|
|
parenEnd := -1
|
|
for i := parenStart + 1; i < len(refPart); i++ {
|
|
if refPart[i] == '(' {
|
|
depth++
|
|
} else if refPart[i] == ')' {
|
|
depth--
|
|
if depth == 0 {
|
|
parenEnd = i
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if parenEnd > parenStart {
|
|
colsStr := refPart[parenStart+1 : parenEnd]
|
|
fk.RefColumns = strings.Split(colsStr, ",")
|
|
for i := range fk.RefColumns {
|
|
fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i])
|
|
}
|
|
// Extract ON DELETE if present
|
|
if parenEnd+1 < len(refPart) {
|
|
onDeletePart := strings.ToUpper(strings.TrimSpace(refPart[parenEnd+1:]))
|
|
if strings.HasPrefix(onDeletePart, "ON DELETE") {
|
|
fields := strings.Fields(onDeletePart)
|
|
if len(fields) >= 3 {
|
|
fk.OnDelete = fields[2]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return fk
|
|
}
|
|
|
|
func parseCreateTable(schema *Schema, stmt string) {
|
|
// Extract table name
|
|
parts := strings.Fields(stmt)
|
|
if len(parts) < 3 {
|
|
return
|
|
}
|
|
|
|
tableName := strings.Trim(parts[2], "()")
|
|
if tableName == "" {
|
|
return
|
|
}
|
|
|
|
// Find the opening parenthesis
|
|
startIdx := strings.Index(stmt, "(")
|
|
endIdx := lastIndexOfMatchingParen(stmt, startIdx)
|
|
if startIdx < 0 || endIdx < 0 {
|
|
return
|
|
}
|
|
|
|
tableDef := stmt[startIdx+1 : endIdx]
|
|
|
|
table := &Table{
|
|
Name: tableName,
|
|
Columns: make(map[string]*Column),
|
|
PrimaryKey: []string{},
|
|
Indexes: make(map[string]*Index),
|
|
ForeignKeys: make(map[string]*ForeignKey),
|
|
}
|
|
|
|
// Parse table definition
|
|
parseTableDefinition(table, tableDef)
|
|
|
|
schema.Tables[tableName] = table
|
|
}
|
|
|
|
func parseTableDefinition(table *Table, def string) {
|
|
// Split by commas, respecting parentheses
|
|
parts := splitByCommaRespectParen(def)
|
|
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
upperPart := strings.ToUpper(part)
|
|
|
|
// PRIMARY KEY constraint
|
|
if strings.HasPrefix(upperPart, "PRIMARY KEY") {
|
|
pkDef := strings.TrimPrefix(part, "PRIMARY KEY")
|
|
pkDef = strings.TrimSpace(pkDef)
|
|
// Remove outer parentheses
|
|
pkDef = strings.TrimLeft(pkDef, "(")
|
|
pkDef = strings.TrimRight(pkDef, ")")
|
|
pkDef = strings.TrimSpace(pkDef)
|
|
pkCols := strings.Split(pkDef, ",")
|
|
for _, col := range pkCols {
|
|
table.PrimaryKey = append(table.PrimaryKey, strings.TrimSpace(col))
|
|
}
|
|
continue
|
|
}
|
|
|
|
// FOREIGN KEY constraint (inline)
|
|
if strings.Contains(upperPart, "FOREIGN KEY") {
|
|
fk := parseInlineForeignKey(part)
|
|
if fk != nil {
|
|
// Generate a name for the constraint
|
|
name := fmt.Sprintf("fk_%s_%s", table.Name, strings.Join(fk.Columns, "_"))
|
|
fk.Name = name
|
|
table.ForeignKeys[name] = fk
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Check for inline PRIMARY KEY (e.g., "id BIGSERIAL PRIMARY KEY")
|
|
if strings.Contains(upperPart, "PRIMARY KEY") && !strings.HasPrefix(upperPart, "PRIMARY KEY") {
|
|
// Extract column name from inline PRIMARY KEY definition
|
|
fields := strings.Fields(part)
|
|
if len(fields) >= 3 {
|
|
colName := fields[0]
|
|
table.PrimaryKey = append(table.PrimaryKey, colName)
|
|
// Remove PRIMARY KEY tokens for column parsing
|
|
var newFields []string
|
|
for i := 0; i < len(fields); i++ {
|
|
if i < len(fields)-1 && strings.ToUpper(fields[i]) == "PRIMARY" && strings.ToUpper(fields[i+1]) == "KEY" {
|
|
i++
|
|
continue
|
|
}
|
|
newFields = append(newFields, fields[i])
|
|
}
|
|
part = strings.Join(newFields, " ")
|
|
}
|
|
}
|
|
|
|
// Column definition
|
|
col := parseColumn(part)
|
|
if col != nil {
|
|
table.Columns[col.Name] = col
|
|
}
|
|
}
|
|
|
|
// PRIMARY KEY columns must be NOT NULL
|
|
for _, pkCol := range table.PrimaryKey {
|
|
if col, exists := table.Columns[pkCol]; exists {
|
|
col.Nullable = false
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseColumn(def string) *Column {
|
|
col := &Column{Nullable: true}
|
|
|
|
// Split by space, respecting quotes and parentheses
|
|
parts := strings.Fields(def)
|
|
if len(parts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
col.Name = parts[0]
|
|
if len(parts) < 2 {
|
|
return col
|
|
}
|
|
|
|
// Collect type - can be multiple words (e.g., DOUBLE PRECISION)
|
|
// Type continues until we hit a constraint keyword
|
|
var typeParts []string
|
|
typeParts = append(typeParts, parts[1])
|
|
|
|
i := 2
|
|
for i < len(parts) {
|
|
p := strings.ToUpper(parts[i])
|
|
// Stop at constraint keywords
|
|
if p == "NOT" || p == "NULL" || p == "DEFAULT" || p == "REFERENCES" ||
|
|
p == "PRIMARY" || p == "FOREIGN" || p == "UNIQUE" || p == "CHECK" || p == "CONSTRAINT" {
|
|
break
|
|
}
|
|
typeParts = append(typeParts, parts[i])
|
|
i++
|
|
}
|
|
col.Type = strings.Join(typeParts, " ")
|
|
|
|
// Check for NOT NULL and other constraints
|
|
for ; i < len(parts); i++ {
|
|
p := strings.ToUpper(parts[i])
|
|
if p == "NOT" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "NULL" {
|
|
col.Nullable = false
|
|
i++
|
|
continue
|
|
}
|
|
if p == "NULL" {
|
|
col.Nullable = true
|
|
continue
|
|
}
|
|
if strings.HasPrefix(p, "DEFAULT") {
|
|
if i+1 < len(parts) {
|
|
col.Default = strings.Join(parts[i+1:], " ")
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
return col
|
|
}
|
|
|
|
func parseCreateIndex(schema *Schema, stmt string) {
|
|
isUnique := strings.Contains(strings.ToUpper(stmt), "UNIQUE")
|
|
|
|
// Extract index name
|
|
parts := strings.Fields(stmt)
|
|
if len(parts) < 4 {
|
|
return
|
|
}
|
|
|
|
idxName := parts[2]
|
|
if strings.ToUpper(parts[1]) == "UNIQUE" {
|
|
idxName = parts[3]
|
|
}
|
|
|
|
// Find ON
|
|
onIdx := -1
|
|
for i, p := range parts {
|
|
if strings.ToUpper(p) == "ON" {
|
|
onIdx = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if onIdx < 0 || onIdx+1 >= len(parts) {
|
|
return
|
|
}
|
|
|
|
tableName := parts[onIdx+1]
|
|
|
|
// Find columns inside parentheses after table name
|
|
var columns []string
|
|
// Find the opening parenthesis after the table name
|
|
parenStart := strings.Index(stmt, "(")
|
|
if parenStart >= 0 {
|
|
// Find the matching closing parenthesis
|
|
depth := 0
|
|
parenEnd := -1
|
|
for i := parenStart; i < len(stmt); i++ {
|
|
if stmt[i] == '(' {
|
|
depth++
|
|
} else if stmt[i] == ')' {
|
|
depth--
|
|
if depth == 0 {
|
|
parenEnd = i
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if parenEnd > parenStart {
|
|
colsStr := stmt[parenStart+1 : parenEnd]
|
|
columns = strings.Split(colsStr, ",")
|
|
for i := range columns {
|
|
columns[i] = strings.TrimSpace(columns[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
if _, exists := schema.Tables[tableName]; !exists {
|
|
schema.Tables[tableName] = &Table{
|
|
Name: tableName,
|
|
Columns: make(map[string]*Column),
|
|
Indexes: make(map[string]*Index),
|
|
ForeignKeys: make(map[string]*ForeignKey),
|
|
}
|
|
}
|
|
|
|
schema.Tables[tableName].Indexes[idxName] = &Index{
|
|
Name: idxName,
|
|
Columns: columns,
|
|
Unique: isUnique,
|
|
}
|
|
}
|
|
|
|
func parseAlterTableAddConstraint(schema *Schema, stmt string) {
|
|
// ALTER TABLE table ADD CONSTRAINT name FOREIGN KEY (col) REFERENCES table(col)
|
|
parts := strings.Fields(stmt)
|
|
if len(parts) < 6 {
|
|
return
|
|
}
|
|
|
|
tableName := parts[2]
|
|
constraintName := parts[5]
|
|
|
|
if _, exists := schema.Tables[tableName]; !exists {
|
|
schema.Tables[tableName] = &Table{
|
|
Name: tableName,
|
|
Columns: make(map[string]*Column),
|
|
Indexes: make(map[string]*Index),
|
|
ForeignKeys: make(map[string]*ForeignKey),
|
|
}
|
|
}
|
|
|
|
// Parse the constraint definition
|
|
// Find the constraint type (FOREIGN KEY, etc.)
|
|
constraintType := ""
|
|
for i, p := range parts {
|
|
if strings.ToUpper(p) == "FOREIGN" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "KEY" {
|
|
constraintType = "FOREIGN KEY"
|
|
break
|
|
}
|
|
}
|
|
|
|
if constraintType == "FOREIGN KEY" {
|
|
// Extract the full definition
|
|
def := strings.Join(parts[6:], " ")
|
|
fk := parseInlineForeignKey(def)
|
|
if fk != nil {
|
|
fk.Name = constraintName
|
|
schema.Tables[tableName].ForeignKeys[constraintName] = fk
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseSchemaFile(schema *Schema, sql string) {
|
|
// Split into statements
|
|
statements := strings.Split(sql, ";")
|
|
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
|
|
upperStmt := strings.ToUpper(stmt)
|
|
|
|
// CREATE TYPE
|
|
if strings.HasPrefix(upperStmt, "CREATE TYPE") {
|
|
parts := strings.Fields(stmt)
|
|
if len(parts) >= 3 {
|
|
typeName := parts[2]
|
|
// Store the full CREATE TYPE statement for later generation
|
|
schema.Types[typeName] = stmt
|
|
}
|
|
}
|
|
|
|
// CREATE TABLE
|
|
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
|
|
parseCreateTable(schema, stmt)
|
|
}
|
|
|
|
// CREATE INDEX
|
|
if strings.HasPrefix(upperStmt, "CREATE") && strings.Contains(upperStmt, "INDEX") {
|
|
parseCreateIndex(schema, stmt)
|
|
}
|
|
|
|
// ALTER TABLE ... ADD CONSTRAINT
|
|
if strings.HasPrefix(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
|
|
parseAlterTableAddConstraint(schema, stmt)
|
|
}
|
|
}
|
|
}
|
|
|
|
func loadSchemaFiles(dir string) ([]SchemaFile, error) {
|
|
files, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var schemaFiles []SchemaFile
|
|
for _, f := range files {
|
|
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
|
|
content, err := os.ReadFile(filepath.Join(dir, f.Name()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
schemaFiles = append(schemaFiles, SchemaFile{
|
|
Name: f.Name(),
|
|
SQL: string(content),
|
|
})
|
|
}
|
|
}
|
|
|
|
// Sort by filename (natural sort for numbered files)
|
|
sort.Slice(schemaFiles, func(i, j int) bool {
|
|
return schemaFiles[i].Name < schemaFiles[j].Name
|
|
})
|
|
|
|
return schemaFiles, nil
|
|
}
|
|
|
|
func loadDesiredSchema(dir string) (*Schema, error) {
|
|
files, err := loadSchemaFiles(dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
schema := &Schema{
|
|
Types: make(map[string]string),
|
|
Tables: make(map[string]*Table),
|
|
}
|
|
|
|
for _, f := range files {
|
|
parseSchemaFile(schema, f.SQL)
|
|
}
|
|
|
|
return schema, nil
|
|
}
|
|
|
|
func generateCreateTableSQL(table *Table) string {
|
|
var lines []string
|
|
|
|
// Table header
|
|
lines = append(lines, fmt.Sprintf("CREATE TABLE %s (", table.Name))
|
|
|
|
// Columns - all get trailing commas
|
|
for colName, col := range table.Columns {
|
|
colSQL := fmt.Sprintf(" %s %s", colName, col.Type)
|
|
if !col.Nullable {
|
|
colSQL += " NOT NULL"
|
|
}
|
|
if col.Default != "" {
|
|
colSQL += fmt.Sprintf(" DEFAULT %s", col.Default)
|
|
}
|
|
lines = append(lines, colSQL+",")
|
|
}
|
|
|
|
// Build list of constraints (PRIMARY KEY + UNIQUE constraints)
|
|
var constraints []string
|
|
if len(table.PrimaryKey) > 0 {
|
|
pkCols := make([]string, len(table.PrimaryKey))
|
|
for i, col := range table.PrimaryKey {
|
|
pkCols[i] = strings.TrimSpace(strings.Trim(col, "()"))
|
|
}
|
|
constraints = append(constraints, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pkCols, ", ")))
|
|
}
|
|
for _, idx := range table.Indexes {
|
|
if idx.Unique && len(idx.Columns) > 0 {
|
|
constraints = append(constraints, fmt.Sprintf("UNIQUE (%s)", strings.Join(idx.Columns, ", ")))
|
|
}
|
|
}
|
|
|
|
// Add constraints with commas between them
|
|
for i, constraint := range constraints {
|
|
if i < len(constraints)-1 {
|
|
lines = append(lines, fmt.Sprintf(" %s,", constraint))
|
|
} else {
|
|
lines = append(lines, fmt.Sprintf(" %s", constraint))
|
|
}
|
|
}
|
|
|
|
lines = append(lines, ");")
|
|
|
|
return strings.Join(lines, "\n")
|
|
}
|
|
|
|
func generateAddColumnSQL(tableName, colName string, col *Column) string {
|
|
colSQL := fmt.Sprintf("%s %s", colName, col.Type)
|
|
if !col.Nullable {
|
|
colSQL += " NOT NULL"
|
|
}
|
|
if col.Default != "" {
|
|
colSQL += fmt.Sprintf(" DEFAULT %s", col.Default)
|
|
}
|
|
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", tableName, colSQL)
|
|
}
|
|
|
|
func generateAddForeignKeySQL(tableName string, fk *ForeignKey) string {
|
|
colStr := strings.Join(fk.Columns, ", ")
|
|
refStr := fmt.Sprintf("%s(%s)", fk.RefTable, strings.Join(fk.RefColumns, ", "))
|
|
|
|
constraint := fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s", colStr, refStr)
|
|
if fk.OnDelete != "" {
|
|
constraint += fmt.Sprintf(" ON DELETE %s", fk.OnDelete)
|
|
}
|
|
|
|
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", tableName, fk.Name, constraint)
|
|
}
|
|
|
|
func generateCreateIndexSQL(tableName string, idx *Index) string {
|
|
unique := ""
|
|
if idx.Unique {
|
|
unique = "UNIQUE "
|
|
}
|
|
cols := strings.Join(idx.Columns, ", ")
|
|
return fmt.Sprintf("CREATE %sINDEX %s ON %s (%s);", unique, idx.Name, tableName, cols)
|
|
}
|
|
|
|
func CompareSchemas(current, desired *Schema) []string {
|
|
var migrations []string
|
|
|
|
// Create missing types
|
|
for typeName, typeStmt := range desired.Types {
|
|
if _, exists := current.Types[typeName]; !exists {
|
|
migrations = append(migrations, typeStmt+";")
|
|
}
|
|
}
|
|
|
|
// Process tables in dependency order
|
|
tableOrder := getDependencyOrder(desired)
|
|
|
|
for _, tableName := range tableOrder {
|
|
desiredTable := desired.Tables[tableName]
|
|
currentTable, exists := current.Tables[tableName]
|
|
|
|
if !exists {
|
|
// Create table with all constraints and unique indexes
|
|
migrations = append(migrations, generateCreateTableSQL(desiredTable))
|
|
// Add foreign keys as separate ALTER TABLE statements (PostgreSQL doesn't support inline FK in CREATE TABLE with existing types)
|
|
for _, desiredFK := range desiredTable.ForeignKeys {
|
|
migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK))
|
|
}
|
|
// Add non-unique indexes as separate CREATE INDEX statements
|
|
for _, desiredIdx := range desiredTable.Indexes {
|
|
if !desiredIdx.Unique {
|
|
migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx))
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Compare columns
|
|
for colName, desiredCol := range desiredTable.Columns {
|
|
currentCol, colExists := currentTable.Columns[colName]
|
|
|
|
if !colExists {
|
|
migrations = append(migrations, generateAddColumnSQL(tableName, colName, desiredCol))
|
|
} else if !columnsMatch(currentCol, desiredCol) {
|
|
// Generate only the specific changes needed
|
|
if currentCol.Nullable != desiredCol.Nullable {
|
|
if desiredCol.Nullable {
|
|
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL;", tableName, colName))
|
|
} else {
|
|
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET NOT NULL;", tableName, colName))
|
|
}
|
|
}
|
|
// Skip default comparison for serial types (implicit defaults)
|
|
if !isSerialType(currentCol.Type) && !isSerialType(desiredCol.Type) && currentCol.Default != desiredCol.Default {
|
|
if desiredCol.Default != "" {
|
|
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", tableName, colName, desiredCol.Default))
|
|
} else {
|
|
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", tableName, colName))
|
|
}
|
|
}
|
|
// Type changes are skipped for simplicity
|
|
}
|
|
}
|
|
|
|
// Compare foreign keys
|
|
for fkName, desiredFK := range desiredTable.ForeignKeys {
|
|
if _, exists := currentTable.ForeignKeys[fkName]; !exists {
|
|
migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK))
|
|
}
|
|
}
|
|
|
|
// Compare indexes - only generate for non-unique (unique are handled in CREATE TABLE)
|
|
for idxName, desiredIdx := range desiredTable.Indexes {
|
|
if _, exists := currentTable.Indexes[idxName]; !exists {
|
|
if !desiredIdx.Unique {
|
|
migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return migrations
|
|
}
|
|
|
|
func GetCurrentSchema(ctx context.Context, db *pgxpool.Pool) (*Schema, error) {
|
|
schema := &Schema{
|
|
Types: make(map[string]string),
|
|
Tables: make(map[string]*Table),
|
|
}
|
|
|
|
// Get custom types (enums)
|
|
rows, err := db.Query(ctx, `
|
|
SELECT typname
|
|
FROM pg_type t
|
|
JOIN pg_namespace n ON n.oid = t.typnamespace
|
|
WHERE n.nspname = 'public' AND t.typtype = 'e'
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for rows.Next() {
|
|
var typeName string
|
|
if err := rows.Scan(&typeName); err != nil {
|
|
rows.Close()
|
|
return nil, err
|
|
}
|
|
schema.Types[typeName] = "enum"
|
|
}
|
|
rows.Close()
|
|
|
|
// Get tables
|
|
tableRows, err := db.Query(ctx, `
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
|
|
ORDER BY table_name
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tableRows.Close()
|
|
|
|
for tableRows.Next() {
|
|
var tableName string
|
|
if err := tableRows.Scan(&tableName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
table := &Table{
|
|
Name: tableName,
|
|
Columns: make(map[string]*Column),
|
|
PrimaryKey: []string{},
|
|
Indexes: make(map[string]*Index),
|
|
ForeignKeys: make(map[string]*ForeignKey),
|
|
}
|
|
|
|
// Get columns
|
|
colRows, err := db.Query(ctx, `
|
|
SELECT column_name, data_type, udt_name, is_nullable, column_default
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public' AND table_name = $1
|
|
ORDER BY ordinal_position
|
|
`, tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for colRows.Next() {
|
|
var col Column
|
|
var dataType, udtName, nullableStr string
|
|
var defaultVal *string
|
|
if err := colRows.Scan(&col.Name, &dataType, &udtName, &nullableStr, &defaultVal); err != nil {
|
|
colRows.Close()
|
|
return nil, err
|
|
}
|
|
// For user-defined types (enums), use udt_name instead of data_type
|
|
if dataType == "USER-DEFINED" && udtName != "" {
|
|
col.Type = udtName
|
|
} else {
|
|
col.Type = dataType
|
|
}
|
|
col.Nullable = nullableStr == "YES"
|
|
if defaultVal != nil {
|
|
col.Default = strings.TrimSpace(*defaultVal)
|
|
} else {
|
|
col.Default = ""
|
|
}
|
|
table.Columns[col.Name] = &col
|
|
}
|
|
colRows.Close()
|
|
|
|
// Get primary key
|
|
pkRows, err := db.Query(ctx, `
|
|
SELECT a.attname
|
|
FROM pg_index ix
|
|
JOIN pg_attribute a ON a.attrelid = ix.indrelid AND a.attnum = ANY(ix.indkey)
|
|
JOIN pg_class c ON c.oid = ix.indrelid
|
|
WHERE c.relname = $1 AND ix.indisprimary
|
|
ORDER BY a.attnum
|
|
`, tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for pkRows.Next() {
|
|
var colName string
|
|
if err := pkRows.Scan(&colName); err != nil {
|
|
pkRows.Close()
|
|
return nil, err
|
|
}
|
|
table.PrimaryKey = append(table.PrimaryKey, colName)
|
|
}
|
|
pkRows.Close()
|
|
|
|
// Get indexes
|
|
idxRows, err := db.Query(ctx, `
|
|
SELECT i.relname, ix.indisunique, pg_get_indexdef(ix.indexrelid)
|
|
FROM pg_index ix
|
|
JOIN pg_class i ON i.oid = ix.indexrelid
|
|
JOIN pg_class t ON t.oid = ix.indrelid
|
|
WHERE t.relname = $1 AND i.relkind = 'i' AND ix.indisprimary = false
|
|
`, tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for idxRows.Next() {
|
|
var idxName string
|
|
var isUnique bool
|
|
var indexDef string
|
|
if err := idxRows.Scan(&idxName, &isUnique, &indexDef); err != nil {
|
|
idxRows.Close()
|
|
return nil, err
|
|
}
|
|
// Parse columns from index definition
|
|
columns := parseIndexColumns(indexDef)
|
|
|
|
table.Indexes[idxName] = &Index{
|
|
Name: idxName,
|
|
Columns: columns,
|
|
Unique: isUnique,
|
|
}
|
|
}
|
|
idxRows.Close()
|
|
|
|
// Get foreign keys
|
|
fkRows, err := db.Query(ctx, `
|
|
SELECT conname, pg_get_constraintdef(c.oid)
|
|
FROM pg_constraint c
|
|
JOIN pg_class cl ON cl.oid = c.conrelid
|
|
WHERE cl.relname = $1 AND c.contype = 'f'
|
|
`, tableName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for fkRows.Next() {
|
|
var fkName, def string
|
|
if err := fkRows.Scan(&fkName, &def); err != nil {
|
|
fkRows.Close()
|
|
return nil, err
|
|
}
|
|
// Parse constraint definition (simplified)
|
|
fk := parseForeignKey(def)
|
|
fk.Name = fkName
|
|
table.ForeignKeys[fkName] = fk
|
|
}
|
|
fkRows.Close()
|
|
|
|
schema.Tables[tableName] = table
|
|
}
|
|
|
|
return schema, nil
|
|
}
|
|
|
|
func parseIndexColumns(def string) []string {
|
|
// Example: "CREATE UNIQUE INDEX idx_name ON table_name (col1, col2)"
|
|
// Extract the part between parentheses
|
|
startIdx := strings.Index(def, "(")
|
|
endIdx := strings.LastIndex(def, ")")
|
|
if startIdx < 0 || endIdx < 0 {
|
|
return nil
|
|
}
|
|
colsStr := def[startIdx+1 : endIdx]
|
|
cols := strings.Split(colsStr, ",")
|
|
for i := range cols {
|
|
cols[i] = strings.TrimSpace(cols[i])
|
|
}
|
|
return cols
|
|
}
|
|
|
|
func parseForeignKey(def string) *ForeignKey {
|
|
fk := &ForeignKey{}
|
|
|
|
// Extract columns
|
|
if strings.Contains(def, "FOREIGN KEY") {
|
|
parts := strings.Split(def, "FOREIGN KEY")
|
|
if len(parts) >= 2 {
|
|
colPart := strings.TrimSpace(parts[1])
|
|
// Get columns before REFERENCES
|
|
refIdx := strings.Index(colPart, "REFERENCES")
|
|
if refIdx >= 0 {
|
|
colDef := strings.TrimSpace(colPart[:refIdx])
|
|
colDef = strings.Trim(colDef, "()")
|
|
fk.Columns = strings.Split(colDef, ",")
|
|
for i := range fk.Columns {
|
|
fk.Columns[i] = strings.TrimSpace(fk.Columns[i])
|
|
}
|
|
|
|
// Get referenced table and columns
|
|
refPart := strings.TrimSpace(colPart[refIdx+10:])
|
|
refParts := strings.Fields(refPart)
|
|
if len(refParts) >= 1 {
|
|
fk.RefTable = strings.Trim(refParts[0], "()")
|
|
if len(refParts) >= 2 {
|
|
fk.RefColumns = strings.Split(strings.Trim(refParts[1], "()"), ",")
|
|
for i := range fk.RefColumns {
|
|
fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract ON DELETE
|
|
if strings.Contains(def, "ON DELETE") {
|
|
onDeleteParts := strings.Split(def, "ON DELETE")
|
|
if len(onDeleteParts) >= 2 {
|
|
fk.OnDelete = strings.TrimSpace(strings.Fields(onDeleteParts[1])[0])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return fk
|
|
}
|
|
|
|
// LoadDesiredSchema loads schema definitions from SQL files in the given directory
|
|
func LoadDesiredSchema(dir string) (*Schema, error) {
|
|
return loadDesiredSchema(dir)
|
|
}
|
|
|
|
// GetMigrations compares the current database schema with the desired schema
|
|
// and returns a list of SQL migration statements
|
|
func GetMigrations(ctx context.Context, db *pgxpool.Pool, schemaDir string) ([]string, error) {
|
|
// Load current database schema
|
|
currentSchema, err := GetCurrentSchema(ctx, db)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get current schema: %w", err)
|
|
}
|
|
|
|
// Load desired schema from files
|
|
desiredSchema, err := LoadDesiredSchema(schemaDir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load desired schema: %w", err)
|
|
}
|
|
|
|
// Compare and generate migration SQL
|
|
return CompareSchemas(currentSchema, desiredSchema), nil
|
|
}
|