Files
go_sql_schema_migrate/migrate.go
2026-05-31 16:11:06 +02:00

1007 lines
25 KiB
Go

package migrate
import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
)
// Column represents a database column
type Column struct {
Name string
Type string
Nullable bool
Default string
}
// Table represents a database table
type Table struct {
Name string
Columns map[string]*Column
PrimaryKey []string
Indexes map[string]*Index
ForeignKeys map[string]*ForeignKey
}
// Index represents a database index
type Index struct {
Name string
Columns []string
Unique bool
}
// ForeignKey represents a foreign key constraint
type ForeignKey struct {
Name string
Columns []string
RefTable string
RefColumns []string
OnDelete string
}
// Schema represents the database schema
type Schema struct {
Types map[string]string
Tables map[string]*Table
}
// SchemaFile represents a loaded schema file
type SchemaFile struct {
Name string
SQL string
}
func isSerialType(t string) bool {
upper := strings.ToUpper(t)
return upper == "SERIAL" || upper == "BIGSERIAL" || upper == "SMALLSERIAL" ||
upper == "SERIAL4" || upper == "SERIAL8" || upper == "SERIAL2"
}
func normalizeType(t string) string {
upper := strings.ToUpper(t)
switch upper {
case "BIGSERIAL", "SERIAL8":
return "BIGINT"
case "SERIAL", "SERIAL4":
return "INTEGER"
case "SMALLSERIAL", "SERIAL2":
return "SMALLINT"
case "DOUBLE":
return "DOUBLE PRECISION"
case "TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ":
return "TIMESTAMP WITH TIME ZONE"
case "TIMESTAMP WITHOUT TIME ZONE":
return "TIMESTAMP"
default:
return upper
}
}
func normalizeDefault(d string) string {
if d == "" {
return ""
}
// Remove PostgreSQL type casts like ::TEXT, ::VARCHAR, etc.
d = regexp.MustCompile(`::\w+`).ReplaceAllString(d, "")
// Normalize whitespace
d = strings.Join(strings.Fields(d), " ")
// Convert to uppercase for case-insensitive comparison of SQL functions/keywords
d = strings.ToUpper(d)
return d
}
func columnsMatch(a, b *Column) bool {
if a == nil || b == nil {
return false
}
return normalizeType(a.Type) == normalizeType(b.Type) && a.Nullable == b.Nullable && normalizeDefault(a.Default) == normalizeDefault(b.Default)
}
func lastIndexOfMatchingParen(s string, start int) int {
depth := 0
for i := start; i < len(s); i++ {
if s[i] == '(' {
depth++
} else if s[i] == ')' {
depth--
if depth == 0 {
return i
}
}
}
return -1
}
func splitByCommaRespectParen(s string) []string {
var parts []string
depth := 0
start := 0
for i, c := range s {
if c == '(' {
depth++
} else if c == ')' {
depth--
} else if c == ',' && depth == 0 {
parts = append(parts, strings.TrimSpace(s[start:i]))
start = i + 1
}
}
if start < len(s) {
parts = append(parts, strings.TrimSpace(s[start:]))
}
return parts
}
func getDependencyOrder(schema *Schema) []string {
// Simple topological sort based on foreign key dependencies
var order []string
visited := make(map[string]bool)
var visit func(name string)
visit = func(name string) {
if visited[name] {
return
}
visited[name] = true
// Visit dependencies first
if table, exists := schema.Tables[name]; exists {
for _, fk := range table.ForeignKeys {
if _, depExists := schema.Tables[fk.RefTable]; depExists {
visit(fk.RefTable)
}
}
}
order = append(order, name)
}
for tableName := range schema.Tables {
visit(tableName)
}
return order
}
func parseInlineForeignKey(def string) *ForeignKey {
fk := &ForeignKey{}
// Extract FOREIGN KEY (col1, col2) REFERENCES table(col1, col2)
fkIdx := strings.Index(strings.ToUpper(def), "FOREIGN KEY")
if fkIdx < 0 {
return nil
}
refIdx := strings.Index(strings.ToUpper(def), "REFERENCES")
if refIdx < 0 {
return nil
}
// Extract columns
fkPart := strings.TrimSpace(def[fkIdx+11 : refIdx])
fkPart = strings.Trim(fkPart, "()")
fk.Columns = strings.Split(fkPart, ",")
for i := range fk.Columns {
fk.Columns[i] = strings.TrimSpace(fk.Columns[i])
}
// Extract referenced table and columns
refPart := strings.TrimSpace(def[refIdx+10:])
parenStart := strings.Index(refPart, "(")
if parenStart > 0 {
fk.RefTable = strings.TrimSpace(refPart[:parenStart])
// Find matching closing parenthesis
depth := 1
parenEnd := -1
for i := parenStart + 1; i < len(refPart); i++ {
if refPart[i] == '(' {
depth++
} else if refPart[i] == ')' {
depth--
if depth == 0 {
parenEnd = i
break
}
}
}
if parenEnd > parenStart {
colsStr := refPart[parenStart+1 : parenEnd]
fk.RefColumns = strings.Split(colsStr, ",")
for i := range fk.RefColumns {
fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i])
}
// Extract ON DELETE if present
if parenEnd+1 < len(refPart) {
onDeletePart := strings.ToUpper(strings.TrimSpace(refPart[parenEnd+1:]))
if strings.HasPrefix(onDeletePart, "ON DELETE") {
fields := strings.Fields(onDeletePart)
if len(fields) >= 3 {
fk.OnDelete = fields[2]
}
}
}
}
}
return fk
}
func parseCreateTable(schema *Schema, stmt string) {
// Extract table name
parts := strings.Fields(stmt)
if len(parts) < 3 {
return
}
tableName := strings.Trim(parts[2], "()")
if tableName == "" {
return
}
// Find the opening parenthesis
startIdx := strings.Index(stmt, "(")
endIdx := lastIndexOfMatchingParen(stmt, startIdx)
if startIdx < 0 || endIdx < 0 {
return
}
tableDef := stmt[startIdx+1 : endIdx]
table := &Table{
Name: tableName,
Columns: make(map[string]*Column),
PrimaryKey: []string{},
Indexes: make(map[string]*Index),
ForeignKeys: make(map[string]*ForeignKey),
}
// Parse table definition
parseTableDefinition(table, tableDef)
schema.Tables[tableName] = table
}
func parseTableDefinition(table *Table, def string) {
// Split by commas, respecting parentheses
parts := splitByCommaRespectParen(def)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
upperPart := strings.ToUpper(part)
// PRIMARY KEY constraint
if strings.HasPrefix(upperPart, "PRIMARY KEY") {
pkDef := strings.TrimPrefix(part, "PRIMARY KEY")
pkDef = strings.TrimSpace(pkDef)
// Remove outer parentheses
pkDef = strings.TrimLeft(pkDef, "(")
pkDef = strings.TrimRight(pkDef, ")")
pkDef = strings.TrimSpace(pkDef)
pkCols := strings.Split(pkDef, ",")
for _, col := range pkCols {
table.PrimaryKey = append(table.PrimaryKey, strings.TrimSpace(col))
}
continue
}
// FOREIGN KEY constraint (inline)
if strings.Contains(upperPart, "FOREIGN KEY") {
fk := parseInlineForeignKey(part)
if fk != nil {
// Generate a name for the constraint
name := fmt.Sprintf("fk_%s_%s", table.Name, strings.Join(fk.Columns, "_"))
fk.Name = name
table.ForeignKeys[name] = fk
}
continue
}
// Check for inline PRIMARY KEY (e.g., "id BIGSERIAL PRIMARY KEY")
if strings.Contains(upperPart, "PRIMARY KEY") && !strings.HasPrefix(upperPart, "PRIMARY KEY") {
// Extract column name from inline PRIMARY KEY definition
fields := strings.Fields(part)
if len(fields) >= 3 {
colName := fields[0]
table.PrimaryKey = append(table.PrimaryKey, colName)
// Remove PRIMARY KEY tokens for column parsing
var newFields []string
for i := 0; i < len(fields); i++ {
if i < len(fields)-1 && strings.ToUpper(fields[i]) == "PRIMARY" && strings.ToUpper(fields[i+1]) == "KEY" {
i++
continue
}
newFields = append(newFields, fields[i])
}
part = strings.Join(newFields, " ")
}
}
// Column definition
col := parseColumn(part)
if col != nil {
table.Columns[col.Name] = col
}
}
// PRIMARY KEY columns must be NOT NULL
for _, pkCol := range table.PrimaryKey {
if col, exists := table.Columns[pkCol]; exists {
col.Nullable = false
}
}
}
func parseColumn(def string) *Column {
col := &Column{Nullable: true}
// Split by space, respecting quotes and parentheses
parts := strings.Fields(def)
if len(parts) == 0 {
return nil
}
col.Name = parts[0]
if len(parts) < 2 {
return col
}
// Collect type - can be multiple words (e.g., DOUBLE PRECISION)
// Type continues until we hit a constraint keyword
var typeParts []string
typeParts = append(typeParts, parts[1])
i := 2
for i < len(parts) {
p := strings.ToUpper(parts[i])
// Stop at constraint keywords
if p == "NOT" || p == "NULL" || p == "DEFAULT" || p == "REFERENCES" ||
p == "PRIMARY" || p == "FOREIGN" || p == "UNIQUE" || p == "CHECK" || p == "CONSTRAINT" {
break
}
typeParts = append(typeParts, parts[i])
i++
}
col.Type = strings.Join(typeParts, " ")
// Check for NOT NULL and other constraints
for ; i < len(parts); i++ {
p := strings.ToUpper(parts[i])
if p == "NOT" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "NULL" {
col.Nullable = false
i++
continue
}
if p == "NULL" {
col.Nullable = true
continue
}
if strings.HasPrefix(p, "DEFAULT") {
if i+1 < len(parts) {
col.Default = strings.Join(parts[i+1:], " ")
}
break
}
}
return col
}
func parseCreateIndex(schema *Schema, stmt string) {
isUnique := strings.Contains(strings.ToUpper(stmt), "UNIQUE")
// Extract index name
parts := strings.Fields(stmt)
if len(parts) < 4 {
return
}
idxName := parts[2]
if strings.ToUpper(parts[1]) == "UNIQUE" {
idxName = parts[3]
}
// Find ON
onIdx := -1
for i, p := range parts {
if strings.ToUpper(p) == "ON" {
onIdx = i
break
}
}
if onIdx < 0 || onIdx+1 >= len(parts) {
return
}
tableName := parts[onIdx+1]
// Find columns inside parentheses after table name
var columns []string
// Find the opening parenthesis after the table name
parenStart := strings.Index(stmt, "(")
if parenStart >= 0 {
// Find the matching closing parenthesis
depth := 0
parenEnd := -1
for i := parenStart; i < len(stmt); i++ {
if stmt[i] == '(' {
depth++
} else if stmt[i] == ')' {
depth--
if depth == 0 {
parenEnd = i
break
}
}
}
if parenEnd > parenStart {
colsStr := stmt[parenStart+1 : parenEnd]
columns = strings.Split(colsStr, ",")
for i := range columns {
columns[i] = strings.TrimSpace(columns[i])
}
}
}
if _, exists := schema.Tables[tableName]; !exists {
schema.Tables[tableName] = &Table{
Name: tableName,
Columns: make(map[string]*Column),
Indexes: make(map[string]*Index),
ForeignKeys: make(map[string]*ForeignKey),
}
}
schema.Tables[tableName].Indexes[idxName] = &Index{
Name: idxName,
Columns: columns,
Unique: isUnique,
}
}
func parseAlterTableAddConstraint(schema *Schema, stmt string) {
// ALTER TABLE table ADD CONSTRAINT name FOREIGN KEY (col) REFERENCES table(col)
parts := strings.Fields(stmt)
if len(parts) < 6 {
return
}
tableName := parts[2]
constraintName := parts[5]
if _, exists := schema.Tables[tableName]; !exists {
schema.Tables[tableName] = &Table{
Name: tableName,
Columns: make(map[string]*Column),
Indexes: make(map[string]*Index),
ForeignKeys: make(map[string]*ForeignKey),
}
}
// Parse the constraint definition
// Find the constraint type (FOREIGN KEY, etc.)
constraintType := ""
for i, p := range parts {
if strings.ToUpper(p) == "FOREIGN" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "KEY" {
constraintType = "FOREIGN KEY"
break
}
}
if constraintType == "FOREIGN KEY" {
// Extract the full definition
def := strings.Join(parts[6:], " ")
fk := parseInlineForeignKey(def)
if fk != nil {
fk.Name = constraintName
schema.Tables[tableName].ForeignKeys[constraintName] = fk
}
}
}
func parseSchemaFile(schema *Schema, sql string) {
// Split into statements
statements := strings.Split(sql, ";")
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
upperStmt := strings.ToUpper(stmt)
// CREATE TYPE
if strings.HasPrefix(upperStmt, "CREATE TYPE") {
parts := strings.Fields(stmt)
if len(parts) >= 3 {
typeName := parts[2]
// Store the full CREATE TYPE statement for later generation
schema.Types[typeName] = stmt
}
}
// CREATE TABLE
if strings.HasPrefix(upperStmt, "CREATE TABLE") {
parseCreateTable(schema, stmt)
}
// CREATE INDEX
if strings.HasPrefix(upperStmt, "CREATE") && strings.Contains(upperStmt, "INDEX") {
parseCreateIndex(schema, stmt)
}
// ALTER TABLE ... ADD CONSTRAINT
if strings.HasPrefix(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") {
parseAlterTableAddConstraint(schema, stmt)
}
}
}
func loadSchemaFiles(dir string) ([]SchemaFile, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
var schemaFiles []SchemaFile
for _, f := range files {
if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") {
content, err := os.ReadFile(filepath.Join(dir, f.Name()))
if err != nil {
return nil, err
}
schemaFiles = append(schemaFiles, SchemaFile{
Name: f.Name(),
SQL: string(content),
})
}
}
// Sort by filename (natural sort for numbered files)
sort.Slice(schemaFiles, func(i, j int) bool {
return schemaFiles[i].Name < schemaFiles[j].Name
})
return schemaFiles, nil
}
func loadDesiredSchema(dir string) (*Schema, error) {
files, err := loadSchemaFiles(dir)
if err != nil {
return nil, err
}
schema := &Schema{
Types: make(map[string]string),
Tables: make(map[string]*Table),
}
for _, f := range files {
parseSchemaFile(schema, f.SQL)
}
return schema, nil
}
func generateCreateTableSQL(table *Table) string {
var lines []string
// Table header
lines = append(lines, fmt.Sprintf("CREATE TABLE %s (", table.Name))
// Columns - all get trailing commas
for colName, col := range table.Columns {
colSQL := fmt.Sprintf(" %s %s", colName, col.Type)
if !col.Nullable {
colSQL += " NOT NULL"
}
if col.Default != "" {
colSQL += fmt.Sprintf(" DEFAULT %s", col.Default)
}
lines = append(lines, colSQL+",")
}
// Build list of constraints (PRIMARY KEY + UNIQUE constraints)
var constraints []string
if len(table.PrimaryKey) > 0 {
pkCols := make([]string, len(table.PrimaryKey))
for i, col := range table.PrimaryKey {
pkCols[i] = strings.TrimSpace(strings.Trim(col, "()"))
}
constraints = append(constraints, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pkCols, ", ")))
}
for _, idx := range table.Indexes {
if idx.Unique && len(idx.Columns) > 0 {
constraints = append(constraints, fmt.Sprintf("UNIQUE (%s)", strings.Join(idx.Columns, ", ")))
}
}
// Add constraints with commas between them
for i, constraint := range constraints {
if i < len(constraints)-1 {
lines = append(lines, fmt.Sprintf(" %s,", constraint))
} else {
lines = append(lines, fmt.Sprintf(" %s", constraint))
}
}
lines = append(lines, ");")
return strings.Join(lines, "\n")
}
func generateAddColumnSQL(tableName, colName string, col *Column) string {
colSQL := fmt.Sprintf("%s %s", colName, col.Type)
if !col.Nullable {
colSQL += " NOT NULL"
}
if col.Default != "" {
colSQL += fmt.Sprintf(" DEFAULT %s", col.Default)
}
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", tableName, colSQL)
}
func generateAddForeignKeySQL(tableName string, fk *ForeignKey) string {
colStr := strings.Join(fk.Columns, ", ")
refStr := fmt.Sprintf("%s(%s)", fk.RefTable, strings.Join(fk.RefColumns, ", "))
constraint := fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s", colStr, refStr)
if fk.OnDelete != "" {
constraint += fmt.Sprintf(" ON DELETE %s", fk.OnDelete)
}
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", tableName, fk.Name, constraint)
}
func generateCreateIndexSQL(tableName string, idx *Index) string {
unique := ""
if idx.Unique {
unique = "UNIQUE "
}
cols := strings.Join(idx.Columns, ", ")
return fmt.Sprintf("CREATE %sINDEX %s ON %s (%s);", unique, idx.Name, tableName, cols)
}
func CompareSchemas(current, desired *Schema) []string {
var migrations []string
// Create missing types
for typeName, typeStmt := range desired.Types {
if _, exists := current.Types[typeName]; !exists {
migrations = append(migrations, typeStmt+";")
}
}
// Process tables in dependency order
tableOrder := getDependencyOrder(desired)
for _, tableName := range tableOrder {
desiredTable := desired.Tables[tableName]
currentTable, exists := current.Tables[tableName]
if !exists {
// Create table with all constraints and unique indexes
migrations = append(migrations, generateCreateTableSQL(desiredTable))
// Add foreign keys as separate ALTER TABLE statements (PostgreSQL doesn't support inline FK in CREATE TABLE with existing types)
for _, desiredFK := range desiredTable.ForeignKeys {
migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK))
}
// Add non-unique indexes as separate CREATE INDEX statements
for _, desiredIdx := range desiredTable.Indexes {
if !desiredIdx.Unique {
migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx))
}
}
continue
}
// Compare columns
for colName, desiredCol := range desiredTable.Columns {
currentCol, colExists := currentTable.Columns[colName]
if !colExists {
migrations = append(migrations, generateAddColumnSQL(tableName, colName, desiredCol))
} else if !columnsMatch(currentCol, desiredCol) {
// Generate only the specific changes needed
if currentCol.Nullable != desiredCol.Nullable {
if desiredCol.Nullable {
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL;", tableName, colName))
} else {
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET NOT NULL;", tableName, colName))
}
}
// Skip default comparison for serial types (implicit defaults)
if !isSerialType(currentCol.Type) && !isSerialType(desiredCol.Type) && currentCol.Default != desiredCol.Default {
if desiredCol.Default != "" {
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", tableName, colName, desiredCol.Default))
} else {
migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", tableName, colName))
}
}
// Type changes are skipped for simplicity
}
}
// Compare foreign keys
for fkName, desiredFK := range desiredTable.ForeignKeys {
if _, exists := currentTable.ForeignKeys[fkName]; !exists {
migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK))
}
}
// Compare indexes - only generate for non-unique (unique are handled in CREATE TABLE)
for idxName, desiredIdx := range desiredTable.Indexes {
if _, exists := currentTable.Indexes[idxName]; !exists {
if !desiredIdx.Unique {
migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx))
}
}
}
}
return migrations
}
func GetCurrentSchema(ctx context.Context, db *pgxpool.Pool) (*Schema, error) {
schema := &Schema{
Types: make(map[string]string),
Tables: make(map[string]*Table),
}
// Get custom types (enums)
rows, err := db.Query(ctx, `
SELECT typname
FROM pg_type t
JOIN pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = 'public' AND t.typtype = 'e'
`)
if err != nil {
return nil, err
}
for rows.Next() {
var typeName string
if err := rows.Scan(&typeName); err != nil {
rows.Close()
return nil, err
}
schema.Types[typeName] = "enum"
}
rows.Close()
// Get tables
tableRows, err := db.Query(ctx, `
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
ORDER BY table_name
`)
if err != nil {
return nil, err
}
defer tableRows.Close()
for tableRows.Next() {
var tableName string
if err := tableRows.Scan(&tableName); err != nil {
return nil, err
}
table := &Table{
Name: tableName,
Columns: make(map[string]*Column),
PrimaryKey: []string{},
Indexes: make(map[string]*Index),
ForeignKeys: make(map[string]*ForeignKey),
}
// Get columns
colRows, err := db.Query(ctx, `
SELECT column_name, data_type, udt_name, is_nullable, column_default
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
`, tableName)
if err != nil {
return nil, err
}
for colRows.Next() {
var col Column
var dataType, udtName, nullableStr string
var defaultVal *string
if err := colRows.Scan(&col.Name, &dataType, &udtName, &nullableStr, &defaultVal); err != nil {
colRows.Close()
return nil, err
}
// For user-defined types (enums), use udt_name instead of data_type
if dataType == "USER-DEFINED" && udtName != "" {
col.Type = udtName
} else {
col.Type = dataType
}
col.Nullable = nullableStr == "YES"
if defaultVal != nil {
col.Default = strings.TrimSpace(*defaultVal)
} else {
col.Default = ""
}
table.Columns[col.Name] = &col
}
colRows.Close()
// Get primary key
pkRows, err := db.Query(ctx, `
SELECT a.attname
FROM pg_index ix
JOIN pg_attribute a ON a.attrelid = ix.indrelid AND a.attnum = ANY(ix.indkey)
JOIN pg_class c ON c.oid = ix.indrelid
WHERE c.relname = $1 AND ix.indisprimary
ORDER BY a.attnum
`, tableName)
if err != nil {
return nil, err
}
for pkRows.Next() {
var colName string
if err := pkRows.Scan(&colName); err != nil {
pkRows.Close()
return nil, err
}
table.PrimaryKey = append(table.PrimaryKey, colName)
}
pkRows.Close()
// Get indexes
idxRows, err := db.Query(ctx, `
SELECT i.relname, ix.indisunique, pg_get_indexdef(ix.indexrelid)
FROM pg_index ix
JOIN pg_class i ON i.oid = ix.indexrelid
JOIN pg_class t ON t.oid = ix.indrelid
WHERE t.relname = $1 AND i.relkind = 'i' AND ix.indisprimary = false
`, tableName)
if err != nil {
return nil, err
}
for idxRows.Next() {
var idxName string
var isUnique bool
var indexDef string
if err := idxRows.Scan(&idxName, &isUnique, &indexDef); err != nil {
idxRows.Close()
return nil, err
}
// Parse columns from index definition
columns := parseIndexColumns(indexDef)
table.Indexes[idxName] = &Index{
Name: idxName,
Columns: columns,
Unique: isUnique,
}
}
idxRows.Close()
// Get foreign keys
fkRows, err := db.Query(ctx, `
SELECT conname, pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class cl ON cl.oid = c.conrelid
WHERE cl.relname = $1 AND c.contype = 'f'
`, tableName)
if err != nil {
return nil, err
}
for fkRows.Next() {
var fkName, def string
if err := fkRows.Scan(&fkName, &def); err != nil {
fkRows.Close()
return nil, err
}
// Parse constraint definition (simplified)
fk := parseForeignKey(def)
fk.Name = fkName
table.ForeignKeys[fkName] = fk
}
fkRows.Close()
schema.Tables[tableName] = table
}
return schema, nil
}
func parseIndexColumns(def string) []string {
// Example: "CREATE UNIQUE INDEX idx_name ON table_name (col1, col2)"
// Extract the part between parentheses
startIdx := strings.Index(def, "(")
endIdx := strings.LastIndex(def, ")")
if startIdx < 0 || endIdx < 0 {
return nil
}
colsStr := def[startIdx+1 : endIdx]
cols := strings.Split(colsStr, ",")
for i := range cols {
cols[i] = strings.TrimSpace(cols[i])
}
return cols
}
func parseForeignKey(def string) *ForeignKey {
fk := &ForeignKey{}
// Extract columns
if strings.Contains(def, "FOREIGN KEY") {
parts := strings.Split(def, "FOREIGN KEY")
if len(parts) >= 2 {
colPart := strings.TrimSpace(parts[1])
// Get columns before REFERENCES
refIdx := strings.Index(colPart, "REFERENCES")
if refIdx >= 0 {
colDef := strings.TrimSpace(colPart[:refIdx])
colDef = strings.Trim(colDef, "()")
fk.Columns = strings.Split(colDef, ",")
for i := range fk.Columns {
fk.Columns[i] = strings.TrimSpace(fk.Columns[i])
}
// Get referenced table and columns
refPart := strings.TrimSpace(colPart[refIdx+10:])
refParts := strings.Fields(refPart)
if len(refParts) >= 1 {
fk.RefTable = strings.Trim(refParts[0], "()")
if len(refParts) >= 2 {
fk.RefColumns = strings.Split(strings.Trim(refParts[1], "()"), ",")
for i := range fk.RefColumns {
fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i])
}
}
}
// Extract ON DELETE
if strings.Contains(def, "ON DELETE") {
onDeleteParts := strings.Split(def, "ON DELETE")
if len(onDeleteParts) >= 2 {
fk.OnDelete = strings.TrimSpace(strings.Fields(onDeleteParts[1])[0])
}
}
}
}
}
return fk
}
// LoadDesiredSchema loads schema definitions from SQL files in the given directory
func LoadDesiredSchema(dir string) (*Schema, error) {
return loadDesiredSchema(dir)
}
// GetMigrations compares the current database schema with the desired schema
// and returns a list of SQL migration statements
func GetMigrations(ctx context.Context, db *pgxpool.Pool, schemaDir string) ([]string, error) {
// Load current database schema
currentSchema, err := GetCurrentSchema(ctx, db)
if err != nil {
return nil, fmt.Errorf("failed to get current schema: %w", err)
}
// Load desired schema from files
desiredSchema, err := LoadDesiredSchema(schemaDir)
if err != nil {
return nil, fmt.Errorf("failed to load desired schema: %w", err)
}
// Compare and generate migration SQL
return CompareSchemas(currentSchema, desiredSchema), nil
}