From 496df2cad0366d6cb8660d4fd59af00ca2ef592e Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sat, 25 Dec 2021 14:49:53 +0200 Subject: [PATCH] dialect/sql: support passing selectors to basic predicates Closed https://github.com/ent/ent/issues/2236 --- dialect/sql/builder.go | 24 ++++++++++++++++++------ dialect/sql/builder_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/dialect/sql/builder.go b/dialect/sql/builder.go index 68667e7ddf..d1e659d07e 100644 --- a/dialect/sql/builder.go +++ b/dialect/sql/builder.go @@ -1326,7 +1326,7 @@ func (p *Predicate) EQ(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpEQ) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1350,7 +1350,7 @@ func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpNEQ) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1374,7 +1374,7 @@ func (p *Predicate) LT(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLT) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1398,7 +1398,7 @@ func (p *Predicate) LTE(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLTE) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1422,7 +1422,7 @@ func (p *Predicate) GT(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGT) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1446,7 +1446,7 @@ func (p *Predicate) GTE(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGTE) - b.Arg(arg) + p.arg(b, arg) }) } @@ -1767,6 +1767,18 @@ func (p *Predicate) Query() (string, []interface{}) { return p.String(), p.args } +// arg calls Builder.Arg, but wraps `a` with parens in case of a Selector. +func (*Predicate) arg(b *Builder, a interface{}) { + switch a.(type) { + case *Selector: + b.Nested(func(b *Builder) { + b.Arg(a) + }) + default: + b.Arg(a) + } +} + // clone returns a shallow clone of p. func (p *Predicate) clone() *Predicate { if p == nil { diff --git a/dialect/sql/builder_test.go b/dialect/sql/builder_test.go index fd328c792c..37cb6322ff 100644 --- a/dialect/sql/builder_test.go +++ b/dialect/sql/builder_test.go @@ -411,6 +411,31 @@ func TestBuilder(t *testing.T) { wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ?", wantArgs: []interface{}{"foo", 10, "foo"}, }, + { + input: Dialect(dialect.Postgres). + Update("users"). + Add("rank", 10). + Where( + Or( + EQ("rank", Select("rank").From(Table("ranks")).Where(EQ("name", "foo"))), + GT("score", Select("score").From(Table("scores")).Where(GT("count", 0))), + ), + ), + wantQuery: `UPDATE "users" SET "rank" = COALESCE("users"."rank", 0) + $1 WHERE "rank" = (SELECT "rank" FROM "ranks" WHERE "name" = $2) OR "score" > (SELECT "score" FROM "scores" WHERE "count" > $3)`, + wantArgs: []interface{}{10, "foo", 0}, + }, + { + input: Update("users"). + Add("rank", 10). + Where( + Or( + EQ("rank", Select("rank").From(Table("ranks")).Where(EQ("name", "foo"))), + GT("score", Select("score").From(Table("scores")).Where(GT("count", 0))), + ), + ), + wantQuery: "UPDATE `users` SET `rank` = COALESCE(`users`.`rank`, 0) + ? WHERE `rank` = (SELECT `rank` FROM `ranks` WHERE `name` = ?) OR `score` > (SELECT `score` FROM `scores` WHERE `count` > ?)", + wantArgs: []interface{}{10, "foo", 0}, + }, { input: Dialect(dialect.Postgres). Update("users").