Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow conflict targets in ON CONFLICT DO NOTHING #422

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/inserting.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* [Insert Structs](#insert-structs)
* [Insert Map](#insert-map)
* [Insert From Query](#insert-from-query)
* [On Conflict](#onconflict)
* [Returning](#returning)
* [SetError](#seterror)
* [Executing](#executing)
Expand Down Expand Up @@ -385,6 +386,41 @@ Output:
INSERT INTO "user" ("first_name", "last_name") SELECT "fn", "ln" FROM "other_table" []
```

<a name="onconflict"></a>
**On Conflict Clause**

You can handle conflicts using the `OnConflict` clause. For example, to ignore conflicts, you can use `DoNothing()`.

```go
sql, _, _ := goqu.Insert("test").
Rows(goqu.Record{"a": "a", "b": "b"}).
OnConflict(goqu.DoNothing()).
ToSQL()
fmt.Println(sql)
```

Output:

```
INSERT INTO "test" ("a", "b") VALUES ('a', 'b') ON CONFLICT DO NOTHING
```

To specify columns to be used in the conflict handling

```go
sql, _, _ := goqu.Insert("test").
Rows(goqu.Record{"a": "a", "b": "b"}).
OnConflict(goqu.DoNothing().SetCols(exp.NewColumnListExpression("a", "b"))).
ToSQL()
fmt.Println(sql)
```

Output:

```
INSERT INTO "test" ("a", "b") VALUES ('a', 'b') ON CONFLICT ("a", "b") DO NOTHING
```

<a name="returning"></a>
**Returning Clause**

Expand Down
15 changes: 13 additions & 2 deletions exp/conflict.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package exp

type (
doNothingConflict struct{}
doNothingConflict struct {
cols ColumnListExpression
}
// ConflictUpdate is the struct that represents the UPDATE fragment of an
// INSERT ... ON CONFLICT/ON DUPLICATE KEY DO UPDATE statement
conflictUpdate struct {
Expand All @@ -14,7 +16,7 @@ type (
// Creates a conflict struct to be passed to InsertConflict to ignore constraint errors
//
// InsertConflict(DoNothing(),...) -> INSERT INTO ... ON CONFLICT DO NOTHING
func NewDoNothingConflictExpression() ConflictExpression {
func NewDoNothingConflictExpression() ConflictNothingExpression {
return &doNothingConflict{}
}

Expand All @@ -30,6 +32,15 @@ func (c doNothingConflict) Action() ConflictAction {
return DoNothingConflictAction
}

func (c *doNothingConflict) SetCols(cl ColumnListExpression) ConflictNothingExpression {
c.cols = cl
return c
}

func (c doNothingConflict) Cols() ColumnListExpression {
return c.cols
}

// Creates a ConflictUpdate struct to be passed to InsertConflict
// Represents a ON CONFLICT DO UPDATE portion of an INSERT statement (ON DUPLICATE KEY UPDATE for mysql)
//
Expand Down
5 changes: 5 additions & 0 deletions exp/exp.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ type (
Expression
Action() ConflictAction
}
ConflictNothingExpression interface {
ConflictExpression
SetCols(cl ColumnListExpression) ConflictNothingExpression
Cols() ColumnListExpression
}
ConflictUpdateExpression interface {
ConflictExpression
TargetColumn() string
Expand Down
2 changes: 1 addition & 1 deletion expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func Cast(e exp.Expression, t string) exp.CastExpression {
// Creates a conflict struct to be passed to InsertConflict to ignore constraint errors
//
// InsertConflict(DoNothing(),...) -> INSERT INTO ... ON CONFLICT DO NOTHING
func DoNothing() exp.ConflictExpression {
func DoNothing() exp.ConflictNothingExpression {
return exp.NewDoNothingConflictExpression()
}

Expand Down
4 changes: 4 additions & 0 deletions insert_dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ func (ids *insertDatasetSuite) TestOnConflict() {
ds: bd.OnConflict(goqu.DoNothing()),
clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(goqu.DoNothing()),
},
insertTestCase{
ds: bd.OnConflict(goqu.DoNothing().SetCols(exp.NewColumnListExpression("items"))),
clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(goqu.DoNothing().SetCols(exp.NewColumnListExpression("items"))),
},
insertTestCase{
ds: bd.OnConflict(du),
clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(du),
Expand Down
18 changes: 18 additions & 0 deletions sqlgen/common_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func ErrReturnNotSupported(dialect string) error {
return errors.New("dialect does not support RETURNING clause [dialect=%s]", dialect)
}

func ErrConflictTargetNotSupported(dialect string) error {
return errors.New("dialect does not support Conflict Target clause [dialect=%s]", dialect)
}

func ErrNotSupportedFragment(sqlType string, f SQLFragmentType) error {
return errors.New("unsupported %s SQL fragment %s", sqlType, f)
}
Expand All @@ -30,6 +34,7 @@ type (
DialectOptions() *SQLDialectOptions
ExpressionSQLGenerator() ExpressionSQLGenerator
ReturningSQL(b sb.SQLBuilder, returns exp.ColumnListExpression)
ConflictTargetSQL(b sb.SQLBuilder, returns exp.ColumnListExpression)
FromSQL(b sb.SQLBuilder, from exp.ColumnListExpression)
SourcesSQL(b sb.SQLBuilder, from exp.ColumnListExpression)
WhereSQL(b sb.SQLBuilder, where exp.ExpressionList)
Expand Down Expand Up @@ -72,6 +77,19 @@ func (csg *commonSQLGenerator) ReturningSQL(b sb.SQLBuilder, returns exp.ColumnL
}
}

func (csg *commonSQLGenerator) ConflictTargetSQL(b sb.SQLBuilder, targets exp.ColumnListExpression) {
if targets != nil && len(targets.Columns()) > 0 {
if csg.dialectOptions.SupportsConflictTarget {
b.WriteRunes(csg.dialectOptions.SpaceRune)
b.WriteRunes(csg.dialectOptions.LeftParenRune)
csg.esg.Generate(b, targets)
b.WriteRunes(csg.dialectOptions.RightParenRune)
} else {
b.SetError(ErrConflictTargetNotSupported(csg.dialect))
}
}
}

// Adds the FROM clause and tables to an sql statement
func (csg *commonSQLGenerator) FromSQL(b sb.SQLBuilder, from exp.ColumnListExpression) {
if from != nil && !from.IsEmpty() {
Expand Down
42 changes: 42 additions & 0 deletions sqlgen/common_sql_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,48 @@ func (csgs *commonSQLGeneratorSuite) TestReturningSQL() {
)
}

func (csgs *commonSQLGeneratorSuite) TestConflictTargetSQL() {
conflictTargetGen := func(csgs sqlgen.CommonSQLGenerator) func(sb.SQLBuilder) {
return func(sb sb.SQLBuilder) {
csgs.ConflictTargetSQL(sb, exp.NewColumnListExpression("a", "b"))
}
}

conflictTargetNoColsGen := func(csgs sqlgen.CommonSQLGenerator) func(sb.SQLBuilder) {
return func(sb sb.SQLBuilder) {
csgs.ConflictTargetSQL(sb, exp.NewColumnListExpression())
}
}

conflictTargetNilExpGen := func(csgs sqlgen.CommonSQLGenerator) func(sb.SQLBuilder) {
return func(sb sb.SQLBuilder) {
csgs.ConflictTargetSQL(sb, nil)
}
}

opts := sqlgen.DefaultDialectOptions()
opts.SupportsConflictTarget = true
csgs1 := sqlgen.NewCommonSQLGenerator("test", opts)

opts2 := sqlgen.DefaultDialectOptions()
opts2.SupportsConflictTarget = false
csgs2 := sqlgen.NewCommonSQLGenerator("test", opts2)

csgs.assertCases(
commonSQLTestCase{gen: conflictTargetGen(csgs1), sql: ` ("a", "b")`},
commonSQLTestCase{gen: conflictTargetGen(csgs1), sql: ` ("a", "b")`, isPrepared: true, args: emptyArgs},

commonSQLTestCase{gen: conflictTargetNoColsGen(csgs1), sql: ``},
commonSQLTestCase{gen: conflictTargetNoColsGen(csgs1), sql: ``, isPrepared: true, args: emptyArgs},

commonSQLTestCase{gen: conflictTargetNilExpGen(csgs1), sql: ``},
commonSQLTestCase{gen: conflictTargetNilExpGen(csgs1), sql: ``, isPrepared: true, args: emptyArgs},

commonSQLTestCase{gen: conflictTargetGen(csgs2), err: `goqu: dialect does not support Conflict Target clause [dialect=test]`},
commonSQLTestCase{gen: conflictTargetGen(csgs2), err: `goqu: dialect does not support Conflict Target clause [dialect=test]`},
)
}

func (csgs *commonSQLGeneratorSuite) TestFromSQL() {
fromGen := func(csgs sqlgen.CommonSQLGenerator) func(sb.SQLBuilder) {
return func(sb sb.SQLBuilder) {
Expand Down
6 changes: 6 additions & 0 deletions sqlgen/insert_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ func (isg *insertSQLGenerator) onConflictSQL(b sb.SQLBuilder, o exp.ConflictExpr
}
}
isg.onConflictDoUpdateSQL(b, t)
case exp.ConflictNothingExpression:
cols := t.Cols()
if isg.DialectOptions().SupportsConflictTarget && cols != nil {
isg.ConflictTargetSQL(b, cols)
}
b.Write(isg.DialectOptions().ConflictDoNothingFragment)
default:
b.Write(isg.DialectOptions().ConflictDoNothingFragment)
}
Expand Down