Skip to content

Commit

Permalink
Allow adding a column to a table created in the same migration (xatai…
Browse files Browse the repository at this point in the history
…o#449)

Allow the `add_column` operation to add a column to a table that was
created by an operation earlier in the same migration.

The following migration would previously have failed to start:

```json
{
  "name": "43_multiple_ops",
  "operations": [
    {
      "create_table": {
        "name": "players",
        "columns": [
          {
            "name": "id",
            "type": "serial",
            "pk": true
          },
          {
            "name": "name",
            "type": "varchar(255)",
            "check": {
              "name": "name_length_check",
              "constraint": "length(name) > 2"
            }
          }
        ]
      }
    },
    {
      "add_column": {
        "table": "players",
        "column": {
          "name": "rating",
          "type": "integer",
          "comment": "hello world",
          "check": {
            "name": "rating_check",
            "constraint": "rating > 0 AND rating < 100"
          },
          "nullable": false
        }
      }
    }
  ]
}
```

As of this PR, the migration can be started.

The above migration does not validate yet, but it can be started
successfully with the `--skip-validation` flag to the `start` command.

Part of xataio#239
  • Loading branch information
andrew-farries authored and kvch committed Nov 11, 2024
1 parent 1d7c900 commit 8fee548
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string
}

if o.Column.Comment != nil {
if err := addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column.Name), o.Column.Comment); err != nil {
if err := addCommentToColumn(ctx, conn, table.Name, TemporaryName(o.Column.Name), o.Column.Comment); err != nil {
return nil, fmt.Errorf("failed to add comment to column: %w", err)
}
}

if !o.Column.IsNullable() && o.Column.Default == nil {
if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil {
if err := addNotNullConstraint(ctx, conn, table.Name, o.Column.Name, TemporaryName(o.Column.Name)); err != nil {
return nil, fmt.Errorf("failed to add not null constraint: %w", err)
}
}

if o.Column.Check != nil {
if err := o.addCheckConstraint(ctx, conn); err != nil {
if err := o.addCheckConstraint(ctx, table.Name, conn); err != nil {
return nil, fmt.Errorf("failed to add check constraint: %w", err)
}
}
Expand Down Expand Up @@ -231,9 +231,9 @@ func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physic
return err
}

func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error {
func (o *OpAddColumn) addCheckConstraint(ctx context.Context, tableName string, conn db.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(tableName),
pq.QuoteIdentifier(o.Column.Check.Name),
rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name),
))
Expand Down
68 changes: 68 additions & 0 deletions pkg/migrations/op_add_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,74 @@ func TestAddColumnDefaultTransformation(t *testing.T) {
}, roll.WithSQLTransformer(sqlTransformer))
}

func TestAddColumnToATableCreatedInTheSameMigration(t *testing.T) {
t.Parallel()

ExecuteTests(t, TestCases{
{
name: "add column to newly created table",
migrations: []migrations.Migration{
{
Name: "01_add_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "varchar(255)",
},
},
},
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "age",
Type: "integer",
Nullable: ptr(false),
Check: &migrations.CheckConstraint{
Name: "age_check",
Constraint: "age >= 18",
},
Comment: ptr("the age of the user"),
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Inserting into the new column on the new table works.
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "Alice", "age": "30",
})

// Inserting a value that doesn't meet the check constraint fails.
MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "Bob", "age": "8",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Inserting into the new column on the new table works.
MustInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "Bob", "age": "31",
})

// Inserting a value that doesn't meet the check constraint fails.
MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{
"name": "Carl", "age": "8",
}, testutils.CheckViolationErrorCode)
},
},
}, roll.WithSkipValidation(true)) // TODO: remove once this migration can be validated
}

func TestAddColumnInvalidNameLength(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 8fee548

Please sign in to comment.