8000 dialect/sql: use identifier qualifiers for WHERE clause on upsert by a8m · Pull Request #2131 · ent/ent · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

dialect/sql: use identifier qualifiers for WHERE clause on upsert #2131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions dialect/sql/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ func (i *InsertBuilder) writeConflict() {
}
u.update.writeSetter(&i.Builder)
if p := i.conflict.action.where; p != nil {
p.qualifier = i.table
i.WriteString(" WHERE ").Join(p)
}
}
Expand Down Expand Up @@ -2825,11 +2826,12 @@ func (n Queries) Query() (string, []interface{}) {

// Builder is the base query builder for the sql dsl.
type Builder struct {
sb *strings.Builder // underlying builder.
dialect string // configured dialect.
args []interface{} // query parameters.
total int // total number of parameters in query tree.
errs []error // errors that added during the query construction.
sb *strings.Builder // underlying builder.
dialect string // configured dialect.
args []interface{} // query parameters.
total int // total number of parameters in query tree.
errs []error // errors that added during the query construction.
qualifier string // qualifier to prefix identifiers (e.g. table name).
}

// Quote quotes the given identifier with the characters based
Expand All @@ -2856,6 +2858,9 @@ func (b *Builder) Ident(s string) *Builder {
switch {
case len(s) == 0:
case s != "*" && !b.isIdent(s) && !isFunc(s) && !isModifier(s):
if b.qualifier != "" {
b.WriteString(b.Quote(b.qualifier)).WriteByte('.')
}
b.WriteString(b.Quote(s))
case (isFunc(s) || isModifier(s)) && b.postgres():
// Modifiers and aggregation functions that
Expand Down
2 changes: 1 addition & 1 deletion dialect/sql/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1735,7 +1735,7 @@ func TestInsert_OnConflict(t *testing.T) {
UpdateWhere(NEQ("updated_at", 0)),
).
Query()
require.Equal(t, `INSERT INTO "users" ("id", "email", "creation_time") VALUES ($1, $2, $3) ON CONFLICT ("email") WHERE "name" = $4 DO UPDATE SET "id" = "users"."id", "email" = "excluded"."email", "creation_time" = "users"."creation_time", "version" = COALESCE("users"."version", 0) + $5 WHERE "updated_at" <> $6`, query)
require.Equal(t, `INSERT INTO "users" ("id", "email", "creation_time") VALUES ($1, $2, $3) ON CONFLICT ("email") WHERE "name" = $4 DO UPDATE SET "id" = "users"."id", "email" = "excluded"."email", "creation_time" = "users"."creation_time", "version" = COALESCE("users"."version", 0) + $5 WHERE "users"."updated_at" <> $6`, query)
require.Equal(t, []interface{}{"1", "user@example.com", 1633279231, "Ariel", 1, 0}, args)

query, args = Dialect(dialect.Postgres).
Expand Down
46 changes: 38 additions & 8 deletions entc/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package integration

import (
"context"
stdsql "database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -303,7 +304,9 @@ func Upsert(t *testing.T, client *ent.Client) {
SetName("Mashraki").
SetAge(30).
SetPhone("0000").
OnConflict(sql.ConflictColumns(user.FieldPhone)).
OnConflict(
sql.ConflictColumns(user.FieldPhone),
).
// Update "name" to the value that was set on create ("Mashraki").
UpdateName().
ExecX(ctx)
Expand Down Expand Up @@ -397,16 +400,43 @@ func Upsert(t *testing.T, client *ent.Client) {
require.Equal(t, bid, client.Item.Query().OnlyIDX(ctx))
}

ts := time.Unix(1623279251, 0)
c1 := client.Card.Create().
SetNumber("102030").
SetCreateTime(time.Unix(1623279251, 0)).
SetUpdateTime(time.Unix(1623279251, 0)).
SetCreateTime(ts).
SetUpdateTime(ts).
SaveX(ctx)
id = client.Card.Create().
SetNumber(c1.Number).
OnConflictColumns(card.FieldNumber).
UpdateNewValues().
IDX(ctx)

// "DO UPDATE SET ... WHERE ..." does not support by MySQL.
if strings.Contains(t.Name(), "Postgres") || strings.Contains(t.Name(), "SQLite") {
err = client.Card.Create().
SetNumber(c1.Number).
OnConflict(
sql.ConflictColumns(card.FieldNumber),
sql.UpdateWhere(sql.NEQ(card.FieldCreateTime, ts)),
).
UpdateNewValues().
Exec(ctx)
// Only rows for which the "UpdateWhere" expression
// returns true will be updated. That is, none.
require.True(t, errors.Is(err, stdsql.ErrNoRows))

id = client.Card.Create().
SetNumber(c1.Number).
OnConflict(
sql.ConflictColumns(card.FieldNumber),
sql.UpdateWhere(sql.EQ(card.FieldCreateTime, ts)),
).
UpdateNewValues().
IDX(ctx)
} else {
id = client.Card.Create().
SetNumber(c1.Number).
OnConflictColumns(card.FieldNumber).
UpdateNewValues().
IDX(ctx)
}

c2 := client.Card.GetX(ctx, id)
require.Equal(t, c1.CreateTime.Unix(), c2.CreateTime.Unix())
require.NotEqual(t, c1.UpdateTime.Unix(), c2.UpdateTime.Unix())
Expand Down
0