8000 Add support for composite keys in `create table` statements by kvch · Pull Request #413 · xataio/pgroll · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add support for composite keys in create table statements #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/01_create_tables.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@
}
]
}
},
{
"create_table": {
"name": "sellers",
"columns": [
{
"name": "name",
"type": "varchar(255)",
"pk": true
},
{
"name": "zip",
"type": "integer",
"pk": true
},
{
"name": "description",
"type": "varchar(255)",
"nullable": true
}
]
}
}
]
}
51 changes: 50 additions & 1 deletion pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
o.Column.Check = nil

o.Column.Name = TemporaryName(o.Column.Name)
colSQL, err := ColumnToSQL(o.Column, tr)
columnWriter := ColumnSQLWriter{WithPK: true, Transformer: tr}
colSQL, err := columnWriter.Write(o.Column)
if err != nil {
return err
}
Expand Down Expand Up @@ -243,3 +244,51 @@ func NotNullConstraintName(columnName string) string {
func IsNotNullConstraintName(name string) bool {
return strings.HasPrefix(name, "_pgroll_check_not_null_")
}

// ColumnSQLWriter writes a column to SQL
// It can optionally include the primary key constraint
// When creating a table, the primary key constraint is not added to the column definition
type ColumnSQLWriter struct {
WithPK bool
Transformer SQLTransformer
}

func (w ColumnSQLWriter) Write(col Column) (string, error) {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)

if w.WithPK && col.IsPrimaryKey() {
sql += " PRIMARY KEY"
}

if col.IsUnique() {
sql += " UNIQUE"
}
if !col.IsNullable() {
sql += " NOT NULL"
}
if col.Default != nil {
d, err := w.Transformer.TransformSQL(*col.Default)
if err != nil {
return "", err
}
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
if col.References.OnDelete != "" {
>
}

sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
pq.QuoteIdentifier(col.References.Name),
pq.QuoteIdentifier(col.References.Table),
pq.QuoteIdentifier(col.References.Column),
onDelete)
}
if col.Check != nil {
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
pq.QuoteIdentifier(col.Check.Name),
col.Check.Constraint)
}
return sql, nil
}
29 changes: 29 additions & 0 deletions pkg/migrations/op_common_test.go
8000
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ func ColumnMustNotHaveComment(t *testing.T, db *sql.DB, schema, table, column st
}
}

func ColumnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) {
t.Helper()
if !columnMustBePK(t, db, schema, table, column) {
t.Fatalf("Expected column %q to be primary key", column)
}
}

func TableMustHaveComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) {
t.Helper()
if !tableHasComment(t, db, schema, table, expectedComment) {
Expand Down Expand Up @@ -526,6 +533,28 @@ func columnHasComment(t *testing.T, db *sql.DB, schema, table, column string, ex
return actualComment != nil && *expectedComment == *actualComment
}

func columnMustBePK(t *testing.T, db *sql.DB, schema, table, column string) bool {
t.Helper()

var exists bool
err := db.QueryRow(fmt.Sprintf(`
SELECT EXISTS (
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = %[1]s::regclass AND i.indisprimary AND a.attname = %[2]s
)`,
pq.QuoteLiteral(fmt.Sprintf("%s.%s", schema, table)),
pq.QuoteLiteral(column)),
).Scan(&exists)
if err != nil {
t.Fatal(err)
}

return exists
}

func tableHasComment(t *testing.T, db *sql.DB, schema, table, expectedComment string) bool {
t.Helper()

Expand Down
44 changes: 7 additions & 37 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,55 +113,25 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {

func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) {
var sql string
var primaryKeys []string
columnWriter := ColumnSQLWriter{WithPK: false, Transformer: tr}
for i, col := range cols {
if i > 0 {
sql += ", "
}
colSQL, err := ColumnToSQL(col, tr)
colSQL, err := columnWriter.Write(col)
if err != nil {
return "", err
}
sql += colSQL
}
return sql, nil
}

// ColumnToSQL generates the SQL for a column definition.
func ColumnToSQL(col Column, tr SQLTransformer) (string, error) {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)

if col.IsPrimaryKey() {
sql += " PRIMARY KEY"
}
if col.IsUnique() {
sql += " UNIQUE"
}
if !col.IsNullable() {
sql += " NOT NULL"
}
if col.Default != nil {
d, err := tr.TransformSQL(*col.Default)
if err != nil {
return "", err
if col.IsPrimaryKey() {
primaryKeys = append(primaryKeys, pq.QuoteIdentifier(col.Name))
}
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
if col.References.OnDelete != "" {
>
}

sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s) ON DELETE %s",
pq.QuoteIdentifier(col.References.Name),
pq.QuoteIdentifier(col.References.Table),
pq.QuoteIdentifier(col.References.Column),
onDelete)
}
if col.Check != nil {
sql += fmt.Sprintf(" CONSTRAINT %s CHECK (%s)",
pq.QuoteIdentifier(col.Check.Name),
col.Check.Constraint)
if len(primaryKeys) > 0 {
sql += fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))
}
return sql, nil
}
83 changes: 83 additions & 0 deletions pkg/migrations/op_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,89 @@ func TestCreateTable(t *testing.T) {
}, rows)
},
},
{
name: "create table with composite key",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "rand",
Type: "varchar(255)",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
Unique: ptr(true),
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The new view exists in the new version schema.
ViewMustExist(t, db, schema, "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"rand": "123",
"name": "Alice",
})
// New record with same keys cannot be inserted.
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
"rand": "123",
"name": "Malice",
}, testutils.UniqueViolationErrorCode)

// Data can be retrieved from the new view.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "rand": "123", "name": "Alice"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The underlying table has been dropped.
TableMustNotExist(t, db, schema, "users")
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// The view still exists
ViewMustExist(t, db, schema, "01_create_table", "users")

// The columns are still primary keys.
ColumnMustBePK(t, db, schema, "users", "id")
ColumnMustBePK(t, db, schema, "users", "rand")

// Data can be inserted into the new view.
MustInsert(t, db, schema, "01_create_table", "users", map[string]string{
"rand": "123",
"name": "Alice",
})

// New record with same keys cannot be inserted.
MustNotInsert(t, db, schema, "01_create_table", "users", map[string]string{
"id": "1",
"rand": "123",
"name": "Malice",
}, testutils.UniqueViolationErrorCode)

// Data can be retrieved from the new view.
rows := MustSelect(t, db, schema, "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "rand": "123", "name": "Alice"},
}, rows)
},
},
{
name: "create table with foreign key with default ON DELETE NO ACTION",
migrations: []migrations.Migration{
Expand Down
0