From 4898b6ab27bc651ae584567a8fdc05a487bba7ed Mon Sep 17 00:00:00 2001 From: Nikola Petrov Date: Sun, 31 May 2026 16:03:55 +0200 Subject: [PATCH] init --- go.mod | 17 + go.sum | 30 ++ lib/lib.go | 1003 ++++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 121 +++++++ 4 files changed, 1171 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 lib/lib.go create mode 100644 main.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..01a2802 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module git.petrovv.com/go-migrate + +go 1.26.2 + +require ( + github.com/jackc/pgx/v5 v5.6.0 + github.com/joho/godotenv v1.5.1 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4aefed3 --- /dev/null +++ b/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/lib.go b/lib/lib.go new file mode 100644 index 0000000..3b61d23 --- /dev/null +++ b/lib/lib.go @@ -0,0 +1,1003 @@ +package lib + +import ( + "context" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" +) + +// Column represents a database column +type Column struct { + Name string + Type string + Nullable bool + Default string +} + +// Table represents a database table +type Table struct { + Name string + Columns map[string]*Column + PrimaryKey []string + Indexes map[string]*Index + ForeignKeys map[string]*ForeignKey +} + +// Index represents a database index +type Index struct { + Name string + Columns []string + Unique bool +} + +// ForeignKey represents a foreign key constraint +type ForeignKey struct { + Name string + Columns []string + RefTable string + RefColumns []string + OnDelete string +} + +// Schema represents the database schema +type Schema struct { + Types map[string]string + Tables map[string]*Table +} + +// SchemaFile represents a loaded schema file +type SchemaFile struct { + Name string + SQL string +} + +func isSerialType(t string) bool { + upper := strings.ToUpper(t) + return upper == "SERIAL" || upper == "BIGSERIAL" || upper == "SMALLSERIAL" || + upper == "SERIAL4" || upper == "SERIAL8" || upper == "SERIAL2" +} + +func normalizeType(t string) string { + upper := strings.ToUpper(t) + switch upper { + case "BIGSERIAL", "SERIAL8": + return "BIGINT" + case "SERIAL", "SERIAL4": + return "INTEGER" + case "SMALLSERIAL", "SERIAL2": + return "SMALLINT" + case "DOUBLE": + return "DOUBLE PRECISION" + case "TIMESTAMP WITH TIME ZONE", "TIMESTAMPTZ": + return "TIMESTAMP WITH TIME ZONE" + case "TIMESTAMP WITHOUT TIME ZONE": + return "TIMESTAMP" + default: + return upper + } +} + +func normalizeDefault(d string) string { + if d == "" { + return "" + } + // Remove PostgreSQL type casts like ::TEXT, ::VARCHAR, etc. + d = regexp.MustCompile(`::\w+`).ReplaceAllString(d, "") + // Normalize whitespace + d = strings.Join(strings.Fields(d), " ") + // Convert to uppercase for case-insensitive comparison of SQL functions/keywords + d = strings.ToUpper(d) + return d +} + +func columnsMatch(a, b *Column) bool { + if a == nil || b == nil { + return false + } + return normalizeType(a.Type) == normalizeType(b.Type) && a.Nullable == b.Nullable && normalizeDefault(a.Default) == normalizeDefault(b.Default) +} + +func lastIndexOfMatchingParen(s string, start int) int { + depth := 0 + for i := start; i < len(s); i++ { + if s[i] == '(' { + depth++ + } else if s[i] == ')' { + depth-- + if depth == 0 { + return i + } + } + } + return -1 +} + +func splitByCommaRespectParen(s string) []string { + var parts []string + depth := 0 + start := 0 + for i, c := range s { + if c == '(' { + depth++ + } else if c == ')' { + depth-- + } else if c == ',' && depth == 0 { + parts = append(parts, strings.TrimSpace(s[start:i])) + start = i + 1 + } + } + if start < len(s) { + parts = append(parts, strings.TrimSpace(s[start:])) + } + return parts +} + +func getDependencyOrder(schema *Schema) []string { + // Simple topological sort based on foreign key dependencies + var order []string + visited := make(map[string]bool) + + var visit func(name string) + visit = func(name string) { + if visited[name] { + return + } + visited[name] = true + + // Visit dependencies first + if table, exists := schema.Tables[name]; exists { + for _, fk := range table.ForeignKeys { + if _, depExists := schema.Tables[fk.RefTable]; depExists { + visit(fk.RefTable) + } + } + } + + order = append(order, name) + } + + for tableName := range schema.Tables { + visit(tableName) + } + + return order +} + +func parseInlineForeignKey(def string) *ForeignKey { + fk := &ForeignKey{} + + // Extract FOREIGN KEY (col1, col2) REFERENCES table(col1, col2) + fkIdx := strings.Index(strings.ToUpper(def), "FOREIGN KEY") + if fkIdx < 0 { + return nil + } + + refIdx := strings.Index(strings.ToUpper(def), "REFERENCES") + if refIdx < 0 { + return nil + } + + // Extract columns + fkPart := strings.TrimSpace(def[fkIdx+11 : refIdx]) + fkPart = strings.Trim(fkPart, "()") + fk.Columns = strings.Split(fkPart, ",") + for i := range fk.Columns { + fk.Columns[i] = strings.TrimSpace(fk.Columns[i]) + } + + // Extract referenced table and columns + refPart := strings.TrimSpace(def[refIdx+10:]) + parenStart := strings.Index(refPart, "(") + if parenStart > 0 { + fk.RefTable = strings.TrimSpace(refPart[:parenStart]) + // Find matching closing parenthesis + depth := 1 + parenEnd := -1 + for i := parenStart + 1; i < len(refPart); i++ { + if refPart[i] == '(' { + depth++ + } else if refPart[i] == ')' { + depth-- + if depth == 0 { + parenEnd = i + break + } + } + } + if parenEnd > parenStart { + colsStr := refPart[parenStart+1 : parenEnd] + fk.RefColumns = strings.Split(colsStr, ",") + for i := range fk.RefColumns { + fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i]) + } + // Extract ON DELETE if present + if parenEnd+1 < len(refPart) { + onDeletePart := strings.ToUpper(strings.TrimSpace(refPart[parenEnd+1:])) + if strings.HasPrefix(onDeletePart, "ON DELETE") { + fields := strings.Fields(onDeletePart) + if len(fields) >= 3 { + fk.OnDelete = fields[2] + } + } + } + } + } + + return fk +} + +func parseCreateTable(schema *Schema, stmt string) { + // Extract table name + parts := strings.Fields(stmt) + if len(parts) < 3 { + return + } + + tableName := strings.Trim(parts[2], "()") + if tableName == "" { + return + } + + // Find the opening parenthesis + startIdx := strings.Index(stmt, "(") + endIdx := lastIndexOfMatchingParen(stmt, startIdx) + if startIdx < 0 || endIdx < 0 { + return + } + + tableDef := stmt[startIdx+1 : endIdx] + + table := &Table{ + Name: tableName, + Columns: make(map[string]*Column), + PrimaryKey: []string{}, + Indexes: make(map[string]*Index), + ForeignKeys: make(map[string]*ForeignKey), + } + + // Parse table definition + parseTableDefinition(table, tableDef) + + schema.Tables[tableName] = table +} + +func parseTableDefinition(table *Table, def string) { + // Split by commas, respecting parentheses + parts := splitByCommaRespectParen(def) + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + upperPart := strings.ToUpper(part) + + // PRIMARY KEY constraint + if strings.HasPrefix(upperPart, "PRIMARY KEY") { + pkDef := strings.TrimPrefix(part, "PRIMARY KEY") + pkDef = strings.TrimSpace(pkDef) + // Remove outer parentheses + pkDef = strings.TrimLeft(pkDef, "(") + pkDef = strings.TrimRight(pkDef, ")") + pkDef = strings.TrimSpace(pkDef) + pkCols := strings.Split(pkDef, ",") + for _, col := range pkCols { + table.PrimaryKey = append(table.PrimaryKey, strings.TrimSpace(col)) + } + continue + } + + // FOREIGN KEY constraint (inline) + if strings.Contains(upperPart, "FOREIGN KEY") { + fk := parseInlineForeignKey(part) + if fk != nil { + // Generate a name for the constraint + name := fmt.Sprintf("fk_%s_%s", table.Name, strings.Join(fk.Columns, "_")) + fk.Name = name + table.ForeignKeys[name] = fk + } + continue + } + + // Check for inline PRIMARY KEY (e.g., "id BIGSERIAL PRIMARY KEY") + if strings.Contains(upperPart, "PRIMARY KEY") && !strings.HasPrefix(upperPart, "PRIMARY KEY") { + // Extract column name from inline PRIMARY KEY definition + fields := strings.Fields(part) + if len(fields) >= 3 { + colName := fields[0] + table.PrimaryKey = append(table.PrimaryKey, colName) + // Remove PRIMARY KEY tokens for column parsing + var newFields []string + for i := 0; i < len(fields); i++ { + if i < len(fields)-1 && strings.ToUpper(fields[i]) == "PRIMARY" && strings.ToUpper(fields[i+1]) == "KEY" { + i++ + continue + } + newFields = append(newFields, fields[i]) + } + part = strings.Join(newFields, " ") + } + } + + // Column definition + col := parseColumn(part) + if col != nil { + table.Columns[col.Name] = col + } + } + + // PRIMARY KEY columns must be NOT NULL + for _, pkCol := range table.PrimaryKey { + if col, exists := table.Columns[pkCol]; exists { + col.Nullable = false + } + } +} + +func parseColumn(def string) *Column { + col := &Column{Nullable: true} + + // Split by space, respecting quotes and parentheses + parts := strings.Fields(def) + if len(parts) == 0 { + return nil + } + + col.Name = parts[0] + if len(parts) < 2 { + return col + } + + // Collect type - can be multiple words (e.g., DOUBLE PRECISION) + // Type continues until we hit a constraint keyword + var typeParts []string + typeParts = append(typeParts, parts[1]) + + i := 2 + for i < len(parts) { + p := strings.ToUpper(parts[i]) + // Stop at constraint keywords + if p == "NOT" || p == "NULL" || p == "DEFAULT" || p == "REFERENCES" || + p == "PRIMARY" || p == "FOREIGN" || p == "UNIQUE" || p == "CHECK" || p == "CONSTRAINT" { + break + } + typeParts = append(typeParts, parts[i]) + i++ + } + col.Type = strings.Join(typeParts, " ") + + // Check for NOT NULL and other constraints + for ; i < len(parts); i++ { + p := strings.ToUpper(parts[i]) + if p == "NOT" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "NULL" { + col.Nullable = false + i++ + continue + } + if p == "NULL" { + col.Nullable = true + continue + } + if strings.HasPrefix(p, "DEFAULT") { + if i+1 < len(parts) { + col.Default = strings.Join(parts[i+1:], " ") + } + break + } + } + + return col +} + +func parseCreateIndex(schema *Schema, stmt string) { + isUnique := strings.Contains(strings.ToUpper(stmt), "UNIQUE") + + // Extract index name + parts := strings.Fields(stmt) + if len(parts) < 4 { + return + } + + idxName := parts[2] + if strings.ToUpper(parts[1]) == "UNIQUE" { + idxName = parts[3] + } + + // Find ON + onIdx := -1 + for i, p := range parts { + if strings.ToUpper(p) == "ON" { + onIdx = i + break + } + } + + if onIdx < 0 || onIdx+1 >= len(parts) { + return + } + + tableName := parts[onIdx+1] + + // Find columns inside parentheses after table name + var columns []string + // Find the opening parenthesis after the table name + parenStart := strings.Index(stmt, "(") + if parenStart >= 0 { + // Find the matching closing parenthesis + depth := 0 + parenEnd := -1 + for i := parenStart; i < len(stmt); i++ { + if stmt[i] == '(' { + depth++ + } else if stmt[i] == ')' { + depth-- + if depth == 0 { + parenEnd = i + break + } + } + } + if parenEnd > parenStart { + colsStr := stmt[parenStart+1 : parenEnd] + columns = strings.Split(colsStr, ",") + for i := range columns { + columns[i] = strings.TrimSpace(columns[i]) + } + } + } + + if _, exists := schema.Tables[tableName]; !exists { + schema.Tables[tableName] = &Table{ + Name: tableName, + Columns: make(map[string]*Column), + Indexes: make(map[string]*Index), + ForeignKeys: make(map[string]*ForeignKey), + } + } + + schema.Tables[tableName].Indexes[idxName] = &Index{ + Name: idxName, + Columns: columns, + Unique: isUnique, + } +} + +func parseAlterTableAddConstraint(schema *Schema, stmt string) { + // ALTER TABLE table ADD CONSTRAINT name FOREIGN KEY (col) REFERENCES table(col) + parts := strings.Fields(stmt) + if len(parts) < 6 { + return + } + + tableName := parts[2] + constraintName := parts[5] + + if _, exists := schema.Tables[tableName]; !exists { + schema.Tables[tableName] = &Table{ + Name: tableName, + Columns: make(map[string]*Column), + Indexes: make(map[string]*Index), + ForeignKeys: make(map[string]*ForeignKey), + } + } + + // Parse the constraint definition + // Find the constraint type (FOREIGN KEY, etc.) + constraintType := "" + for i, p := range parts { + if strings.ToUpper(p) == "FOREIGN" && i+1 < len(parts) && strings.ToUpper(parts[i+1]) == "KEY" { + constraintType = "FOREIGN KEY" + break + } + } + + if constraintType == "FOREIGN KEY" { + // Extract the full definition + def := strings.Join(parts[6:], " ") + fk := parseInlineForeignKey(def) + if fk != nil { + fk.Name = constraintName + schema.Tables[tableName].ForeignKeys[constraintName] = fk + } + } +} + +func parseSchemaFile(schema *Schema, sql string) { + // Split into statements + statements := strings.Split(sql, ";") + + for _, stmt := range statements { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + + upperStmt := strings.ToUpper(stmt) + + // CREATE TYPE + if strings.HasPrefix(upperStmt, "CREATE TYPE") { + parts := strings.Fields(stmt) + if len(parts) >= 3 { + typeName := parts[2] + // Store the full CREATE TYPE statement for later generation + schema.Types[typeName] = stmt + } + } + + // CREATE TABLE + if strings.HasPrefix(upperStmt, "CREATE TABLE") { + parseCreateTable(schema, stmt) + } + + // CREATE INDEX + if strings.HasPrefix(upperStmt, "CREATE") && strings.Contains(upperStmt, "INDEX") { + parseCreateIndex(schema, stmt) + } + + // ALTER TABLE ... ADD CONSTRAINT + if strings.HasPrefix(upperStmt, "ALTER TABLE") && strings.Contains(upperStmt, "ADD CONSTRAINT") { + parseAlterTableAddConstraint(schema, stmt) + } + } +} + +func loadSchemaFiles(dir string) ([]SchemaFile, error) { + files, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + var schemaFiles []SchemaFile + for _, f := range files { + if !f.IsDir() && strings.HasSuffix(f.Name(), ".sql") { + content, err := os.ReadFile(filepath.Join(dir, f.Name())) + if err != nil { + return nil, err + } + schemaFiles = append(schemaFiles, SchemaFile{ + Name: f.Name(), + SQL: string(content), + }) + } + } + + // Sort by filename (natural sort for numbered files) + sort.Slice(schemaFiles, func(i, j int) bool { + return schemaFiles[i].Name < schemaFiles[j].Name + }) + + return schemaFiles, nil +} + +func loadDesiredSchema(dir string) (*Schema, error) { + files, err := loadSchemaFiles(dir) + if err != nil { + return nil, err + } + + schema := &Schema{ + Types: make(map[string]string), + Tables: make(map[string]*Table), + } + + for _, f := range files { + parseSchemaFile(schema, f.SQL) + } + + return schema, nil +} + +func generateCreateTableSQL(table *Table) string { + var lines []string + + // Table header + lines = append(lines, fmt.Sprintf("CREATE TABLE %s (", table.Name)) + + // Columns - all get trailing commas + for colName, col := range table.Columns { + colSQL := fmt.Sprintf(" %s %s", colName, col.Type) + if !col.Nullable { + colSQL += " NOT NULL" + } + if col.Default != "" { + colSQL += fmt.Sprintf(" DEFAULT %s", col.Default) + } + lines = append(lines, colSQL+",") + } + + // Build list of constraints (PRIMARY KEY + UNIQUE constraints) + var constraints []string + if len(table.PrimaryKey) > 0 { + pkCols := make([]string, len(table.PrimaryKey)) + for i, col := range table.PrimaryKey { + pkCols[i] = strings.TrimSpace(strings.Trim(col, "()")) + } + constraints = append(constraints, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(pkCols, ", "))) + } + for _, idx := range table.Indexes { + if idx.Unique && len(idx.Columns) > 0 { + constraints = append(constraints, fmt.Sprintf("UNIQUE (%s)", strings.Join(idx.Columns, ", "))) + } + } + + // Add constraints with commas between them + for i, constraint := range constraints { + if i < len(constraints)-1 { + lines = append(lines, fmt.Sprintf(" %s,", constraint)) + } else { + lines = append(lines, fmt.Sprintf(" %s", constraint)) + } + } + + lines = append(lines, ");") + + return strings.Join(lines, "\n") +} + +func generateAddColumnSQL(tableName, colName string, col *Column) string { + colSQL := fmt.Sprintf("%s %s", colName, col.Type) + if !col.Nullable { + colSQL += " NOT NULL" + } + if col.Default != "" { + colSQL += fmt.Sprintf(" DEFAULT %s", col.Default) + } + return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s;", tableName, colSQL) +} + +func generateAddForeignKeySQL(tableName string, fk *ForeignKey) string { + colStr := strings.Join(fk.Columns, ", ") + refStr := fmt.Sprintf("%s(%s)", fk.RefTable, strings.Join(fk.RefColumns, ", ")) + + constraint := fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s", colStr, refStr) + if fk.OnDelete != "" { + constraint += fmt.Sprintf(" ON DELETE %s", fk.OnDelete) + } + + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", tableName, fk.Name, constraint) +} + +func generateCreateIndexSQL(tableName string, idx *Index) string { + unique := "" + if idx.Unique { + unique = "UNIQUE " + } + cols := strings.Join(idx.Columns, ", ") + return fmt.Sprintf("CREATE %sINDEX %s ON %s (%s);", unique, idx.Name, tableName, cols) +} + +func CompareSchemas(current, desired *Schema) []string { + var migrations []string + + // Create missing types + for typeName, typeStmt := range desired.Types { + if _, exists := current.Types[typeName]; !exists { + migrations = append(migrations, typeStmt+";") + } + } + + // Process tables in dependency order + tableOrder := getDependencyOrder(desired) + + for _, tableName := range tableOrder { + desiredTable := desired.Tables[tableName] + currentTable, exists := current.Tables[tableName] + + if !exists { + // Create table with all constraints and unique indexes + migrations = append(migrations, generateCreateTableSQL(desiredTable)) + // Add foreign keys as separate ALTER TABLE statements (PostgreSQL doesn't support inline FK in CREATE TABLE with existing types) + for _, desiredFK := range desiredTable.ForeignKeys { + migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK)) + } + // Add non-unique indexes as separate CREATE INDEX statements + for _, desiredIdx := range desiredTable.Indexes { + if !desiredIdx.Unique { + migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx)) + } + } + continue + } + + // Compare columns + for colName, desiredCol := range desiredTable.Columns { + currentCol, colExists := currentTable.Columns[colName] + + if !colExists { + migrations = append(migrations, generateAddColumnSQL(tableName, colName, desiredCol)) + } else if !columnsMatch(currentCol, desiredCol) { + // Generate only the specific changes needed + if currentCol.Nullable != desiredCol.Nullable { + if desiredCol.Nullable { + migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP NOT NULL;", tableName, colName)) + } else { + migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET NOT NULL;", tableName, colName)) + } + } + // Skip default comparison for serial types (implicit defaults) + if !isSerialType(currentCol.Type) && !isSerialType(desiredCol.Type) && currentCol.Default != desiredCol.Default { + if desiredCol.Default != "" { + migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", tableName, colName, desiredCol.Default)) + } else { + migrations = append(migrations, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", tableName, colName)) + } + } + // Type changes are skipped for simplicity + } + } + + // Compare foreign keys + for fkName, desiredFK := range desiredTable.ForeignKeys { + if _, exists := currentTable.ForeignKeys[fkName]; !exists { + migrations = append(migrations, generateAddForeignKeySQL(tableName, desiredFK)) + } + } + + // Compare indexes - only generate for non-unique (unique are handled in CREATE TABLE) + for idxName, desiredIdx := range desiredTable.Indexes { + if _, exists := currentTable.Indexes[idxName]; !exists { + if !desiredIdx.Unique { + migrations = append(migrations, generateCreateIndexSQL(tableName, desiredIdx)) + } + } + } + } + + return migrations +} + +func GetCurrentSchema(ctx context.Context, db *pgxpool.Pool) (*Schema, error) { + schema := &Schema{ + Types: make(map[string]string), + Tables: make(map[string]*Table), + } + + // Get custom types (enums) + rows, err := db.Query(ctx, ` + SELECT typname + FROM pg_type t + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE n.nspname = 'public' AND t.typtype = 'e' + `) + if err != nil { + return nil, err + } + for rows.Next() { + var typeName string + if err := rows.Scan(&typeName); err != nil { + rows.Close() + return nil, err + } + schema.Types[typeName] = "enum" + } + rows.Close() + + // Get tables + tableRows, err := db.Query(ctx, ` + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' AND table_type = 'BASE TABLE' + ORDER BY table_name + `) + if err != nil { + return nil, err + } + defer tableRows.Close() + + for tableRows.Next() { + var tableName string + if err := tableRows.Scan(&tableName); err != nil { + return nil, err + } + + table := &Table{ + Name: tableName, + Columns: make(map[string]*Column), + PrimaryKey: []string{}, + Indexes: make(map[string]*Index), + ForeignKeys: make(map[string]*ForeignKey), + } + + // Get columns + colRows, err := db.Query(ctx, ` + SELECT column_name, data_type, udt_name, is_nullable, column_default + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 + ORDER BY ordinal_position + `, tableName) + if err != nil { + return nil, err + } + for colRows.Next() { + var col Column + var dataType, udtName, nullableStr string + var defaultVal *string + if err := colRows.Scan(&col.Name, &dataType, &udtName, &nullableStr, &defaultVal); err != nil { + colRows.Close() + return nil, err + } + // For user-defined types (enums), use udt_name instead of data_type + if dataType == "USER-DEFINED" && udtName != "" { + col.Type = udtName + } else { + col.Type = dataType + } + col.Nullable = nullableStr == "YES" + if defaultVal != nil { + col.Default = strings.TrimSpace(*defaultVal) + } else { + col.Default = "" + } + table.Columns[col.Name] = &col + } + colRows.Close() + + // Get primary key + pkRows, err := db.Query(ctx, ` + SELECT a.attname + FROM pg_index ix + JOIN pg_attribute a ON a.attrelid = ix.indrelid AND a.attnum = ANY(ix.indkey) + JOIN pg_class c ON c.oid = ix.indrelid + WHERE c.relname = $1 AND ix.indisprimary + ORDER BY a.attnum + `, tableName) + if err != nil { + return nil, err + } + for pkRows.Next() { + var colName string + if err := pkRows.Scan(&colName); err != nil { + pkRows.Close() + return nil, err + } + table.PrimaryKey = append(table.PrimaryKey, colName) + } + pkRows.Close() + + // Get indexes + idxRows, err := db.Query(ctx, ` + SELECT i.relname, ix.indisunique, pg_get_indexdef(ix.indexrelid) + FROM pg_index ix + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN pg_class t ON t.oid = ix.indrelid + WHERE t.relname = $1 AND i.relkind = 'i' AND ix.indisprimary = false + `, tableName) + if err != nil { + return nil, err + } + for idxRows.Next() { + var idxName string + var isUnique bool + var indexDef string + if err := idxRows.Scan(&idxName, &isUnique, &indexDef); err != nil { + idxRows.Close() + return nil, err + } + // Parse columns from index definition + columns := parseIndexColumns(indexDef) + + table.Indexes[idxName] = &Index{ + Name: idxName, + Columns: columns, + Unique: isUnique, + } + } + idxRows.Close() + + // Get foreign keys + fkRows, err := db.Query(ctx, ` + SELECT conname, pg_get_constraintdef(c.oid) + FROM pg_constraint c + JOIN pg_class cl ON cl.oid = c.conrelid + WHERE cl.relname = $1 AND c.contype = 'f' + `, tableName) + if err != nil { + return nil, err + } + for fkRows.Next() { + var fkName, def string + if err := fkRows.Scan(&fkName, &def); err != nil { + fkRows.Close() + return nil, err + } + // Parse constraint definition (simplified) + fk := parseForeignKey(def) + fk.Name = fkName + table.ForeignKeys[fkName] = fk + } + fkRows.Close() + + schema.Tables[tableName] = table + } + + return schema, nil +} + +func parseIndexColumns(def string) []string { + // Example: "CREATE UNIQUE INDEX idx_name ON table_name (col1, col2)" + // Extract the part between parentheses + startIdx := strings.Index(def, "(") + endIdx := strings.LastIndex(def, ")") + if startIdx < 0 || endIdx < 0 { + return nil + } + colsStr := def[startIdx+1 : endIdx] + cols := strings.Split(colsStr, ",") + for i := range cols { + cols[i] = strings.TrimSpace(cols[i]) + } + return cols +} + +func parseForeignKey(def string) *ForeignKey { + fk := &ForeignKey{} + + // Extract columns + if strings.Contains(def, "FOREIGN KEY") { + parts := strings.Split(def, "FOREIGN KEY") + if len(parts) >= 2 { + colPart := strings.TrimSpace(parts[1]) + // Get columns before REFERENCES + refIdx := strings.Index(colPart, "REFERENCES") + if refIdx >= 0 { + colDef := strings.TrimSpace(colPart[:refIdx]) + colDef = strings.Trim(colDef, "()") + fk.Columns = strings.Split(colDef, ",") + for i := range fk.Columns { + fk.Columns[i] = strings.TrimSpace(fk.Columns[i]) + } + + // Get referenced table and columns + refPart := strings.TrimSpace(colPart[refIdx+10:]) + refParts := strings.Fields(refPart) + if len(refParts) >= 1 { + fk.RefTable = strings.Trim(refParts[0], "()") + if len(refParts) >= 2 { + fk.RefColumns = strings.Split(strings.Trim(refParts[1], "()"), ",") + for i := range fk.RefColumns { + fk.RefColumns[i] = strings.TrimSpace(fk.RefColumns[i]) + } + } + } + + // Extract ON DELETE + if strings.Contains(def, "ON DELETE") { + onDeleteParts := strings.Split(def, "ON DELETE") + if len(onDeleteParts) >= 2 { + fk.OnDelete = strings.TrimSpace(strings.Fields(onDeleteParts[1])[0]) + } + } + } + } + } + + return fk +} + +func LoadDesiredSchema(dir string) (*Schema, error) { + return loadDesiredSchema(dir) +} + +func GetMigrations(ctx context.Context, db *pgxpool.Pool, schemaDir string) ([]string, error) { + // Load current database schema + currentSchema, err := GetCurrentSchema(ctx, db) + if err != nil { + return nil, fmt.Errorf("failed to get current schema: %w", err) + } + + // Load desired schema from files + desiredSchema, err := LoadDesiredSchema(schemaDir) + if err != nil { + return nil, fmt.Errorf("failed to load desired schema: %w", err) + } + + // Compare and generate migration SQL + return CompareSchemas(currentSchema, desiredSchema), nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..274c0c3 --- /dev/null +++ b/main.go @@ -0,0 +1,121 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + + "git.petrovv.com/go-migrate/lib" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/joho/godotenv" +) + +func main() { + // Parse flags + schemaDir := flag.String("dir", "schema", "Directory containing desired schema SQL files") + cmd := flag.String("cmd", "", "Command: generate or apply") + flag.Parse() + + if *cmd == "" { + fmt.Println("Usage:") + fmt.Println(" go run main.go -cmd=generate -dir=schema - Compare schema and generate migrations") + fmt.Println(" go run main.go -cmd=apply -dir=schema - Apply pending migrations") + fmt.Println("\nFlags:") + flag.PrintDefaults() + return + } + + switch *cmd { + case "generate": + generateMigrations(*schemaDir) + case "apply": + applyMigrations(*schemaDir) + default: + fmt.Println("Unknown command:", *cmd) + fmt.Println("Use 'generate' or 'apply'") + } +} + +func generateMigrations(schemaDir string) { + _ = godotenv.Load() + db := getDBConnection() + defer db.Close() + ctx := context.Background() + + migrations, err := lib.GetMigrations(ctx, db, schemaDir) + if err != nil { + log.Fatal(err) + } + + if len(migrations) == 0 { + fmt.Println("✓ Database schema is up to date") + return + } + + fmt.Println("-- Migration SQL") + for _, m := range migrations { + fmt.Println(m) + } +} + +func applyMigrations(schemaDir string) { + _ = godotenv.Load() + db := getDBConnection() + defer db.Close() + ctx := context.Background() + + migrations, err := lib.GetMigrations(ctx, db, schemaDir) + if err != nil { + log.Fatal(err) + } + + if len(migrations) == 0 { + fmt.Println("✓ Database schema is up to date") + return + } + + fmt.Println("Applying migrations...") + for _, m := range migrations { + fmt.Printf("Executing: %s\n", m) + _, err := db.Exec(ctx, m) + if err != nil { + log.Fatalf("failed to execute %s: %v", m, err) + } + } + + fmt.Println("✓ All migrations applied successfully") +} + +func getDBConnection() *pgxpool.Pool { + dbUser := os.Getenv("DB_USER") + dbPass := os.Getenv("DB_PASS") + dbHost := os.Getenv("DB_HOST") + dbPort := os.Getenv("DB_PORT") + dbName := os.Getenv("DB_NAME") + + if dbUser == "" { + dbUser = "hengspot_user" + } + if dbPass == "" { + dbPass = "securepassword" + } + if dbHost == "" { + dbHost = "localhost" + } + if dbPort == "" { + dbPort = "5432" + } + if dbName == "" { + dbName = "hengspot_db" + } + + dsn := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dbUser, dbPass, dbHost, dbPort, dbName) + + pool, err := pgxpool.New(context.Background(), dsn) + if err != nil { + log.Fatalf("Unable to connect to database: %v", err) + } + return pool +}