From 93a1ae2a3b11a5feba348f0275abbf8bf65efd04 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Sat, 23 Jul 2022 21:29:59 +0300 Subject: [PATCH] entc/gen: move eager-loading to method This is a preparation work for 'WithNamed' API --- entc/gen/template/dialect/sql/query.tmpl | 268 ++--- .../cascadelete/ent/comment_query.go | 52 +- .../integration/cascadelete/ent/post_query.go | 99 +- .../integration/cascadelete/ent/user_query.go | 53 +- .../integration/customid/ent/account_query.go | 61 +- entc/integration/customid/ent/blob_query.go | 208 ++-- .../customid/ent/bloblink_query.go | 98 +- entc/integration/customid/ent/car_query.go | 58 +- entc/integration/customid/ent/device_query.go | 113 ++- entc/integration/customid/ent/doc_query.go | 214 ++-- entc/integration/customid/ent/group_query.go | 105 +- entc/integration/customid/ent/intsid_query.go | 113 ++- entc/integration/customid/ent/note_query.go | 113 ++- entc/integration/customid/ent/pet_query.go | 272 ++--- .../integration/customid/ent/session_query.go | 58 +- entc/integration/customid/ent/token_query.go | 58 +- entc/integration/customid/ent/user_query.go | 277 ++--- entc/integration/edgefield/ent/car_query.go | 53 +- entc/integration/edgefield/ent/card_query.go | 52 +- entc/integration/edgefield/ent/info_query.go | 52 +- .../edgefield/ent/metadata_query.go | 147 +-- entc/integration/edgefield/ent/node_query.go | 96 +- entc/integration/edgefield/ent/pet_query.go | 52 +- entc/integration/edgefield/ent/post_query.go | 58 +- .../integration/edgefield/ent/rental_query.go | 98 +- entc/integration/edgefield/ent/user_query.go | 388 ++++--- .../edgeschema/ent/friendship_query.go | 98 +- .../integration/edgeschema/ent/group_query.go | 154 +-- .../edgeschema/ent/relationship_query.go | 146 +-- entc/integration/edgeschema/ent/role_query.go | 154 +-- .../edgeschema/ent/roleuser_query.go | 98 +- entc/integration/edgeschema/ent/tag_query.go | 154 +-- .../integration/edgeschema/ent/tweet_query.go | 472 +++++---- .../edgeschema/ent/tweetlike_query.go | 98 +- .../edgeschema/ent/tweettag_query.go | 98 +- entc/integration/edgeschema/ent/user_query.go | 944 ++++++++++-------- .../edgeschema/ent/usergroup_query.go | 98 +- .../edgeschema/ent/usertweet_query.go | 98 +- entc/integration/ent/card_query.go | 159 +-- entc/integration/ent/file_query.go | 165 +-- entc/integration/ent/filetype_query.go | 61 +- entc/integration/ent/group_query.go | 277 ++--- entc/integration/ent/groupinfo_query.go | 61 +- entc/integration/ent/node_query.go | 108 +- entc/integration/ent/pet_query.go | 110 +- entc/integration/ent/spec_query.go | 105 +- entc/integration/ent/user_query.go | 811 ++++++++------- entc/integration/hooks/ent/card_query.go | 58 +- entc/integration/hooks/ent/user_query.go | 216 ++-- entc/integration/idtype/ent/user_query.go | 264 ++--- entc/integration/migrate/entv1/car_query.go | 58 +- entc/integration/migrate/entv1/user_query.go | 223 +++-- entc/integration/migrate/entv2/car_query.go | 58 +- entc/integration/migrate/entv2/pet_query.go | 58 +- entc/integration/migrate/entv2/user_query.go | 214 ++-- .../multischema/ent/group_query.go | 105 +- entc/integration/multischema/ent/pet_query.go | 52 +- .../integration/multischema/ent/user_query.go | 154 +-- entc/integration/privacy/ent/task_query.go | 159 +-- entc/integration/privacy/ent/team_query.go | 206 ++-- entc/integration/privacy/ent/user_query.go | 162 +-- entc/integration/template/ent/pet_query.go | 58 +- entc/integration/template/ent/user_query.go | 162 +-- examples/edgeindex/ent/city_query.go | 61 +- examples/edgeindex/ent/street_query.go | 58 +- examples/fs/ent/file_query.go | 99 +- examples/m2m2types/ent/group_query.go | 105 +- examples/m2m2types/ent/user_query.go | 105 +- examples/m2mbidi/ent/user_query.go | 105 +- examples/m2mrecur/ent/user_query.go | 206 ++-- examples/o2m2types/ent/pet_query.go | 58 +- examples/o2m2types/ent/user_query.go | 61 +- examples/o2mrecur/ent/node_query.go | 113 ++- examples/o2o2types/ent/card_query.go | 58 +- examples/o2o2types/ent/user_query.go | 56 +- examples/o2obidi/ent/user_query.go | 58 +- examples/o2orecur/ent/node_query.go | 108 +- examples/privacytenant/ent/group_query.go | 153 +-- examples/privacytenant/ent/user_query.go | 153 +-- examples/start/ent/car_query.go | 58 +- examples/start/ent/group_query.go | 105 +- examples/start/ent/user_query.go | 162 +-- examples/traversal/ent/group_query.go | 159 +-- examples/traversal/ent/pet_query.go | 159 +-- examples/traversal/ent/user_query.go | 328 +++--- 85 files changed, 6904 insertions(+), 5516 deletions(-) diff --git a/entc/gen/template/dialect/sql/query.tmpl b/entc/gen/template/dialect/sql/query.tmpl index 1f52ecf939..fa7c3cfa81 100644 --- a/entc/gen/template/dialect/sql/query.tmpl +++ b/entc/gen/template/dialect/sql/query.tmpl @@ -75,9 +75,13 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer return nodes, nil } {{- range $e := $.Edges }} - {{- with extend $ "Rec" $receiver "Edge" $e }} - {{ template "dialect/sql/query/eagerloading" . }} - {{- end }} + if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil { + if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if and (not $e.M2M) (not $e.O2M) }}nil{{ else }} + func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }}, + func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if or $e.OwnFK $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil { + return nil, err + } + } {{- end }} {{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}} {{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }} @@ -88,6 +92,137 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer return nodes, nil } +{{/* Generate a method to eager-load each edge. */}} +{{- range $e := $.Edges }} + func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error { + {{- if $e.M2M }} + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[{{ $.ID.Type }}]*{{ $.Name }}) + nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }}) + {{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }} + {{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }} + s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}])) + s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + {{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }} + {{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }} + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]interface{}{new({{ $out }})}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} + inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} + if nids[inValue] == nil { + nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) + } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + {{- else if $e.OwnFK }} + ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes)) + nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) + for i := range nodes { + {{- $fk := $e.ForeignKey }} + {{- if $fk.Field.Nillable }} + if nodes[i].{{ $fk.StructField }} == nil { + continue + } + {{- end }} + fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }} + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where({{ $e.Type.Package }}.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + {{- else }} + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }}) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + {{- if $e.O2M }} + if init != nil { + init(nodes[i]) + } + {{- end }} + } + {{- with $e.Type.UnexportedForeignKeys }} + query.withFKs = true + {{- end }} + query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) { + s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + {{- $fk := $e.ForeignKey }} + fk := n.{{ $fk.StructField }} + {{- if $fk.Field.Nillable }} + if fk == nil { + return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID) + } + {{- end }} + node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }}) + } + assign(node, n) + } + {{- end }} + return nil + } +{{- end }} + func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) { _spec := {{ $receiver }}.querySpec() {{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}} @@ -286,133 +421,6 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select {{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step) {{ end }} -{{ define "dialect/sql/query/eagerloading" }} - {{- $e := $.Scope.Edge }} - {{- $receiver := $.Scope.Rec }} - if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil { - {{- if $e.M2M }} - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[{{ $.ID.Type }}]*{{ $.Name }}) - nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }}) - {{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }} - {{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }} - s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}])) - s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - {{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }} - {{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }} - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new({{ $out }})}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} - inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }} - if nids[inValue] == nil { - nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err - } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n) - } - } - {{- else if $e.OwnFK }} - ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes)) - nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }}) - for i := range nodes { - {{- $fk := $e.ForeignKey }} - {{- if $fk.Field.Nillable }} - if nodes[i].{{ $fk.StructField }} == nil { - continue - } - {{- end }} - fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }} - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where({{ $e.Type.Package }}.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err - } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.{{ $e.StructField }} = n - } - } - {{- else }} - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }}) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - {{- if $e.O2M }} - nodes[i].Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} - {{- end }} - } - {{- with $e.Type.UnexportedForeignKeys }} - query.withFKs = true - {{- end }} - query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) { - s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err - } - for _, n := range neighbors { - {{- $fk := $e.ForeignKey }} - fk := n.{{ $fk.StructField }} - {{- if $fk.Field.Nillable }} - if fk == nil { - return nil, fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID) - } - {{- end }} - node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }}) - } - node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }} - } - {{- end }} - } -{{ end }} - {{ define "dialect/sql/query/eagerloading/m2massign" }} {{- $arg := $.Scope.Arg }} {{- $field := $.Scope.Field }} diff --git a/entc/integration/cascadelete/ent/comment_query.go b/entc/integration/cascadelete/ent/comment_query.go index 579f8a030e..40bd7d0678 100644 --- a/entc/integration/cascadelete/ent/comment_query.go +++ b/entc/integration/cascadelete/ent/comment_query.go @@ -380,36 +380,42 @@ func (cq *CommentQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Comm if len(nodes) == 0 { return nodes, nil } - if query := cq.withPost; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Comment) - for i := range nodes { - fk := nodes[i].PostID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(post.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadPost(ctx, query, nodes, nil, + func(n *Comment, e *Post) { n.Edges.Post = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "post_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Post = n - } - } } - return nodes, nil } +func (cq *CommentQuery) loadPost(ctx context.Context, query *PostQuery, nodes []*Comment, init func(*Comment), assign func(*Comment, *Post)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Comment) + for i := range nodes { + fk := nodes[i].PostID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(post.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "post_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CommentQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/cascadelete/ent/post_query.go b/entc/integration/cascadelete/ent/post_query.go index 0c849ad51c..6e17425b8c 100644 --- a/entc/integration/cascadelete/ent/post_query.go +++ b/entc/integration/cascadelete/ent/post_query.go @@ -418,59 +418,74 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e if len(nodes) == 0 { return nodes, nil } - if query := pq.withAuthor; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Post) - for i := range nodes { - fk := nodes[i].AuthorID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadAuthor(ctx, query, nodes, nil, + func(n *Post, e *User) { n.Edges.Author = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "author_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Author = n - } + } + if query := pq.withComments; query != nil { + if err := pq.loadComments(ctx, query, nodes, + func(n *Post) { n.Edges.Comments = []*Comment{} }, + func(n *Post, e *Comment) { n.Edges.Comments = append(n.Edges.Comments, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := pq.withComments; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Post) +func (pq *PostQuery) loadAuthor(ctx context.Context, query *UserQuery, nodes []*Post, init func(*Post), assign func(*Post, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Post) + for i := range nodes { + fk := nodes[i].AuthorID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "author_id" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Comments = []*Comment{} + assign(nodes[i], n) } - query.Where(predicate.Comment(func(s *sql.Selector) { - s.Where(sql.InValues(post.CommentsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (pq *PostQuery) loadComments(ctx context.Context, query *CommentQuery, nodes []*Post, init func(*Post), assign func(*Post, *Comment)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Post) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.PostID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "post_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Comments = append(node.Edges.Comments, n) + } + query.Where(predicate.Comment(func(s *sql.Selector) { + s.Where(sql.InValues(post.CommentsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PostID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "post_id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - return nodes, nil + return nil } func (pq *PostQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/cascadelete/ent/user_query.go b/entc/integration/cascadelete/ent/user_query.go index 2c4799070d..932955e3ac 100644 --- a/entc/integration/cascadelete/ent/user_query.go +++ b/entc/integration/cascadelete/ent/user_query.go @@ -381,35 +381,44 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPosts; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Posts = []*Post{} - } - query.Where(predicate.Post(func(s *sql.Selector) { - s.Where(sql.InValues(user.PostsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPosts(ctx, query, nodes, + func(n *User) { n.Edges.Posts = []*Post{} }, + func(n *User, e *Post) { n.Edges.Posts = append(n.Edges.Posts, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.AuthorID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "author_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Posts = append(node.Edges.Posts, n) - } } - return nodes, nil } +func (uq *UserQuery) loadPosts(ctx context.Context, query *PostQuery, nodes []*User, init func(*User), assign func(*User, *Post)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.Post(func(s *sql.Selector) { + s.Where(sql.InValues(user.PostsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.AuthorID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "author_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() _spec.Node.Columns = uq.fields diff --git a/entc/integration/customid/ent/account_query.go b/entc/integration/customid/ent/account_query.go index 259cc4c727..f77b76f0f1 100644 --- a/entc/integration/customid/ent/account_query.go +++ b/entc/integration/customid/ent/account_query.go @@ -382,39 +382,48 @@ func (aq *AccountQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Acco if len(nodes) == 0 { return nodes, nil } - if query := aq.withToken; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[sid.ID]*Account) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Token = []*Token{} - } - query.withFKs = true - query.Where(predicate.Token(func(s *sql.Selector) { - s.Where(sql.InValues(account.TokenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := aq.loadToken(ctx, query, nodes, + func(n *Account) { n.Edges.Token = []*Token{} }, + func(n *Account, e *Token) { n.Edges.Token = append(n.Edges.Token, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.account_token - if fk == nil { - return nil, fmt.Errorf(`foreign-key "account_token" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "account_token" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Token = append(node.Edges.Token, n) - } } - return nodes, nil } +func (aq *AccountQuery) loadToken(ctx context.Context, query *TokenQuery, nodes []*Account, init func(*Account), assign func(*Account, *Token)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[sid.ID]*Account) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Token(func(s *sql.Selector) { + s.Where(sql.InValues(account.TokenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.account_token + if fk == nil { + return fmt.Errorf(`foreign-key "account_token" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "account_token" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (aq *AccountQuery) sqlCount(ctx context.Context) (int, error) { _spec := aq.querySpec() _spec.Node.Columns = aq.fields diff --git a/entc/integration/customid/ent/blob_query.go b/entc/integration/customid/ent/blob_query.go index 3709258fd8..4fbf980afb 100644 --- a/entc/integration/customid/ent/blob_query.go +++ b/entc/integration/customid/ent/blob_query.go @@ -462,115 +462,139 @@ func (bq *BlobQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Blob, e if len(nodes) == 0 { return nodes, nil } - if query := bq.withParent; query != nil { - ids := make([]uuid.UUID, 0, len(nodes)) - nodeids := make(map[uuid.UUID][]*Blob) - for i := range nodes { - if nodes[i].blob_parent == nil { - continue - } - fk := *nodes[i].blob_parent - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := bq.loadParent(ctx, query, nodes, nil, + func(n *Blob, e *Blob) { n.Edges.Parent = e }); err != nil { + return nil, err } - query.Where(blob.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := bq.withLinks; query != nil { + if err := bq.loadLinks(ctx, query, nodes, + func(n *Blob) { n.Edges.Links = []*Blob{} }, + func(n *Blob, e *Blob) { n.Edges.Links = append(n.Edges.Links, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "blob_parent" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := bq.withBlobLinks; query != nil { + if err := bq.loadBlobLinks(ctx, query, nodes, + func(n *Blob) { n.Edges.BlobLinks = []*BlobLink{} }, + func(n *Blob, e *BlobLink) { n.Edges.BlobLinks = append(n.Edges.BlobLinks, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := bq.withLinks; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[uuid.UUID]*Blob) - nids := make(map[uuid.UUID]map[*Blob]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Links = []*Blob{} +func (bq *BlobQuery) loadParent(ctx context.Context, query *BlobQuery, nodes []*Blob, init func(*Blob), assign func(*Blob, *Blob)) error { + ids := make([]uuid.UUID, 0, len(nodes)) + nodeids := make(map[uuid.UUID][]*Blob) + for i := range nodes { + if nodes[i].blob_parent == nil { + continue } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(blob.LinksTable) - s.Join(joinT).On(s.C(blob.FieldID), joinT.C(blob.LinksPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(blob.LinksPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(blob.LinksPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(uuid.UUID)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := *values[0].(*uuid.UUID) - inValue := *values[1].(*uuid.UUID) - if nids[inValue] == nil { - nids[inValue] = map[*Blob]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + fk := *nodes[i].blob_parent + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(blob.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "blob_parent" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "links" node returned %v`, n.ID) + } + return nil +} +func (bq *BlobQuery) loadLinks(ctx context.Context, query *BlobQuery, nodes []*Blob, init func(*Blob), assign func(*Blob, *Blob)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[uuid.UUID]*Blob) + nids := make(map[uuid.UUID]map[*Blob]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(blob.LinksTable) + s.Join(joinT).On(s.C(blob.FieldID), joinT.C(blob.LinksPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(blob.LinksPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(blob.LinksPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Links = append(kn.Edges.Links, n) + return append([]interface{}{new(uuid.UUID)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := *values[0].(*uuid.UUID) + inValue := *values[1].(*uuid.UUID) + if nids[inValue] == nil { + nids[inValue] = map[*Blob]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := bq.withBlobLinks; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[uuid.UUID]*Blob) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.BlobLinks = []*BlobLink{} + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "links" node returned %v`, n.ID) } - query.Where(predicate.BlobLink(func(s *sql.Selector) { - s.Where(sql.InValues(blob.BlobLinksColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - fk := n.BlobID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "blob_id" returned %v for node %v`, fk, n) - } - node.Edges.BlobLinks = append(node.Edges.BlobLinks, n) + } + return nil +} +func (bq *BlobQuery) loadBlobLinks(ctx context.Context, query *BlobLinkQuery, nodes []*Blob, init func(*Blob), assign func(*Blob, *BlobLink)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[uuid.UUID]*Blob) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.Where(predicate.BlobLink(func(s *sql.Selector) { + s.Where(sql.InValues(blob.BlobLinksColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.BlobID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "blob_id" returned %v for node %v`, fk, n) + } + assign(node, n) + } + return nil } func (bq *BlobQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/bloblink_query.go b/entc/integration/customid/ent/bloblink_query.go index ba8ad9ad58..cc4c43b28d 100644 --- a/entc/integration/customid/ent/bloblink_query.go +++ b/entc/integration/customid/ent/bloblink_query.go @@ -347,60 +347,72 @@ func (blq *BlobLinkQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Bl if len(nodes) == 0 { return nodes, nil } - if query := blq.withBlob; query != nil { - ids := make([]uuid.UUID, 0, len(nodes)) - nodeids := make(map[uuid.UUID][]*BlobLink) - for i := range nodes { - fk := nodes[i].BlobID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(blob.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := blq.loadBlob(ctx, query, nodes, nil, + func(n *BlobLink, e *Blob) { n.Edges.Blob = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "blob_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Blob = n - } + } + if query := blq.withLink; query != nil { + if err := blq.loadLink(ctx, query, nodes, nil, + func(n *BlobLink, e *Blob) { n.Edges.Link = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := blq.withLink; query != nil { - ids := make([]uuid.UUID, 0, len(nodes)) - nodeids := make(map[uuid.UUID][]*BlobLink) +func (blq *BlobLinkQuery) loadBlob(ctx context.Context, query *BlobQuery, nodes []*BlobLink, init func(*BlobLink), assign func(*BlobLink, *Blob)) error { + ids := make([]uuid.UUID, 0, len(nodes)) + nodeids := make(map[uuid.UUID][]*BlobLink) + for i := range nodes { + fk := nodes[i].BlobID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(blob.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "blob_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].LinkID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(blob.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (blq *BlobLinkQuery) loadLink(ctx context.Context, query *BlobQuery, nodes []*BlobLink, init func(*BlobLink), assign func(*BlobLink, *Blob)) error { + ids := make([]uuid.UUID, 0, len(nodes)) + nodeids := make(map[uuid.UUID][]*BlobLink) + for i := range nodes { + fk := nodes[i].LinkID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "link_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Link = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(blob.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "link_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (blq *BlobLinkQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/car_query.go b/entc/integration/customid/ent/car_query.go index 9e3cb5f8b4..725cd9432e 100644 --- a/entc/integration/customid/ent/car_query.go +++ b/entc/integration/customid/ent/car_query.go @@ -388,39 +388,45 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]string, 0, len(nodes)) - nodeids := make(map[string][]*Car) - for i := range nodes { - if nodes[i].pet_cars == nil { - continue - } - fk := *nodes[i].pet_cars - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(pet.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Car, e *Pet) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "pet_cars" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CarQuery) loadOwner(ctx context.Context, query *PetQuery, nodes []*Car, init func(*Car), assign func(*Car, *Pet)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*Car) + for i := range nodes { + if nodes[i].pet_cars == nil { + continue + } + fk := *nodes[i].pet_cars + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(pet.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pet_cars" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/customid/ent/device_query.go b/entc/integration/customid/ent/device_query.go index 0ea47442e9..38cc5a7891 100644 --- a/entc/integration/customid/ent/device_query.go +++ b/entc/integration/customid/ent/device_query.go @@ -402,66 +402,81 @@ func (dq *DeviceQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Devic if len(nodes) == 0 { return nodes, nil } - if query := dq.withActiveSession; query != nil { - ids := make([]schema.ID, 0, len(nodes)) - nodeids := make(map[schema.ID][]*Device) - for i := range nodes { - if nodes[i].device_active_session == nil { - continue - } - fk := *nodes[i].device_active_session - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(session.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := dq.loadActiveSession(ctx, query, nodes, nil, + func(n *Device, e *Session) { n.Edges.ActiveSession = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "device_active_session" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.ActiveSession = n - } + } + if query := dq.withSessions; query != nil { + if err := dq.loadSessions(ctx, query, nodes, + func(n *Device) { n.Edges.Sessions = []*Session{} }, + func(n *Device, e *Session) { n.Edges.Sessions = append(n.Edges.Sessions, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := dq.withSessions; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[schema.ID]*Device) +func (dq *DeviceQuery) loadActiveSession(ctx context.Context, query *SessionQuery, nodes []*Device, init func(*Device), assign func(*Device, *Session)) error { + ids := make([]schema.ID, 0, len(nodes)) + nodeids := make(map[schema.ID][]*Device) + for i := range nodes { + if nodes[i].device_active_session == nil { + continue + } + fk := *nodes[i].device_active_session + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(session.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "device_active_session" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Sessions = []*Session{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Session(func(s *sql.Selector) { - s.Where(sql.InValues(device.SessionsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (dq *DeviceQuery) loadSessions(ctx context.Context, query *SessionQuery, nodes []*Device, init func(*Device), assign func(*Device, *Session)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[schema.ID]*Device) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.device_sessions - if fk == nil { - return nil, fmt.Errorf(`foreign-key "device_sessions" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "device_sessions" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Sessions = append(node.Edges.Sessions, n) + } + query.withFKs = true + query.Where(predicate.Session(func(s *sql.Selector) { + s.Where(sql.InValues(device.SessionsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.device_sessions + if fk == nil { + return fmt.Errorf(`foreign-key "device_sessions" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "device_sessions" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - return nodes, nil + return nil } func (dq *DeviceQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/doc_query.go b/entc/integration/customid/ent/doc_query.go index 24b6f343bc..41534ea6a9 100644 --- a/entc/integration/customid/ent/doc_query.go +++ b/entc/integration/customid/ent/doc_query.go @@ -461,119 +461,143 @@ func (dq *DocQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Doc, err if len(nodes) == 0 { return nodes, nil } - if query := dq.withParent; query != nil { - ids := make([]schema.DocID, 0, len(nodes)) - nodeids := make(map[schema.DocID][]*Doc) - for i := range nodes { - if nodes[i].doc_children == nil { - continue - } - fk := *nodes[i].doc_children - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := dq.loadParent(ctx, query, nodes, nil, + func(n *Doc, e *Doc) { n.Edges.Parent = e }); err != nil { + return nil, err } - query.Where(doc.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := dq.withChildren; query != nil { + if err := dq.loadChildren(ctx, query, nodes, + func(n *Doc) { n.Edges.Children = []*Doc{} }, + func(n *Doc, e *Doc) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "doc_children" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := dq.withRelated; query != nil { + if err := dq.loadRelated(ctx, query, nodes, + func(n *Doc) { n.Edges.Related = []*Doc{} }, + func(n *Doc, e *Doc) { n.Edges.Related = append(n.Edges.Related, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := dq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[schema.DocID]*Doc) +func (dq *DocQuery) loadParent(ctx context.Context, query *DocQuery, nodes []*Doc, init func(*Doc), assign func(*Doc, *Doc)) error { + ids := make([]schema.DocID, 0, len(nodes)) + nodeids := make(map[schema.DocID][]*Doc) + for i := range nodes { + if nodes[i].doc_children == nil { + continue + } + fk := *nodes[i].doc_children + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(doc.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "doc_children" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*Doc{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Doc(func(s *sql.Selector) { - s.Where(sql.InValues(doc.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (dq *DocQuery) loadChildren(ctx context.Context, query *DocQuery, nodes []*Doc, init func(*Doc), assign func(*Doc, *Doc)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[schema.DocID]*Doc) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.doc_children - if fk == nil { - return nil, fmt.Errorf(`foreign-key "doc_children" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "doc_children" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.withFKs = true + query.Where(predicate.Doc(func(s *sql.Selector) { + s.Where(sql.InValues(doc.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.doc_children + if fk == nil { + return fmt.Errorf(`foreign-key "doc_children" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "doc_children" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - if query := dq.withRelated; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[schema.DocID]*Doc) - nids := make(map[schema.DocID]map[*Doc]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Related = []*Doc{} + return nil +} +func (dq *DocQuery) loadRelated(ctx context.Context, query *DocQuery, nodes []*Doc, init func(*Doc), assign func(*Doc, *Doc)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[schema.DocID]*Doc) + nids := make(map[schema.DocID]map[*Doc]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(doc.RelatedTable) - s.Join(joinT).On(s.C(doc.FieldID), joinT.C(doc.RelatedPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(doc.RelatedPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(doc.RelatedPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(schema.DocID)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := *values[0].(*schema.DocID) - inValue := *values[1].(*schema.DocID) - if nids[inValue] == nil { - nids[inValue] = map[*Doc]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(doc.RelatedTable) + s.Join(joinT).On(s.C(doc.FieldID), joinT.C(doc.RelatedPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(doc.RelatedPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(doc.RelatedPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - }) - if err != nil { - return nil, err + return append([]interface{}{new(schema.DocID)}, values...), nil } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "related" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Related = append(kn.Edges.Related, n) + spec.Assign = func(columns []string, values []interface{}) error { + outValue := *values[0].(*schema.DocID) + inValue := *values[1].(*schema.DocID) + if nids[inValue] == nil { + nids[inValue] = map[*Doc]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "related" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (dq *DocQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/group_query.go b/entc/integration/customid/ent/group_query.go index c927c7dc8d..6f13142114 100644 --- a/entc/integration/customid/ent/group_query.go +++ b/entc/integration/customid/ent/group_query.go @@ -357,61 +357,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/intsid_query.go b/entc/integration/customid/ent/intsid_query.go index 8c015296ad..a5806d06f6 100644 --- a/entc/integration/customid/ent/intsid_query.go +++ b/entc/integration/customid/ent/intsid_query.go @@ -401,66 +401,81 @@ func (isq *IntSIDQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IntS if len(nodes) == 0 { return nodes, nil } - if query := isq.withParent; query != nil { - ids := make([]sid.ID, 0, len(nodes)) - nodeids := make(map[sid.ID][]*IntSID) - for i := range nodes { - if nodes[i].int_sid_parent == nil { - continue - } - fk := *nodes[i].int_sid_parent - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(intsid.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := isq.loadParent(ctx, query, nodes, nil, + func(n *IntSID, e *IntSID) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "int_sid_parent" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := isq.withChildren; query != nil { + if err := isq.loadChildren(ctx, query, nodes, + func(n *IntSID) { n.Edges.Children = []*IntSID{} }, + func(n *IntSID, e *IntSID) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := isq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[sid.ID]*IntSID) +func (isq *IntSIDQuery) loadParent(ctx context.Context, query *IntSIDQuery, nodes []*IntSID, init func(*IntSID), assign func(*IntSID, *IntSID)) error { + ids := make([]sid.ID, 0, len(nodes)) + nodeids := make(map[sid.ID][]*IntSID) + for i := range nodes { + if nodes[i].int_sid_parent == nil { + continue + } + fk := *nodes[i].int_sid_parent + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(intsid.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "int_sid_parent" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*IntSID{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.IntSID(func(s *sql.Selector) { - s.Where(sql.InValues(intsid.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (isq *IntSIDQuery) loadChildren(ctx context.Context, query *IntSIDQuery, nodes []*IntSID, init func(*IntSID), assign func(*IntSID, *IntSID)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[sid.ID]*IntSID) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.int_sid_parent - if fk == nil { - return nil, fmt.Errorf(`foreign-key "int_sid_parent" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "int_sid_parent" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.withFKs = true + query.Where(predicate.IntSID(func(s *sql.Selector) { + s.Where(sql.InValues(intsid.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.int_sid_parent + if fk == nil { + return fmt.Errorf(`foreign-key "int_sid_parent" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "int_sid_parent" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - return nodes, nil + return nil } func (isq *IntSIDQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/note_query.go b/entc/integration/customid/ent/note_query.go index ee35f0774e..fd85cb1f79 100644 --- a/entc/integration/customid/ent/note_query.go +++ b/entc/integration/customid/ent/note_query.go @@ -425,66 +425,81 @@ func (nq *NoteQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Note, e if len(nodes) == 0 { return nodes, nil } - if query := nq.withParent; query != nil { - ids := make([]schema.NoteID, 0, len(nodes)) - nodeids := make(map[schema.NoteID][]*Note) - for i := range nodes { - if nodes[i].note_children == nil { - continue - } - fk := *nodes[i].note_children - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(note.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := nq.loadParent(ctx, query, nodes, nil, + func(n *Note, e *Note) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "note_children" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := nq.withChildren; query != nil { + if err := nq.loadChildren(ctx, query, nodes, + func(n *Note) { n.Edges.Children = []*Note{} }, + func(n *Note, e *Note) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := nq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[schema.NoteID]*Note) +func (nq *NoteQuery) loadParent(ctx context.Context, query *NoteQuery, nodes []*Note, init func(*Note), assign func(*Note, *Note)) error { + ids := make([]schema.NoteID, 0, len(nodes)) + nodeids := make(map[schema.NoteID][]*Note) + for i := range nodes { + if nodes[i].note_children == nil { + continue + } + fk := *nodes[i].note_children + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(note.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "note_children" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*Note{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Note(func(s *sql.Selector) { - s.Where(sql.InValues(note.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (nq *NoteQuery) loadChildren(ctx context.Context, query *NoteQuery, nodes []*Note, init func(*Note), assign func(*Note, *Note)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[schema.NoteID]*Note) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.note_children - if fk == nil { - return nil, fmt.Errorf(`foreign-key "note_children" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "note_children" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.withFKs = true + query.Where(predicate.Note(func(s *sql.Selector) { + s.Where(sql.InValues(note.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.note_children + if fk == nil { + return fmt.Errorf(`foreign-key "note_children" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "note_children" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - return nodes, nil + return nil } func (nq *NoteQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/pet_query.go b/entc/integration/customid/ent/pet_query.go index 684a3d8287..cd9284d6c6 100644 --- a/entc/integration/customid/ent/pet_query.go +++ b/entc/integration/customid/ent/pet_query.go @@ -474,148 +474,178 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].user_pets == nil { - continue - } - fk := *nodes[i].user_pets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - if query := pq.withCars; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[string]*Pet) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Cars = []*Car{} - } - query.withFKs = true - query.Where(predicate.Car(func(s *sql.Selector) { - s.Where(sql.InValues(pet.CarsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadCars(ctx, query, nodes, + func(n *Pet) { n.Edges.Cars = []*Car{} }, + func(n *Pet, e *Car) { n.Edges.Cars = append(n.Edges.Cars, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.pet_cars - if fk == nil { - return nil, fmt.Errorf(`foreign-key "pet_cars" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "pet_cars" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Cars = append(node.Edges.Cars, n) - } } - if query := pq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[string]*Pet) - nids := make(map[string]map[*Pet]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*Pet{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(pet.FriendsTable) - s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(pet.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullString)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := values[0].(*sql.NullString).String - inValue := values[1].(*sql.NullString).String - if nids[inValue] == nil { - nids[inValue] = map[*Pet]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := pq.loadFriends(ctx, query, nodes, + func(n *Pet) { n.Edges.Friends = []*Pet{} }, + func(n *Pet, e *Pet) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } + } + if query := pq.withBestFriend; query != nil { + if err := pq.loadBestFriend(ctx, query, nodes, nil, + func(n *Pet, e *Pet) { n.Edges.BestFriend = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := pq.withBestFriend; query != nil { - ids := make([]string, 0, len(nodes)) - nodeids := make(map[string][]*Pet) +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_pets == nil { + continue + } + fk := *nodes[i].user_pets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + } for i := range nodes { - if nodes[i].pet_best_friend == nil { - continue - } - fk := *nodes[i].pet_best_friend - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(pet.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (pq *PetQuery) loadCars(ctx context.Context, query *CarQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *Car)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[string]*Pet) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Car(func(s *sql.Selector) { + s.Where(sql.InValues(pet.CarsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.pet_cars + if fk == nil { + return fmt.Errorf(`foreign-key "pet_cars" is nil for node %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "pet_best_friend" returned %v`, n.ID) + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pet_cars" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (pq *PetQuery) loadFriends(ctx context.Context, query *PetQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *Pet)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[string]*Pet) + nids := make(map[string]map[*Pet]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(pet.FriendsTable) + s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(pet.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for i := range nodes { - nodes[i].Edges.BestFriend = n + return append([]interface{}{new(sql.NullString)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := values[0].(*sql.NullString).String + inValue := values[1].(*sql.NullString).String + if nids[inValue] == nil { + nids[inValue] = map[*Pet]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (pq *PetQuery) loadBestFriend(ctx context.Context, query *PetQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *Pet)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*Pet) + for i := range nodes { + if nodes[i].pet_best_friend == nil { + continue + } + fk := *nodes[i].pet_best_friend + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(pet.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pet_best_friend" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/customid/ent/session_query.go b/entc/integration/customid/ent/session_query.go index 69dfe0ebdb..0085512cc2 100644 --- a/entc/integration/customid/ent/session_query.go +++ b/entc/integration/customid/ent/session_query.go @@ -365,39 +365,45 @@ func (sq *SessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Sess if len(nodes) == 0 { return nodes, nil } - if query := sq.withDevice; query != nil { - ids := make([]schema.ID, 0, len(nodes)) - nodeids := make(map[schema.ID][]*Session) - for i := range nodes { - if nodes[i].device_sessions == nil { - continue - } - fk := *nodes[i].device_sessions - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(device.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := sq.loadDevice(ctx, query, nodes, nil, + func(n *Session, e *Device) { n.Edges.Device = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "device_sessions" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Device = n - } - } } - return nodes, nil } +func (sq *SessionQuery) loadDevice(ctx context.Context, query *DeviceQuery, nodes []*Session, init func(*Session), assign func(*Session, *Device)) error { + ids := make([]schema.ID, 0, len(nodes)) + nodeids := make(map[schema.ID][]*Session) + for i := range nodes { + if nodes[i].device_sessions == nil { + continue + } + fk := *nodes[i].device_sessions + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(device.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "device_sessions" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (sq *SessionQuery) sqlCount(ctx context.Context) (int, error) { _spec := sq.querySpec() _spec.Node.Columns = sq.fields diff --git a/entc/integration/customid/ent/token_query.go b/entc/integration/customid/ent/token_query.go index 4ca5ae59c3..9509bbb450 100644 --- a/entc/integration/customid/ent/token_query.go +++ b/entc/integration/customid/ent/token_query.go @@ -389,39 +389,45 @@ func (tq *TokenQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Token, if len(nodes) == 0 { return nodes, nil } - if query := tq.withAccount; query != nil { - ids := make([]sid.ID, 0, len(nodes)) - nodeids := make(map[sid.ID][]*Token) - for i := range nodes { - if nodes[i].account_token == nil { - continue - } - fk := *nodes[i].account_token - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(account.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := tq.loadAccount(ctx, query, nodes, nil, + func(n *Token, e *Account) { n.Edges.Account = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "account_token" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Account = n - } - } } - return nodes, nil } +func (tq *TokenQuery) loadAccount(ctx context.Context, query *AccountQuery, nodes []*Token, init func(*Token), assign func(*Token, *Account)) error { + ids := make([]sid.ID, 0, len(nodes)) + nodeids := make(map[sid.ID][]*Token) + for i := range nodes { + if nodes[i].account_token == nil { + continue + } + fk := *nodes[i].account_token + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(account.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "account_token" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (tq *TokenQuery) sqlCount(ctx context.Context) (int, error) { _spec := tq.querySpec() _spec.Node.Columns = tq.fields diff --git a/entc/integration/customid/ent/user_query.go b/entc/integration/customid/ent/user_query.go index 1563719d16..771f19a4e6 100644 --- a/entc/integration/customid/ent/user_query.go +++ b/entc/integration/customid/ent/user_query.go @@ -474,148 +474,181 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) - } - } } - if query := uq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_children == nil { - continue - } - fk := *nodes[i].user_children - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := uq.loadParent(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Parent = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withChildren; query != nil { + if err := uq.loadChildren(ctx, query, nodes, + func(n *User) { n.Edges.Children = []*User{} }, + func(n *User, e *User) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_children" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := uq.withPets; query != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*User{} - } - query.withFKs = true - query.Where(predicate.User(func(s *sql.Selector) { - s.Where(sql.InValues(user.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - for _, n := range neighbors { - fk := n.user_children - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_children" is nil for node %v`, n.ID) + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_children" returned %v for node %v`, *fk, n.ID) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.Children = append(node.Edges.Children, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadParent(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_children == nil { + continue + } + fk := *nodes[i].user_children + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_children" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + assign(nodes[i], n) } - for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) + } + return nil +} +func (uq *UserQuery) loadChildren(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(user.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_children + if fk == nil { + return fmt.Errorf(`foreign-key "user_children" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_children" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_pets + if fk == nil { + return fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgefield/ent/car_query.go b/entc/integration/edgefield/ent/car_query.go index 6587d2cbca..6f0937498f 100644 --- a/entc/integration/edgefield/ent/car_query.go +++ b/entc/integration/edgefield/ent/car_query.go @@ -382,35 +382,44 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err if len(nodes) == 0 { return nodes, nil } - if query := cq.withRentals; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[uuid.UUID]*Car) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Rentals = []*Rental{} - } - query.Where(predicate.Rental(func(s *sql.Selector) { - s.Where(sql.InValues(car.RentalsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadRentals(ctx, query, nodes, + func(n *Car) { n.Edges.Rentals = []*Rental{} }, + func(n *Car, e *Rental) { n.Edges.Rentals = append(n.Edges.Rentals, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.CarID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "car_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Rentals = append(node.Edges.Rentals, n) - } } - return nodes, nil } +func (cq *CarQuery) loadRentals(ctx context.Context, query *RentalQuery, nodes []*Car, init func(*Car), assign func(*Car, *Rental)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[uuid.UUID]*Car) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.Rental(func(s *sql.Selector) { + s.Where(sql.InValues(car.RentalsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.CarID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "car_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/edgefield/ent/card_query.go b/entc/integration/edgefield/ent/card_query.go index b75b07dc06..cb7367710f 100644 --- a/entc/integration/edgefield/ent/card_query.go +++ b/entc/integration/edgefield/ent/card_query.go @@ -380,36 +380,42 @@ func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, e if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Card) - for i := range nodes { - fk := nodes[i].OwnerID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Card, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Card, init func(*Card), assign func(*Card, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Card) + for i := range nodes { + fk := nodes[i].OwnerID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/edgefield/ent/info_query.go b/entc/integration/edgefield/ent/info_query.go index b6a28c6a9e..1f787503ad 100644 --- a/entc/integration/edgefield/ent/info_query.go +++ b/entc/integration/edgefield/ent/info_query.go @@ -380,36 +380,42 @@ func (iq *InfoQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Info, e if len(nodes) == 0 { return nodes, nil } - if query := iq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Info) - for i := range nodes { - fk := nodes[i].ID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := iq.loadUser(ctx, query, nodes, nil, + func(n *Info, e *User) { n.Edges.User = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } - } } - return nodes, nil } +func (iq *InfoQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Info, init func(*Info), assign func(*Info, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Info) + for i := range nodes { + fk := nodes[i].ID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (iq *InfoQuery) sqlCount(ctx context.Context) (int, error) { _spec := iq.querySpec() _spec.Node.Columns = iq.fields diff --git a/entc/integration/edgefield/ent/metadata_query.go b/entc/integration/edgefield/ent/metadata_query.go index bd534b8219..d0eeedc963 100644 --- a/entc/integration/edgefield/ent/metadata_query.go +++ b/entc/integration/edgefield/ent/metadata_query.go @@ -453,85 +453,106 @@ func (mq *MetadataQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Met if len(nodes) == 0 { return nodes, nil } - if query := mq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Metadata) - for i := range nodes { - fk := nodes[i].ID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := mq.loadUser(ctx, query, nodes, nil, + func(n *Metadata, e *User) { n.Edges.User = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := mq.withChildren; query != nil { + if err := mq.loadChildren(ctx, query, nodes, + func(n *Metadata) { n.Edges.Children = []*Metadata{} }, + func(n *Metadata, e *Metadata) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := mq.withParent; query != nil { + if err := mq.loadParent(ctx, query, nodes, nil, + func(n *Metadata, e *Metadata) { n.Edges.Parent = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := mq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Metadata) +func (mq *MetadataQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Metadata, init func(*Metadata), assign func(*Metadata, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Metadata) + for i := range nodes { + fk := nodes[i].ID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "id" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*Metadata{} + assign(nodes[i], n) } - query.Where(predicate.Metadata(func(s *sql.Selector) { - s.Where(sql.InValues(metadata.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (mq *MetadataQuery) loadChildren(ctx context.Context, query *MetadataQuery, nodes []*Metadata, init func(*Metadata), assign func(*Metadata, *Metadata)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Metadata) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.ParentID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.Where(predicate.Metadata(func(s *sql.Selector) { + s.Where(sql.InValues(metadata.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ParentID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - if query := mq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Metadata) - for i := range nodes { - fk := nodes[i].ParentID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + return nil +} +func (mq *MetadataQuery) loadParent(ctx context.Context, query *MetadataQuery, nodes []*Metadata, init func(*Metadata), assign func(*Metadata, *Metadata)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Metadata) + for i := range nodes { + fk := nodes[i].ParentID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - query.Where(metadata.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(metadata.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (mq *MetadataQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgefield/ent/node_query.go b/entc/integration/edgefield/ent/node_query.go index f600c99dc4..5d35622ddd 100644 --- a/entc/integration/edgefield/ent/node_query.go +++ b/entc/integration/edgefield/ent/node_query.go @@ -416,58 +416,70 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e if len(nodes) == 0 { return nodes, nil } - if query := nq.withPrev; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Node) - for i := range nodes { - fk := nodes[i].PrevID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(node.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := nq.loadPrev(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Prev = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "prev_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Prev = n - } + } + if query := nq.withNext; query != nil { + if err := nq.loadNext(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Next = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := nq.withNext; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Node) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] +func (nq *NodeQuery) loadPrev(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Node) + for i := range nodes { + fk := nodes[i].PrevID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - query.Where(predicate.Node(func(s *sql.Selector) { - s.Where(sql.InValues(node.NextColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "prev_id" returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.PrevID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "prev_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Next = n + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil +} +func (nq *NodeQuery) loadNext(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.Where(predicate.Node(func(s *sql.Selector) { + s.Where(sql.InValues(node.NextColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PrevID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "prev_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgefield/ent/pet_query.go b/entc/integration/edgefield/ent/pet_query.go index 706375538a..f2fee95803 100644 --- a/entc/integration/edgefield/ent/pet_query.go +++ b/entc/integration/edgefield/ent/pet_query.go @@ -380,36 +380,42 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - fk := nodes[i].OwnerID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + fk := nodes[i].OwnerID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Columns = pq.fields diff --git a/entc/integration/edgefield/ent/post_query.go b/entc/integration/edgefield/ent/post_query.go index 0109475b59..becf6dd0cb 100644 --- a/entc/integration/edgefield/ent/post_query.go +++ b/entc/integration/edgefield/ent/post_query.go @@ -380,39 +380,45 @@ func (pq *PostQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Post, e if len(nodes) == 0 { return nodes, nil } - if query := pq.withAuthor; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Post) - for i := range nodes { - if nodes[i].AuthorID == nil { - continue - } - fk := *nodes[i].AuthorID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadAuthor(ctx, query, nodes, nil, + func(n *Post, e *User) { n.Edges.Author = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "author_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Author = n - } - } } - return nodes, nil } +func (pq *PostQuery) loadAuthor(ctx context.Context, query *UserQuery, nodes []*Post, init func(*Post), assign func(*Post, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Post) + for i := range nodes { + if nodes[i].AuthorID == nil { + continue + } + fk := *nodes[i].AuthorID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "author_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PostQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Columns = pq.fields diff --git a/entc/integration/edgefield/ent/rental_query.go b/entc/integration/edgefield/ent/rental_query.go index 9f647eb3d1..dd19294925 100644 --- a/entc/integration/edgefield/ent/rental_query.go +++ b/entc/integration/edgefield/ent/rental_query.go @@ -418,60 +418,72 @@ func (rq *RentalQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Renta if len(nodes) == 0 { return nodes, nil } - if query := rq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Rental) - for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := rq.loadUser(ctx, query, nodes, nil, + func(n *Rental, e *User) { n.Edges.User = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := rq.withCar; query != nil { + if err := rq.loadCar(ctx, query, nodes, nil, + func(n *Rental, e *Car) { n.Edges.Car = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := rq.withCar; query != nil { - ids := make([]uuid.UUID, 0, len(nodes)) - nodeids := make(map[uuid.UUID][]*Rental) +func (rq *RentalQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Rental, init func(*Rental), assign func(*Rental, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Rental) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].CarID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(car.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (rq *RentalQuery) loadCar(ctx context.Context, query *CarQuery, nodes []*Rental, init func(*Rental), assign func(*Rental, *Car)) error { + ids := make([]uuid.UUID, 0, len(nodes)) + nodeids := make(map[uuid.UUID][]*Rental) + for i := range nodes { + fk := nodes[i].CarID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "car_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Car = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(car.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "car_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (rq *RentalQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgefield/ent/user_query.go b/entc/integration/edgefield/ent/user_query.go index 9c2943ebc4..0302605a03 100644 --- a/entc/integration/edgefield/ent/user_query.go +++ b/entc/integration/edgefield/ent/user_query.go @@ -637,208 +637,268 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.OwnerID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) - } } - if query := uq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - fk := nodes[i].ParentID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadParent(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } - } } - if query := uq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*User{} + if err := uq.loadChildren(ctx, query, nodes, + func(n *User) { n.Edges.Children = []*User{} }, + func(n *User, e *User) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err } - query.Where(predicate.User(func(s *sql.Selector) { - s.Where(sql.InValues(user.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withSpouse; query != nil { + if err := uq.loadSpouse(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Spouse = e }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.ParentID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + if query := uq.withCard; query != nil { + if err := uq.loadCard(ctx, query, nodes, nil, + func(n *User, e *Card) { n.Edges.Card = e }); err != nil { + return nil, err } } - - if query := uq.withSpouse; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - fk := nodes[i].SpouseID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if query := uq.withMetadata; query != nil { + if err := uq.loadMetadata(ctx, query, nodes, nil, + func(n *User, e *Metadata) { n.Edges.Metadata = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withInfo; query != nil { + if err := uq.loadInfo(ctx, query, nodes, + func(n *User) { n.Edges.Info = []*Info{} }, + func(n *User, e *Info) { n.Edges.Info = append(n.Edges.Info, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "spouse_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Spouse = n - } + } + if query := uq.withRentals; query != nil { + if err := uq.loadRentals(ctx, query, nodes, + func(n *User) { n.Edges.Rentals = []*Rental{} }, + func(n *User, e *Rental) { n.Edges.Rentals = append(n.Edges.Rentals, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withCard; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(predicate.Card(func(s *sql.Selector) { - s.Where(sql.InValues(user.CardColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.OwnerID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) } - for _, n := range neighbors { - fk := n.OwnerID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Card = n + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadParent(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + fk := nodes[i].ParentID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } + nodeids[fk] = append(nodeids[fk], nodes[i]) } - - if query := uq.withMetadata; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] + assign(nodes[i], n) } - query.Where(predicate.Metadata(func(s *sql.Selector) { - s.Where(sql.InValues(user.MetadataColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (uq *UserQuery) loadChildren(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.ID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Metadata = n + } + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(user.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ParentID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - if query := uq.withInfo; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) + return nil +} +func (uq *UserQuery) loadSpouse(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + fk := nodes[i].SpouseID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "spouse_id" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Info = []*Info{} + assign(nodes[i], n) } - query.Where(predicate.Info(func(s *sql.Selector) { - s.Where(sql.InValues(user.InfoColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (uq *UserQuery) loadCard(ctx context.Context, query *CardQuery, nodes []*User, init func(*User), assign func(*User, *Card)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.Where(predicate.Card(func(s *sql.Selector) { + s.Where(sql.InValues(user.CardColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.OwnerID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) } - for _, n := range neighbors { - fk := n.ID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Info = append(node.Edges.Info, n) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadMetadata(ctx context.Context, query *MetadataQuery, nodes []*User, init func(*User), assign func(*User, *Metadata)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.Where(predicate.Metadata(func(s *sql.Selector) { + s.Where(sql.InValues(user.MetadataColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - if query := uq.withRentals; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Rentals = []*Rental{} + return nil +} +func (uq *UserQuery) loadInfo(ctx context.Context, query *InfoQuery, nodes []*User, init func(*User), assign func(*User, *Info)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(predicate.Rental(func(s *sql.Selector) { - s.Where(sql.InValues(user.RentalsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + query.Where(predicate.Info(func(s *sql.Selector) { + s.Where(sql.InValues(user.InfoColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "id" returned %v for node %v`, fk, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Rentals = append(node.Edges.Rentals, n) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadRentals(ctx context.Context, query *RentalQuery, nodes []*User, init func(*User), assign func(*User, *Rental)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.Where(predicate.Rental(func(s *sql.Selector) { + s.Where(sql.InValues(user.RentalsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/friendship_query.go b/entc/integration/edgeschema/ent/friendship_query.go index 802775b683..ba6a6fc7ca 100644 --- a/entc/integration/edgeschema/ent/friendship_query.go +++ b/entc/integration/edgeschema/ent/friendship_query.go @@ -416,60 +416,72 @@ func (fq *FriendshipQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*F if len(nodes) == 0 { return nodes, nil } - if query := fq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Friendship) - for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := fq.loadUser(ctx, query, nodes, nil, + func(n *Friendship, e *User) { n.Edges.User = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := fq.withFriend; query != nil { + if err := fq.loadFriend(ctx, query, nodes, nil, + func(n *Friendship, e *User) { n.Edges.Friend = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := fq.withFriend; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Friendship) +func (fq *FriendshipQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Friendship, init func(*Friendship), assign func(*Friendship, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Friendship) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].FriendID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (fq *FriendshipQuery) loadFriend(ctx context.Context, query *UserQuery, nodes []*Friendship, init func(*Friendship), assign func(*Friendship, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Friendship) + for i := range nodes { + fk := nodes[i].FriendID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "friend_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Friend = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "friend_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (fq *FriendshipQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/group_query.go b/entc/integration/edgeschema/ent/group_query.go index 98804d9c6e..f77f52acba 100644 --- a/entc/integration/edgeschema/ent/group_query.go +++ b/entc/integration/edgeschema/ent/group_query.go @@ -418,86 +418,104 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { + return nil, err } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + } + if query := gq.withJoinedUsers; query != nil { + if err := gq.loadJoinedUsers(ctx, query, nodes, + func(n *Group) { n.Edges.JoinedUsers = []*UserGroup{} }, + func(n *Group, e *UserGroup) { n.Edges.JoinedUsers = append(n.Edges.JoinedUsers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := gq.withJoinedUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Group) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.JoinedUsers = []*UserGroup{} + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - query.Where(predicate.UserGroup(func(s *sql.Selector) { - s.Where(sql.InValues(group.JoinedUsersColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - fk := n.GroupID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.JoinedUsers = append(node.Edges.JoinedUsers, n) + } + return nil +} +func (gq *GroupQuery) loadJoinedUsers(ctx context.Context, query *UserGroupQuery, nodes []*Group, init func(*Group), assign func(*Group, *UserGroup)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.Where(predicate.UserGroup(func(s *sql.Selector) { + s.Where(sql.InValues(group.JoinedUsersColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.GroupID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/relationship_query.go b/entc/integration/edgeschema/ent/relationship_query.go index d48dad23b2..1d6b085a48 100644 --- a/entc/integration/edgeschema/ent/relationship_query.go +++ b/entc/integration/edgeschema/ent/relationship_query.go @@ -383,86 +383,104 @@ func (rq *RelationshipQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([] if len(nodes) == 0 { return nodes, nil } - if query := rq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Relationship) - for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := rq.loadUser(ctx, query, nodes, nil, + func(n *Relationship, e *User) { n.Edges.User = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := rq.withRelative; query != nil { + if err := rq.loadRelative(ctx, query, nodes, nil, + func(n *Relationship, e *User) { n.Edges.Relative = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := rq.withInfo; query != nil { + if err := rq.loadInfo(ctx, query, nodes, nil, + func(n *Relationship, e *RelationshipInfo) { n.Edges.Info = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := rq.withRelative; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Relationship) - for i := range nodes { - fk := nodes[i].RelativeID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) +func (rq *RelationshipQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Relationship, init func(*Relationship), assign func(*Relationship, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Relationship) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "relative_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Relative = n - } + for i := range nodes { + assign(nodes[i], n) } } - - if query := rq.withInfo; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Relationship) + return nil +} +func (rq *RelationshipQuery) loadRelative(ctx context.Context, query *UserQuery, nodes []*Relationship, init func(*Relationship), assign func(*Relationship, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Relationship) + for i := range nodes { + fk := nodes[i].RelativeID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "relative_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].InfoID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(relationshipinfo.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (rq *RelationshipQuery) loadInfo(ctx context.Context, query *RelationshipInfoQuery, nodes []*Relationship, init func(*Relationship), assign func(*Relationship, *RelationshipInfo)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Relationship) + for i := range nodes { + fk := nodes[i].InfoID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "info_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Info = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(relationshipinfo.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "info_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (rq *RelationshipQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/role_query.go b/entc/integration/edgeschema/ent/role_query.go index 5ec1b8f3c7..1f9395454a 100644 --- a/entc/integration/edgeschema/ent/role_query.go +++ b/entc/integration/edgeschema/ent/role_query.go @@ -418,86 +418,104 @@ func (rq *RoleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Role, e if len(nodes) == 0 { return nodes, nil } - if query := rq.withUser; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Role) - nids := make(map[int]map[*Role]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.User = []*User{} + if err := rq.loadUser(ctx, query, nodes, + func(n *Role) { n.Edges.User = []*User{} }, + func(n *Role, e *User) { n.Edges.User = append(n.Edges.User, e) }); err != nil { + return nil, err } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(role.UserTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(role.UserPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(role.UserPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(role.UserPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Role]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + } + if query := rq.withRolesUsers; query != nil { + if err := rq.loadRolesUsers(ctx, query, nodes, + func(n *Role) { n.Edges.RolesUsers = []*RoleUser{} }, + func(n *Role, e *RoleUser) { n.Edges.RolesUsers = append(n.Edges.RolesUsers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "user" node returned %v`, n.ID) + } + return nodes, nil +} + +func (rq *RoleQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Role, init func(*Role), assign func(*Role, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Role) + nids := make(map[int]map[*Role]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(role.UserTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(role.UserPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(role.UserPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(role.UserPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.User = append(kn.Edges.User, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Role]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := rq.withRolesUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Role) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.RolesUsers = []*RoleUser{} + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "user" node returned %v`, n.ID) } - query.Where(predicate.RoleUser(func(s *sql.Selector) { - s.Where(sql.InValues(role.RolesUsersColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - fk := n.RoleID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "role_id" returned %v for node %v`, fk, n) - } - node.Edges.RolesUsers = append(node.Edges.RolesUsers, n) + } + return nil +} +func (rq *RoleQuery) loadRolesUsers(ctx context.Context, query *RoleUserQuery, nodes []*Role, init func(*Role), assign func(*Role, *RoleUser)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Role) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.Where(predicate.RoleUser(func(s *sql.Selector) { + s.Where(sql.InValues(role.RolesUsersColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.RoleID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "role_id" returned %v for node %v`, fk, n) + } + assign(node, n) + } + return nil } func (rq *RoleQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/roleuser_query.go b/entc/integration/edgeschema/ent/roleuser_query.go index 23400faf56..ca6c00890e 100644 --- a/entc/integration/edgeschema/ent/roleuser_query.go +++ b/entc/integration/edgeschema/ent/roleuser_query.go @@ -347,60 +347,72 @@ func (ruq *RoleUserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Ro if len(nodes) == 0 { return nodes, nil } - if query := ruq.withRole; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*RoleUser) - for i := range nodes { - fk := nodes[i].RoleID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(role.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := ruq.loadRole(ctx, query, nodes, nil, + func(n *RoleUser, e *Role) { n.Edges.Role = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "role_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Role = n - } + } + if query := ruq.withUser; query != nil { + if err := ruq.loadUser(ctx, query, nodes, nil, + func(n *RoleUser, e *User) { n.Edges.User = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := ruq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*RoleUser) +func (ruq *RoleUserQuery) loadRole(ctx context.Context, query *RoleQuery, nodes []*RoleUser, init func(*RoleUser), assign func(*RoleUser, *Role)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*RoleUser) + for i := range nodes { + fk := nodes[i].RoleID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(role.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "role_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (ruq *RoleUserQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*RoleUser, init func(*RoleUser), assign func(*RoleUser, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*RoleUser) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (ruq *RoleUserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/tag_query.go b/entc/integration/edgeschema/ent/tag_query.go index 5fcbbbfb6e..386df6b549 100644 --- a/entc/integration/edgeschema/ent/tag_query.go +++ b/entc/integration/edgeschema/ent/tag_query.go @@ -418,86 +418,104 @@ func (tq *TagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Tag, err if len(nodes) == 0 { return nodes, nil } - if query := tq.withTweets; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Tag) - nids := make(map[int]map[*Tag]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Tweets = []*Tweet{} + if err := tq.loadTweets(ctx, query, nodes, + func(n *Tag) { n.Edges.Tweets = []*Tweet{} }, + func(n *Tag, e *Tweet) { n.Edges.Tweets = append(n.Edges.Tweets, e) }); err != nil { + return nil, err } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(tag.TweetsTable) - s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(tag.TweetsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(tag.TweetsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(tag.TweetsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Tag]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + } + if query := tq.withTweetTags; query != nil { + if err := tq.loadTweetTags(ctx, query, nodes, + func(n *Tag) { n.Edges.TweetTags = []*TweetTag{} }, + func(n *Tag, e *TweetTag) { n.Edges.TweetTags = append(n.Edges.TweetTags, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "tweets" node returned %v`, n.ID) + } + return nodes, nil +} + +func (tq *TagQuery) loadTweets(ctx context.Context, query *TweetQuery, nodes []*Tag, init func(*Tag), assign func(*Tag, *Tweet)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Tag) + nids := make(map[int]map[*Tag]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(tag.TweetsTable) + s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(tag.TweetsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(tag.TweetsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(tag.TweetsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Tweets = append(kn.Edges.Tweets, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Tag]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := tq.withTweetTags; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Tag) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.TweetTags = []*TweetTag{} + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "tweets" node returned %v`, n.ID) } - query.Where(predicate.TweetTag(func(s *sql.Selector) { - s.Where(sql.InValues(tag.TweetTagsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - fk := n.TagID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tag_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.TweetTags = append(node.Edges.TweetTags, n) + } + return nil +} +func (tq *TagQuery) loadTweetTags(ctx context.Context, query *TweetTagQuery, nodes []*Tag, init func(*Tag), assign func(*Tag, *TweetTag)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Tag) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.Where(predicate.TweetTag(func(s *sql.Selector) { + s.Where(sql.InValues(tag.TweetTagsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TagID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tag_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil } func (tq *TagQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/tweet_query.go b/entc/integration/edgeschema/ent/tweet_query.go index 0c790e0f7e..b64f55c20a 100644 --- a/entc/integration/edgeschema/ent/tweet_query.go +++ b/entc/integration/edgeschema/ent/tweet_query.go @@ -565,242 +565,296 @@ func (tq *TweetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Tweet, if len(nodes) == 0 { return nodes, nil } - if query := tq.withLikedUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Tweet) - nids := make(map[int]map[*Tweet]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.LikedUsers = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(tweet.LikedUsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(tweet.LikedUsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(tweet.LikedUsersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(tweet.LikedUsersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Tweet]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := tq.loadLikedUsers(ctx, query, nodes, + func(n *Tweet) { n.Edges.LikedUsers = []*User{} }, + func(n *Tweet, e *User) { n.Edges.LikedUsers = append(n.Edges.LikedUsers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "liked_users" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.LikedUsers = append(kn.Edges.LikedUsers, n) - } - } } - if query := tq.withUser; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Tweet) - nids := make(map[int]map[*Tweet]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.User = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(tweet.UserTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(tweet.UserPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(tweet.UserPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(tweet.UserPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Tweet]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := tq.loadUser(ctx, query, nodes, + func(n *Tweet) { n.Edges.User = []*User{} }, + func(n *Tweet, e *User) { n.Edges.User = append(n.Edges.User, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "user" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.User = append(kn.Edges.User, n) - } + } + if query := tq.withTags; query != nil { + if err := tq.loadTags(ctx, query, nodes, + func(n *Tweet) { n.Edges.Tags = []*Tag{} }, + func(n *Tweet, e *Tag) { n.Edges.Tags = append(n.Edges.Tags, e) }); err != nil { + return nil, err + } + } + if query := tq.withLikes; query != nil { + if err := tq.loadLikes(ctx, query, nodes, + func(n *Tweet) { n.Edges.Likes = []*TweetLike{} }, + func(n *Tweet, e *TweetLike) { n.Edges.Likes = append(n.Edges.Likes, e) }); err != nil { + return nil, err + } + } + if query := tq.withTweetUser; query != nil { + if err := tq.loadTweetUser(ctx, query, nodes, + func(n *Tweet) { n.Edges.TweetUser = []*UserTweet{} }, + func(n *Tweet, e *UserTweet) { n.Edges.TweetUser = append(n.Edges.TweetUser, e) }); err != nil { + return nil, err + } + } + if query := tq.withTweetTags; query != nil { + if err := tq.loadTweetTags(ctx, query, nodes, + func(n *Tweet) { n.Edges.TweetTags = []*TweetTag{} }, + func(n *Tweet, e *TweetTag) { n.Edges.TweetTags = append(n.Edges.TweetTags, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := tq.withTags; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Tweet) - nids := make(map[int]map[*Tweet]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Tags = []*Tag{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(tweet.TagsTable) - s.Join(joinT).On(s.C(tag.FieldID), joinT.C(tweet.TagsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(tweet.TagsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(tweet.TagsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil +func (tq *TweetQuery) loadLikedUsers(ctx context.Context, query *UserQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Tweet) + nids := make(map[int]map[*Tweet]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(tweet.LikedUsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(tweet.LikedUsersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(tweet.LikedUsersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(tweet.LikedUsersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Tweet]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Tweet]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "tags" node returned %v`, n.ID) + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "liked_users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (tq *TweetQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Tweet) + nids := make(map[int]map[*Tweet]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(tweet.UserTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(tweet.UserPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(tweet.UserPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(tweet.UserPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Tags = append(kn.Edges.Tags, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Tweet]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := tq.withLikes; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Tweet) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Likes = []*TweetLike{} - } - query.Where(predicate.TweetLike(func(s *sql.Selector) { - s.Where(sql.InValues(tweet.LikesColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "user" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.TweetID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n) - } - node.Edges.Likes = append(node.Edges.Likes, n) + for kn := range nodes { + assign(kn, n) } } - - if query := tq.withTweetUser; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Tweet) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.TweetUser = []*UserTweet{} - } - query.Where(predicate.UserTweet(func(s *sql.Selector) { - s.Where(sql.InValues(tweet.TweetUserColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + return nil +} +func (tq *TweetQuery) loadTags(ctx context.Context, query *TagQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *Tag)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Tweet) + nids := make(map[int]map[*Tweet]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(tweet.TagsTable) + s.Join(joinT).On(s.C(tag.FieldID), joinT.C(tweet.TagsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(tweet.TagsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(tweet.TagsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]interface{}{new(sql.NullInt64)}, values...), nil } - for _, n := range neighbors { - fk := n.TweetID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n.ID) + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Tweet]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.TweetUser = append(node.Edges.TweetUser, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := tq.withTweetTags; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Tweet) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.TweetTags = []*TweetTag{} - } - query.Where(predicate.TweetTag(func(s *sql.Selector) { - s.Where(sql.InValues(tweet.TweetTagsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "tags" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.TweetID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.TweetTags = append(node.Edges.TweetTags, n) + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (tq *TweetQuery) loadLikes(ctx context.Context, query *TweetLikeQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *TweetLike)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Tweet) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.TweetLike(func(s *sql.Selector) { + s.Where(sql.InValues(tweet.LikesColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TweetID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n) + } + assign(node, n) + } + return nil +} +func (tq *TweetQuery) loadTweetUser(ctx context.Context, query *UserTweetQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *UserTweet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Tweet) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.UserTweet(func(s *sql.Selector) { + s.Where(sql.InValues(tweet.TweetUserColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TweetID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (tq *TweetQuery) loadTweetTags(ctx context.Context, query *TweetTagQuery, nodes []*Tweet, init func(*Tweet), assign func(*Tweet, *TweetTag)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Tweet) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.TweetTag(func(s *sql.Selector) { + s.Where(sql.InValues(tweet.TweetTagsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TweetID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil } func (tq *TweetQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/tweetlike_query.go b/entc/integration/edgeschema/ent/tweetlike_query.go index 387cfe3751..2f7b8c5c14 100644 --- a/entc/integration/edgeschema/ent/tweetlike_query.go +++ b/entc/integration/edgeschema/ent/tweetlike_query.go @@ -354,60 +354,72 @@ func (tlq *TweetLikeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*T if len(nodes) == 0 { return nodes, nil } - if query := tlq.withTweet; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*TweetLike) - for i := range nodes { - fk := nodes[i].TweetID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(tweet.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := tlq.loadTweet(ctx, query, nodes, nil, + func(n *TweetLike, e *Tweet) { n.Edges.Tweet = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tweet = n - } + } + if query := tlq.withUser; query != nil { + if err := tlq.loadUser(ctx, query, nodes, nil, + func(n *TweetLike, e *User) { n.Edges.User = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := tlq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*TweetLike) +func (tlq *TweetLikeQuery) loadTweet(ctx context.Context, query *TweetQuery, nodes []*TweetLike, init func(*TweetLike), assign func(*TweetLike, *Tweet)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*TweetLike) + for i := range nodes { + fk := nodes[i].TweetID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tweet.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (tlq *TweetLikeQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*TweetLike, init func(*TweetLike), assign func(*TweetLike, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*TweetLike) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (tlq *TweetLikeQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/tweettag_query.go b/entc/integration/edgeschema/ent/tweettag_query.go index e0db5da555..7ea97e335b 100644 --- a/entc/integration/edgeschema/ent/tweettag_query.go +++ b/entc/integration/edgeschema/ent/tweettag_query.go @@ -418,60 +418,72 @@ func (ttq *TweetTagQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Tw if len(nodes) == 0 { return nodes, nil } - if query := ttq.withTag; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*TweetTag) - for i := range nodes { - fk := nodes[i].TagID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(tag.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := ttq.loadTag(ctx, query, nodes, nil, + func(n *TweetTag, e *Tag) { n.Edges.Tag = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tag_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tag = n - } + } + if query := ttq.withTweet; query != nil { + if err := ttq.loadTweet(ctx, query, nodes, nil, + func(n *TweetTag, e *Tweet) { n.Edges.Tweet = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := ttq.withTweet; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*TweetTag) +func (ttq *TweetTagQuery) loadTag(ctx context.Context, query *TagQuery, nodes []*TweetTag, init func(*TweetTag), assign func(*TweetTag, *Tag)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*TweetTag) + for i := range nodes { + fk := nodes[i].TagID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tag.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tag_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].TweetID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(tweet.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (ttq *TweetTagQuery) loadTweet(ctx context.Context, query *TweetQuery, nodes []*TweetTag, init func(*TweetTag), assign func(*TweetTag, *Tweet)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*TweetTag) + for i := range nodes { + fk := nodes[i].TweetID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tweet = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tweet.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (ttq *TweetTagQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/user_query.go b/entc/integration/edgeschema/ent/user_query.go index a348ec8eea..92902a3477 100644 --- a/entc/integration/edgeschema/ent/user_query.go +++ b/entc/integration/edgeschema/ent/user_query.go @@ -792,476 +792,584 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) - } - } } - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } - } } - if query := uq.withRelatives; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Relatives = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.RelativesTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.RelativesPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.RelativesPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.RelativesPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadRelatives(ctx, query, nodes, + func(n *User) { n.Edges.Relatives = []*User{} }, + func(n *User, e *User) { n.Edges.Relatives = append(n.Edges.Relatives, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "relatives" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Relatives = append(kn.Edges.Relatives, n) - } - } } - if query := uq.withLikedTweets; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.LikedTweets = []*Tweet{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.LikedTweetsTable) - s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(user.LikedTweetsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.LikedTweetsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.LikedTweetsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadLikedTweets(ctx, query, nodes, + func(n *User) { n.Edges.LikedTweets = []*Tweet{} }, + func(n *User, e *Tweet) { n.Edges.LikedTweets = append(n.Edges.LikedTweets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "liked_tweets" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.LikedTweets = append(kn.Edges.LikedTweets, n) - } + } + if query := uq.withTweets; query != nil { + if err := uq.loadTweets(ctx, query, nodes, + func(n *User) { n.Edges.Tweets = []*Tweet{} }, + func(n *User, e *Tweet) { n.Edges.Tweets = append(n.Edges.Tweets, e) }); err != nil { + return nil, err } } + if query := uq.withRoles; query != nil { + if err := uq.loadRoles(ctx, query, nodes, + func(n *User) { n.Edges.Roles = []*Role{} }, + func(n *User, e *Role) { n.Edges.Roles = append(n.Edges.Roles, e) }); err != nil { + return nil, err + } + } + if query := uq.withJoinedGroups; query != nil { + if err := uq.loadJoinedGroups(ctx, query, nodes, + func(n *User) { n.Edges.JoinedGroups = []*UserGroup{} }, + func(n *User, e *UserGroup) { n.Edges.JoinedGroups = append(n.Edges.JoinedGroups, e) }); err != nil { + return nil, err + } + } + if query := uq.withFriendships; query != nil { + if err := uq.loadFriendships(ctx, query, nodes, + func(n *User) { n.Edges.Friendships = []*Friendship{} }, + func(n *User, e *Friendship) { n.Edges.Friendships = append(n.Edges.Friendships, e) }); err != nil { + return nil, err + } + } + if query := uq.withRelationship; query != nil { + if err := uq.loadRelationship(ctx, query, nodes, + func(n *User) { n.Edges.Relationship = []*Relationship{} }, + func(n *User, e *Relationship) { n.Edges.Relationship = append(n.Edges.Relationship, e) }); err != nil { + return nil, err + } + } + if query := uq.withLikes; query != nil { + if err := uq.loadLikes(ctx, query, nodes, + func(n *User) { n.Edges.Likes = []*TweetLike{} }, + func(n *User, e *TweetLike) { n.Edges.Likes = append(n.Edges.Likes, e) }); err != nil { + return nil, err + } + } + if query := uq.withUserTweets; query != nil { + if err := uq.loadUserTweets(ctx, query, nodes, + func(n *User) { n.Edges.UserTweets = []*UserTweet{} }, + func(n *User, e *UserTweet) { n.Edges.UserTweets = append(n.Edges.UserTweets, e) }); err != nil { + return nil, err + } + } + if query := uq.withRolesUsers; query != nil { + if err := uq.loadRolesUsers(ctx, query, nodes, + func(n *User) { n.Edges.RolesUsers = []*RoleUser{} }, + func(n *User, e *RoleUser) { n.Edges.RolesUsers = append(n.Edges.RolesUsers, e) }); err != nil { + return nil, err + } + } + return nodes, nil +} - if query := uq.withTweets; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Tweets = []*Tweet{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.TweetsTable) - s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(user.TweetsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.TweetsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.TweetsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "tweets" node returned %v`, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Tweets = append(kn.Edges.Tweets, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withRoles; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Roles = []*Role{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.RolesTable) - s.Join(joinT).On(s.C(role.FieldID), joinT.C(user.RolesPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.RolesPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.RolesPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadRelatives(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.RelativesTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.RelativesPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.RelativesPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.RelativesPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "relatives" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "roles" node returned %v`, n.ID) + } + return nil +} +func (uq *UserQuery) loadLikedTweets(ctx context.Context, query *TweetQuery, nodes []*User, init func(*User), assign func(*User, *Tweet)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.LikedTweetsTable) + s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(user.LikedTweetsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.LikedTweetsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.LikedTweetsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Roles = append(kn.Edges.Roles, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withJoinedGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.JoinedGroups = []*UserGroup{} - } - query.Where(predicate.UserGroup(func(s *sql.Selector) { - s.Where(sql.InValues(user.JoinedGroupsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "liked_tweets" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadTweets(ctx context.Context, query *TweetQuery, nodes []*User, init func(*User), assign func(*User, *Tweet)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.TweetsTable) + s.Join(joinT).On(s.C(tweet.FieldID), joinT.C(user.TweetsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.TweetsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.TweetsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.JoinedGroups = append(node.Edges.JoinedGroups, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withFriendships; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Friendships = []*Friendship{} - } - query.Where(predicate.Friendship(func(s *sql.Selector) { - s.Where(sql.InValues(user.FriendshipsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "tweets" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadRoles(ctx context.Context, query *RoleQuery, nodes []*User, init func(*User), assign func(*User, *Role)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.RolesTable) + s.Join(joinT).On(s.C(role.FieldID), joinT.C(user.RolesPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.RolesPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.RolesPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err + } + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.Friendships = append(node.Edges.Friendships, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withRelationship; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Relationship = []*Relationship{} - } - query.Where(predicate.Relationship(func(s *sql.Selector) { - s.Where(sql.InValues(user.RelationshipColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "roles" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) - } - node.Edges.Relationship = append(node.Edges.Relationship, n) + for kn := range nodes { + assign(kn, n) } } - - if query := uq.withLikes; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Likes = []*TweetLike{} - } - query.Where(predicate.TweetLike(func(s *sql.Selector) { - s.Where(sql.InValues(user.LikesColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + return nil +} +func (uq *UserQuery) loadJoinedGroups(ctx context.Context, query *UserGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserGroup)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.UserGroup(func(s *sql.Selector) { + s.Where(sql.InValues(user.JoinedGroupsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) - } - node.Edges.Likes = append(node.Edges.Likes, n) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadFriendships(ctx context.Context, query *FriendshipQuery, nodes []*User, init func(*User), assign func(*User, *Friendship)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.Friendship(func(s *sql.Selector) { + s.Where(sql.InValues(user.FriendshipsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - if query := uq.withUserTweets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.UserTweets = []*UserTweet{} - } - query.Where(predicate.UserTweet(func(s *sql.Selector) { - s.Where(sql.InValues(user.UserTweetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + return nil +} +func (uq *UserQuery) loadRelationship(ctx context.Context, query *RelationshipQuery, nodes []*User, init func(*User), assign func(*User, *Relationship)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.Relationship(func(s *sql.Selector) { + s.Where(sql.InValues(user.RelationshipColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.UserTweets = append(node.Edges.UserTweets, n) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadLikes(ctx context.Context, query *TweetLikeQuery, nodes []*User, init func(*User), assign func(*User, *TweetLike)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.TweetLike(func(s *sql.Selector) { + s.Where(sql.InValues(user.LikesColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) } + assign(node, n) } - - if query := uq.withRolesUsers; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.RolesUsers = []*RoleUser{} - } - query.Where(predicate.RoleUser(func(s *sql.Selector) { - s.Where(sql.InValues(user.RolesUsersColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + return nil +} +func (uq *UserQuery) loadUserTweets(ctx context.Context, query *UserTweetQuery, nodes []*User, init func(*User), assign func(*User, *UserTweet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.UserTweet(func(s *sql.Selector) { + s.Where(sql.InValues(user.UserTweetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n.ID) } - for _, n := range neighbors { - fk := n.UserID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) - } - node.Edges.RolesUsers = append(node.Edges.RolesUsers, n) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadRolesUsers(ctx context.Context, query *RoleUserQuery, nodes []*User, init func(*User), assign func(*User, *RoleUser)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.Where(predicate.RoleUser(func(s *sql.Selector) { + s.Where(sql.InValues(user.RolesUsersColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v for node %v`, fk, n) } + assign(node, n) } - - return nodes, nil + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/usergroup_query.go b/entc/integration/edgeschema/ent/usergroup_query.go index 212b374bce..c650a47662 100644 --- a/entc/integration/edgeschema/ent/usergroup_query.go +++ b/entc/integration/edgeschema/ent/usergroup_query.go @@ -417,60 +417,72 @@ func (ugq *UserGroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*U if len(nodes) == 0 { return nodes, nil } - if query := ugq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*UserGroup) - for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := ugq.loadUser(ctx, query, nodes, nil, + func(n *UserGroup, e *User) { n.Edges.User = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := ugq.withGroup; query != nil { + if err := ugq.loadGroup(ctx, query, nodes, nil, + func(n *UserGroup, e *Group) { n.Edges.Group = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := ugq.withGroup; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*UserGroup) +func (ugq *UserGroupQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserGroup, init func(*UserGroup), assign func(*UserGroup, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*UserGroup) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].GroupID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(group.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (ugq *UserGroupQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*UserGroup, init func(*UserGroup), assign func(*UserGroup, *Group)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*UserGroup) + for i := range nodes { + fk := nodes[i].GroupID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Group = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(group.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (ugq *UserGroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/edgeschema/ent/usertweet_query.go b/entc/integration/edgeschema/ent/usertweet_query.go index 07d1bdb0b0..0efb0a1062 100644 --- a/entc/integration/edgeschema/ent/usertweet_query.go +++ b/entc/integration/edgeschema/ent/usertweet_query.go @@ -417,60 +417,72 @@ func (utq *UserTweetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*U if len(nodes) == 0 { return nodes, nil } - if query := utq.withUser; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*UserTweet) - for i := range nodes { - fk := nodes[i].UserID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := utq.loadUser(ctx, query, nodes, nil, + func(n *UserTweet, e *User) { n.Edges.User = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.User = n - } + } + if query := utq.withTweet; query != nil { + if err := utq.loadTweet(ctx, query, nodes, nil, + func(n *UserTweet, e *Tweet) { n.Edges.Tweet = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := utq.withTweet; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*UserTweet) +func (utq *UserTweetQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*UserTweet, init func(*UserTweet), assign func(*UserTweet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*UserTweet) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } for i := range nodes { - fk := nodes[i].TweetID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(tweet.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (utq *UserTweetQuery) loadTweet(ctx context.Context, query *TweetQuery, nodes []*UserTweet, init func(*UserTweet), assign func(*UserTweet, *Tweet)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*UserTweet) + for i := range nodes { + fk := nodes[i].TweetID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tweet = n - } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tweet.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tweet_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) } } - - return nodes, nil + return nil } func (utq *UserTweetQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/card_query.go b/entc/integration/ent/card_query.go index ef728ce05d..b406deb021 100644 --- a/entc/integration/ent/card_query.go +++ b/entc/integration/ent/card_query.go @@ -431,90 +431,105 @@ func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, e if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Card) - for i := range nodes { - if nodes[i].user_card == nil { - continue - } - fk := *nodes[i].user_card - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Card, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_card" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } + } + if query := cq.withSpec; query != nil { + if err := cq.loadSpec(ctx, query, nodes, + func(n *Card) { n.Edges.Spec = []*Spec{} }, + func(n *Card, e *Spec) { n.Edges.Spec = append(n.Edges.Spec, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := cq.withSpec; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Card) - nids := make(map[int]map[*Card]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Spec = []*Spec{} +func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Card, init func(*Card), assign func(*Card, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Card) + for i := range nodes { + if nodes[i].user_card == nil { + continue } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(card.SpecTable) - s.Join(joinT).On(s.C(spec.FieldID), joinT.C(card.SpecPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(card.SpecPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(card.SpecPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Card]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + fk := *nodes[i].user_card + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "spec" node returned %v`, n.ID) + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_card" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (cq *CardQuery) loadSpec(ctx context.Context, query *SpecQuery, nodes []*Card, init func(*Card), assign func(*Card, *Spec)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Card) + nids := make(map[int]map[*Card]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(card.SpecTable) + s.Join(joinT).On(s.C(spec.FieldID), joinT.C(card.SpecPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(card.SpecPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(card.SpecPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Spec = append(kn.Edges.Spec, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Card]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "spec" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/file_query.go b/entc/integration/ent/file_query.go index 0137acc37e..9c3bfdf6d0 100644 --- a/entc/integration/ent/file_query.go +++ b/entc/integration/ent/file_query.go @@ -468,95 +468,116 @@ func (fq *FileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*File, e if len(nodes) == 0 { return nodes, nil } - if query := fq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*File) - for i := range nodes { - if nodes[i].user_files == nil { - continue - } - fk := *nodes[i].user_files - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := fq.loadOwner(ctx, query, nodes, nil, + func(n *File, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := fq.withType; query != nil { + if err := fq.loadType(ctx, query, nodes, nil, + func(n *File, e *FileType) { n.Edges.Type = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_files" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } + } + if query := fq.withField; query != nil { + if err := fq.loadField(ctx, query, nodes, + func(n *File) { n.Edges.Field = []*FieldType{} }, + func(n *File, e *FieldType) { n.Edges.Field = append(n.Edges.Field, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := fq.withType; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*File) +func (fq *FileQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*File, init func(*File), assign func(*File, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + if nodes[i].user_files == nil { + continue + } + fk := *nodes[i].user_files + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_files" returned %v`, n.ID) + } for i := range nodes { - if nodes[i].file_type_files == nil { - continue - } - fk := *nodes[i].file_type_files - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(filetype.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (fq *FileQuery) loadType(ctx context.Context, query *FileTypeQuery, nodes []*File, init func(*File), assign func(*File, *FileType)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + if nodes[i].file_type_files == nil { + continue } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "file_type_files" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Type = n - } + fk := *nodes[i].file_type_files + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } + nodeids[fk] = append(nodeids[fk], nodes[i]) } - - if query := fq.withField; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*File) + query.Where(filetype.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_type_files" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Field = []*FieldType{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.FieldType(func(s *sql.Selector) { - s.Where(sql.InValues(file.FieldColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (fq *FileQuery) loadField(ctx context.Context, query *FieldTypeQuery, nodes []*File, init func(*File), assign func(*File, *FieldType)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.file_field - if fk == nil { - return nil, fmt.Errorf(`foreign-key "file_field" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "file_field" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Field = append(node.Edges.Field, n) + } + query.withFKs = true + query.Where(predicate.FieldType(func(s *sql.Selector) { + s.Where(sql.InValues(file.FieldColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.file_field + if fk == nil { + return fmt.Errorf(`foreign-key "file_field" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_field" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - return nodes, nil + return nil } func (fq *FileQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/filetype_query.go b/entc/integration/ent/filetype_query.go index 88074d5082..7aa0084a4c 100644 --- a/entc/integration/ent/filetype_query.go +++ b/entc/integration/ent/filetype_query.go @@ -386,39 +386,48 @@ func (ftq *FileTypeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Fi if len(nodes) == 0 { return nodes, nil } - if query := ftq.withFiles; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*FileType) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Files = []*File{} - } - query.withFKs = true - query.Where(predicate.File(func(s *sql.Selector) { - s.Where(sql.InValues(filetype.FilesColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := ftq.loadFiles(ctx, query, nodes, + func(n *FileType) { n.Edges.Files = []*File{} }, + func(n *FileType, e *File) { n.Edges.Files = append(n.Edges.Files, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.file_type_files - if fk == nil { - return nil, fmt.Errorf(`foreign-key "file_type_files" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "file_type_files" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Files = append(node.Edges.Files, n) - } } - return nodes, nil } +func (ftq *FileTypeQuery) loadFiles(ctx context.Context, query *FileQuery, nodes []*FileType, init func(*FileType), assign func(*FileType, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*FileType) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(filetype.FilesColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.file_type_files + if fk == nil { + return fmt.Errorf(`foreign-key "file_type_files" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "file_type_files" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (ftq *FileTypeQuery) sqlCount(ctx context.Context) (int, error) { _spec := ftq.querySpec() if len(ftq.modifiers) > 0 { diff --git a/entc/integration/ent/group_query.go b/entc/integration/ent/group_query.go index 5f0d86814a..69a86d26f2 100644 --- a/entc/integration/ent/group_query.go +++ b/entc/integration/ent/group_query.go @@ -504,148 +504,181 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withFiles; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Group) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Files = []*File{} - } - query.withFKs = true - query.Where(predicate.File(func(s *sql.Selector) { - s.Where(sql.InValues(group.FilesColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := gq.loadFiles(ctx, query, nodes, + func(n *Group) { n.Edges.Files = []*File{} }, + func(n *Group, e *File) { n.Edges.Files = append(n.Edges.Files, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.group_files - if fk == nil { - return nil, fmt.Errorf(`foreign-key "group_files" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_files" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Files = append(node.Edges.Files, n) - } } - if query := gq.withBlocked; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Group) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Blocked = []*User{} - } - query.withFKs = true - query.Where(predicate.User(func(s *sql.Selector) { - s.Where(sql.InValues(group.BlockedColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := gq.loadBlocked(ctx, query, nodes, + func(n *Group) { n.Edges.Blocked = []*User{} }, + func(n *Group, e *User) { n.Edges.Blocked = append(n.Edges.Blocked, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.group_blocked - if fk == nil { - return nil, fmt.Errorf(`foreign-key "group_blocked" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_blocked" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Blocked = append(node.Edges.Blocked, n) - } } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) - } + } + if query := gq.withInfo; query != nil { + if err := gq.loadInfo(ctx, query, nodes, nil, + func(n *Group, e *GroupInfo) { n.Edges.Info = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := gq.withInfo; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Group) - for i := range nodes { - if nodes[i].group_info == nil { - continue - } - fk := *nodes[i].group_info - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) +func (gq *GroupQuery) loadFiles(ctx context.Context, query *FileQuery, nodes []*Group, init func(*Group), assign func(*Group, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(groupinfo.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + query.withFKs = true + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(group.FilesColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.group_files + if fk == nil { + return fmt.Errorf(`foreign-key "group_files" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_files" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (gq *GroupQuery) loadBlocked(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Group) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(group.BlockedColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.group_blocked + if fk == nil { + return fmt.Errorf(`foreign-key "group_blocked" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_blocked" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_info" returned %v`, n.ID) + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for i := range nodes { - nodes[i].Edges.Info = n + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (gq *GroupQuery) loadInfo(ctx context.Context, query *GroupInfoQuery, nodes []*Group, init func(*Group), assign func(*Group, *GroupInfo)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Group) + for i := range nodes { + if nodes[i].group_info == nil { + continue + } + fk := *nodes[i].group_info + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(groupinfo.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_info" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/groupinfo_query.go b/entc/integration/ent/groupinfo_query.go index b0c06cdd89..7ea26b19fe 100644 --- a/entc/integration/ent/groupinfo_query.go +++ b/entc/integration/ent/groupinfo_query.go @@ -386,39 +386,48 @@ func (giq *GroupInfoQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*G if len(nodes) == 0 { return nodes, nil } - if query := giq.withGroups; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*GroupInfo) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Groups = []*Group{} - } - query.withFKs = true - query.Where(predicate.Group(func(s *sql.Selector) { - s.Where(sql.InValues(groupinfo.GroupsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := giq.loadGroups(ctx, query, nodes, + func(n *GroupInfo) { n.Edges.Groups = []*Group{} }, + func(n *GroupInfo, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.group_info - if fk == nil { - return nil, fmt.Errorf(`foreign-key "group_info" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_info" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Groups = append(node.Edges.Groups, n) - } } - return nodes, nil } +func (giq *GroupInfoQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*GroupInfo, init func(*GroupInfo), assign func(*GroupInfo, *Group)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*GroupInfo) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Group(func(s *sql.Selector) { + s.Where(sql.InValues(groupinfo.GroupsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.group_info + if fk == nil { + return fmt.Errorf(`foreign-key "group_info" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_info" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (giq *GroupInfoQuery) sqlCount(ctx context.Context) (int, error) { _spec := giq.querySpec() if len(giq.modifiers) > 0 { diff --git a/entc/integration/ent/node_query.go b/entc/integration/ent/node_query.go index fde335acb0..474a00b0b9 100644 --- a/entc/integration/ent/node_query.go +++ b/entc/integration/ent/node_query.go @@ -429,65 +429,77 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e if len(nodes) == 0 { return nodes, nil } - if query := nq.withPrev; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Node) - for i := range nodes { - if nodes[i].node_next == nil { - continue - } - fk := *nodes[i].node_next - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(node.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := nq.loadPrev(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Prev = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_next" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Prev = n - } + } + if query := nq.withNext; query != nil { + if err := nq.loadNext(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Next = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := nq.withNext; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Node) +func (nq *NodeQuery) loadPrev(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Node) + for i := range nodes { + if nodes[i].node_next == nil { + continue + } + fk := *nodes[i].node_next + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_next" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Node(func(s *sql.Selector) { - s.Where(sql.InValues(node.NextColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (nq *NodeQuery) loadNext(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Node(func(s *sql.Selector) { + s.Where(sql.InValues(node.NextColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.node_next + if fk == nil { + return fmt.Errorf(`foreign-key "node_next" is nil for node %v`, n.ID) } - for _, n := range neighbors { - fk := n.node_next - if fk == nil { - return nil, fmt.Errorf(`foreign-key "node_next" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_next" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Next = n + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_next" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - return nodes, nil + return nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/pet_query.go b/entc/integration/ent/pet_query.go index 3f3ec41c64..a0fcd04e9b 100644 --- a/entc/integration/ent/pet_query.go +++ b/entc/integration/ent/pet_query.go @@ -429,66 +429,78 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withTeam; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].user_team == nil { - continue - } - fk := *nodes[i].user_team - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadTeam(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Team = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_team" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Team = n - } + } + if query := pq.withOwner; query != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) +func (pq *PetQuery) loadTeam(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_team == nil { + continue + } + fk := *nodes[i].user_team + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_team" returned %v`, n.ID) + } for i := range nodes { - if nodes[i].user_pets == nil { - continue - } - fk := *nodes[i].user_pets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_pets == nil { + continue } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } + fk := *nodes[i].user_pets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } + nodeids[fk] = append(nodeids[fk], nodes[i]) } - - return nodes, nil + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/spec_query.go b/entc/integration/ent/spec_query.go index 752c0f461d..e84c7caefa 100644 --- a/entc/integration/ent/spec_query.go +++ b/entc/integration/ent/spec_query.go @@ -362,61 +362,70 @@ func (sq *SpecQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Spec, e if len(nodes) == 0 { return nodes, nil } - if query := sq.withCard; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Spec) - nids := make(map[int]map[*Spec]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Card = []*Card{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(spec.CardTable) - s.Join(joinT).On(s.C(card.FieldID), joinT.C(spec.CardPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(spec.CardPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(spec.CardPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Spec]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := sq.loadCard(ctx, query, nodes, + func(n *Spec) { n.Edges.Card = []*Card{} }, + func(n *Spec, e *Card) { n.Edges.Card = append(n.Edges.Card, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "card" node returned %v`, n.ID) + } + return nodes, nil +} + +func (sq *SpecQuery) loadCard(ctx context.Context, query *CardQuery, nodes []*Spec, init func(*Spec), assign func(*Spec, *Card)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Spec) + nids := make(map[int]map[*Spec]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(spec.CardTable) + s.Join(joinT).On(s.C(card.FieldID), joinT.C(spec.CardPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(spec.CardPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(spec.CardPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Card = append(kn.Edges.Card, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Spec]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "card" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (sq *SpecQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/ent/user_query.go b/entc/integration/ent/user_query.go index 1c856a1f93..d08b514303 100644 --- a/entc/integration/ent/user_query.go +++ b/entc/integration/ent/user_query.go @@ -757,421 +757,508 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withCard; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - } - query.withFKs = true - query.Where(predicate.Card(func(s *sql.Selector) { - s.Where(sql.InValues(user.CardColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadCard(ctx, query, nodes, nil, + func(n *User, e *Card) { n.Edges.Card = e }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_card - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_card" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_card" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Card = n - } } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) - } } - if query := uq.withFiles; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Files = []*File{} - } - query.withFKs = true - query.Where(predicate.File(func(s *sql.Selector) { - s.Where(sql.InValues(user.FilesColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadFiles(ctx, query, nodes, + func(n *User) { n.Edges.Files = []*File{} }, + func(n *User, e *File) { n.Edges.Files = append(n.Edges.Files, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_files - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_files" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_files" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Files = append(node.Edges.Files, n) - } } - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) - } - } } - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } - } } - if query := uq.withFollowers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Followers = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFollowers(ctx, query, nodes, + func(n *User) { n.Edges.Followers = []*User{} }, + func(n *User, e *User) { n.Edges.Followers = append(n.Edges.Followers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Followers = append(kn.Edges.Followers, n) - } - } } - if query := uq.withFollowing; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Following = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowingTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowingPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFollowing(ctx, query, nodes, + func(n *User) { n.Edges.Following = []*User{} }, + func(n *User, e *User) { n.Edges.Following = append(n.Edges.Following, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Following = append(kn.Edges.Following, n) - } - } } - if query := uq.withTeam; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.TeamColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadTeam(ctx, query, nodes, nil, + func(n *User, e *Pet) { n.Edges.Team = e }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_team - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_team" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_team" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Team = n - } } - if query := uq.withSpouse; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_spouse == nil { - continue - } - fk := *nodes[i].user_spouse - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + if err := uq.loadSpouse(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Spouse = e }); err != nil { + return nil, err } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withChildren; query != nil { + if err := uq.loadChildren(ctx, query, nodes, + func(n *User) { n.Edges.Children = []*User{} }, + func(n *User, e *User) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) + } + if query := uq.withParent; query != nil { + if err := uq.loadParent(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Parent = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (uq *UserQuery) loadCard(ctx context.Context, query *CardQuery, nodes []*User, init func(*User), assign func(*User, *Card)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Card(func(s *sql.Selector) { + s.Where(sql.InValues(user.CardColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_card + if fk == nil { + return fmt.Errorf(`foreign-key "user_card" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_card" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_pets + if fk == nil { + return fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadFiles(ctx context.Context, query *FileQuery, nodes []*User, init func(*User), assign func(*User, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(user.FilesColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_files + if fk == nil { + return fmt.Errorf(`foreign-key "user_files" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_files" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for i := range nodes { - nodes[i].Edges.Spouse = n + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*User{} - } - query.withFKs = true - query.Where(predicate.User(func(s *sql.Selector) { - s.Where(sql.InValues(user.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.user_parent - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_parent" is nil for node %v`, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_parent" returned %v for node %v`, *fk, n.ID) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.Children = append(node.Edges.Children, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_parent == nil { - continue + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadFollowers(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - fk := *nodes[i].user_parent - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) + nids[inValue][byID[outValue]] = struct{}{} + return nil } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_parent" returned %v`, n.ID) + } + return nil +} +func (uq *UserQuery) loadFollowing(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for i := range nodes { - nodes[i].Edges.Parent = n + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "following" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadTeam(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.TeamColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_team + if fk == nil { + return fmt.Errorf(`foreign-key "user_team" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_team" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadSpouse(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_spouse == nil { + continue + } + fk := *nodes[i].user_spouse + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (uq *UserQuery) loadChildren(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(user.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_parent + if fk == nil { + return fmt.Errorf(`foreign-key "user_parent" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_parent" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadParent(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_parent == nil { + continue + } + fk := *nodes[i].user_parent + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_parent" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/hooks/ent/card_query.go b/entc/integration/hooks/ent/card_query.go index 94851c5aa1..d1fd68b066 100644 --- a/entc/integration/hooks/ent/card_query.go +++ b/entc/integration/hooks/ent/card_query.go @@ -388,39 +388,45 @@ func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, e if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Card) - for i := range nodes { - if nodes[i].user_cards == nil { - continue - } - fk := *nodes[i].user_cards - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Card, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_cards" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Card, init func(*Card), assign func(*Card, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Card) + for i := range nodes { + if nodes[i].user_cards == nil { + continue + } + fk := *nodes[i].user_cards + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_cards" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/hooks/ent/user_query.go b/entc/integration/hooks/ent/user_query.go index 74d5d2f8fb..2a719617bb 100644 --- a/entc/integration/hooks/ent/user_query.go +++ b/entc/integration/hooks/ent/user_query.go @@ -461,119 +461,143 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withCards; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Cards = []*Card{} + if err := uq.loadCards(ctx, query, nodes, + func(n *User) { n.Edges.Cards = []*Card{} }, + func(n *User, e *Card) { n.Edges.Cards = append(n.Edges.Cards, e) }); err != nil { + return nil, err } - query.withFKs = true - query.Where(predicate.Card(func(s *sql.Selector) { - s.Where(sql.InValues(user.CardsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withFriends; query != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_cards - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_cards" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_cards" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Cards = append(node.Edges.Cards, n) + } + if query := uq.withBestFriend; query != nil { + if err := uq.loadBestFriend(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.BestFriend = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} +func (uq *UserQuery) loadCards(ctx context.Context, query *CardQuery, nodes []*User, init func(*User), assign func(*User, *Card)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + } + query.withFKs = true + query.Where(predicate.Card(func(s *sql.Selector) { + s.Where(sql.InValues(user.CardsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_cards + if fk == nil { + return fmt.Errorf(`foreign-key "user_cards" is nil for node %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_cards" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - if query := uq.withBestFriend; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_best_friend == nil { - continue + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - fk := *nodes[i].user_best_friend - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) + nids[inValue][byID[outValue]] = struct{}{} + return nil } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_best_friend" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.BestFriend = n - } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (uq *UserQuery) loadBestFriend(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_best_friend == nil { + continue + } + fk := *nodes[i].user_best_friend + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_best_friend" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/idtype/ent/user_query.go b/entc/integration/idtype/ent/user_query.go index 28e3b01d5d..9e8b1625a8 100644 --- a/entc/integration/idtype/ent/user_query.go +++ b/entc/integration/idtype/ent/user_query.go @@ -460,143 +460,167 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withSpouse; query != nil { - ids := make([]uint64, 0, len(nodes)) - nodeids := make(map[uint64][]*User) - for i := range nodes { - if nodes[i].user_spouse == nil { - continue - } - fk := *nodes[i].user_spouse - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadSpouse(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Spouse = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Spouse = n - } - } } - if query := uq.withFollowers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[uint64]*User) - nids := make(map[uint64]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Followers = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := uint64(values[0].(*sql.NullInt64).Int64) - inValue := uint64(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFollowers(ctx, query, nodes, + func(n *User) { n.Edges.Followers = []*User{} }, + func(n *User, e *User) { n.Edges.Followers = append(n.Edges.Followers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Followers = append(kn.Edges.Followers, n) - } + } + if query := uq.withFollowing; query != nil { + if err := uq.loadFollowing(ctx, query, nodes, + func(n *User) { n.Edges.Following = []*User{} }, + func(n *User, e *User) { n.Edges.Following = append(n.Edges.Following, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withFollowing; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[uint64]*User) - nids := make(map[uint64]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Following = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowingTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowingPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil +func (uq *UserQuery) loadSpouse(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]uint64, 0, len(nodes)) + nodeids := make(map[uint64][]*User) + for i := range nodes { + if nodes[i].user_spouse == nil { + continue + } + fk := *nodes[i].user_spouse + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (uq *UserQuery) loadFollowers(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[uint64]*User) + nids := make(map[uint64]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := uint64(values[0].(*sql.NullInt64).Int64) - inValue := uint64(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := uint64(values[0].(*sql.NullInt64).Int64) + inValue := uint64(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) + } + return nil +} +func (uq *UserQuery) loadFollowing(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[uint64]*User) + nids := make(map[uint64]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Following = append(kn.Edges.Following, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := uint64(values[0].(*sql.NullInt64).Int64) + inValue := uint64(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "following" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/migrate/entv1/car_query.go b/entc/integration/migrate/entv1/car_query.go index cd25883299..1adfd58c0a 100644 --- a/entc/integration/migrate/entv1/car_query.go +++ b/entc/integration/migrate/entv1/car_query.go @@ -364,39 +364,45 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Car) - for i := range nodes { - if nodes[i].user_car == nil { - continue - } - fk := *nodes[i].user_car - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Car, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_car" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Car, init func(*Car), assign func(*Car, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Car) + for i := range nodes { + if nodes[i].user_car == nil { + continue + } + fk := *nodes[i].user_car + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_car" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/migrate/entv1/user_query.go b/entc/integration/migrate/entv1/user_query.go index 82e00722b1..062d1a7146 100644 --- a/entc/integration/migrate/entv1/user_query.go +++ b/entc/integration/migrate/entv1/user_query.go @@ -497,123 +497,150 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_children == nil { - continue - } - fk := *nodes[i].user_children - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadParent(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_children" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } - } } - if query := uq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*User{} - } - query.withFKs = true - query.Where(predicate.User(func(s *sql.Selector) { - s.Where(sql.InValues(user.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadChildren(ctx, query, nodes, + func(n *User) { n.Edges.Children = []*User{} }, + func(n *User, e *User) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_children - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_children" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_children" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + if query := uq.withSpouse; query != nil { + if err := uq.loadSpouse(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Spouse = e }); err != nil { + return nil, err } } + if query := uq.withCar; query != nil { + if err := uq.loadCar(ctx, query, nodes, nil, + func(n *User, e *Car) { n.Edges.Car = e }); err != nil { + return nil, err + } + } + return nodes, nil +} - if query := uq.withSpouse; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) +func (uq *UserQuery) loadParent(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_children == nil { + continue + } + fk := *nodes[i].user_children + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_children" returned %v`, n.ID) + } for i := range nodes { - if nodes[i].user_spouse == nil { - continue - } - fk := *nodes[i].user_spouse - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) + assign(nodes[i], n) } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (uq *UserQuery) loadChildren(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Spouse = n - } + } + query.withFKs = true + query.Where(predicate.User(func(s *sql.Selector) { + s.Where(sql.InValues(user.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_children + if fk == nil { + return fmt.Errorf(`foreign-key "user_children" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_children" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - if query := uq.withCar; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) + return nil +} +func (uq *UserQuery) loadSpouse(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_spouse == nil { + continue + } + fk := *nodes[i].user_spouse + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - } - query.withFKs = true - query.Where(predicate.Car(func(s *sql.Selector) { - s.Where(sql.InValues(user.CarColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + assign(nodes[i], n) } - for _, n := range neighbors { - fk := n.user_car - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_car" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_car" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Car = n + } + return nil +} +func (uq *UserQuery) loadCar(ctx context.Context, query *CarQuery, nodes []*User, init func(*User), assign func(*User, *Car)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Car(func(s *sql.Selector) { + s.Where(sql.InValues(user.CarColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_car + if fk == nil { + return fmt.Errorf(`foreign-key "user_car" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_car" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - return nodes, nil + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/migrate/entv2/car_query.go b/entc/integration/migrate/entv2/car_query.go index 4e156ec264..31d1ecc406 100644 --- a/entc/integration/migrate/entv2/car_query.go +++ b/entc/integration/migrate/entv2/car_query.go @@ -388,39 +388,45 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Car) - for i := range nodes { - if nodes[i].user_car == nil { - continue - } - fk := *nodes[i].user_car - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Car, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_car" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Car, init func(*Car), assign func(*Car, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Car) + for i := range nodes { + if nodes[i].user_car == nil { + continue + } + fk := *nodes[i].user_car + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_car" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/entc/integration/migrate/entv2/pet_query.go b/entc/integration/migrate/entv2/pet_query.go index df24fc4fe4..813dd3783d 100644 --- a/entc/integration/migrate/entv2/pet_query.go +++ b/entc/integration/migrate/entv2/pet_query.go @@ -388,39 +388,45 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].owner_id == nil { - continue - } - fk := *nodes[i].owner_id - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].owner_id == nil { + continue + } + fk := *nodes[i].owner_id + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Columns = pq.fields diff --git a/entc/integration/migrate/entv2/user_query.go b/entc/integration/migrate/entv2/user_query.go index 074a73ec34..14aa6bf92e 100644 --- a/entc/integration/migrate/entv2/user_query.go +++ b/entc/integration/migrate/entv2/user_query.go @@ -454,118 +454,142 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withCar; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Car = []*Car{} + if err := uq.loadCar(ctx, query, nodes, + func(n *User) { n.Edges.Car = []*Car{} }, + func(n *User, e *Car) { n.Edges.Car = append(n.Edges.Car, e) }); err != nil { + return nil, err } - query.withFKs = true - query.Where(predicate.Car(func(s *sql.Selector) { - s.Where(sql.InValues(user.CarColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + } + if query := uq.withPets; query != nil { + if err := uq.loadPets(ctx, query, nodes, nil, + func(n *User, e *Pet) { n.Edges.Pets = e }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_car - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_car" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_car" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Car = append(node.Edges.Car, n) + } + if query := uq.withFriends; query != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] +func (uq *UserQuery) loadCar(ctx context.Context, query *CarQuery, nodes []*User, init func(*User), assign func(*User, *Car)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + query.withFKs = true + query.Where(predicate.Car(func(s *sql.Selector) { + s.Where(sql.InValues(user.CarColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_car + if fk == nil { + return fmt.Errorf(`foreign-key "user_car" is nil for node %v`, n.ID) } - for _, n := range neighbors { - fk := n.owner_id - if fk == nil { - return nil, fmt.Errorf(`foreign-key "owner_id" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = n + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_car" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} + return nil +} +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.owner_id + if fk == nil { + return fmt.Errorf(`foreign-key "owner_id" is nil for node %v`, n.ID) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, *fk, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/multischema/ent/group_query.go b/entc/integration/multischema/ent/group_query.go index b79fb5c911..64f59b71ba 100644 --- a/entc/integration/multischema/ent/group_query.go +++ b/entc/integration/multischema/ent/group_query.go @@ -391,61 +391,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/multischema/ent/pet_query.go b/entc/integration/multischema/ent/pet_query.go index b0721b7908..47d51c34d6 100644 --- a/entc/integration/multischema/ent/pet_query.go +++ b/entc/integration/multischema/ent/pet_query.go @@ -390,36 +390,42 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - fk := nodes[i].OwnerID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + fk := nodes[i].OwnerID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Schema = pq.schemaConfig.Pet diff --git a/entc/integration/multischema/ent/user_query.go b/entc/integration/multischema/ent/user_query.go index dc4539d93c..a0f37e0023 100644 --- a/entc/integration/multischema/ent/user_query.go +++ b/entc/integration/multischema/ent/user_query.go @@ -431,86 +431,104 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.OwnerID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) + } + if query := uq.withGroups; query != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + } + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.OwnerID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "owner_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/privacy/ent/task_query.go b/entc/integration/privacy/ent/task_query.go index ffda6160d2..effe52c85d 100644 --- a/entc/integration/privacy/ent/task_query.go +++ b/entc/integration/privacy/ent/task_query.go @@ -433,90 +433,105 @@ func (tq *TaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Task, e if len(nodes) == 0 { return nodes, nil } - if query := tq.withTeams; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Task) - nids := make(map[int]map[*Task]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Teams = []*Team{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(task.TeamsTable) - s.Join(joinT).On(s.C(team.FieldID), joinT.C(task.TeamsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(task.TeamsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(task.TeamsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Task]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := tq.loadTeams(ctx, query, nodes, + func(n *Task) { n.Edges.Teams = []*Team{} }, + func(n *Task, e *Team) { n.Edges.Teams = append(n.Edges.Teams, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Teams = append(kn.Edges.Teams, n) - } + } + if query := tq.withOwner; query != nil { + if err := tq.loadOwner(ctx, query, nodes, nil, + func(n *Task, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := tq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Task) - for i := range nodes { - if nodes[i].user_tasks == nil { - continue +func (tq *TaskQuery) loadTeams(ctx context.Context, query *TeamQuery, nodes []*Task, init func(*Task), assign func(*Task, *Team)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Task) + nids := make(map[int]map[*Task]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(task.TeamsTable) + s.Join(joinT).On(s.C(team.FieldID), joinT.C(task.TeamsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(task.TeamsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(task.TeamsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - fk := *nodes[i].user_tasks - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Task]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) + nids[inValue][byID[outValue]] = struct{}{} + return nil } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_tasks" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (tq *TaskQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Task, init func(*Task), assign func(*Task, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Task) + for i := range nodes { + if nodes[i].user_tasks == nil { + continue + } + fk := *nodes[i].user_tasks + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_tasks" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (tq *TaskQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/privacy/ent/team_query.go b/entc/integration/privacy/ent/team_query.go index c6fc5b05ad..6618b29923 100644 --- a/entc/integration/privacy/ent/team_query.go +++ b/entc/integration/privacy/ent/team_query.go @@ -425,114 +425,132 @@ func (tq *TeamQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Team, e if len(nodes) == 0 { return nodes, nil } - if query := tq.withTasks; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Team) - nids := make(map[int]map[*Team]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Tasks = []*Task{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(team.TasksTable) - s.Join(joinT).On(s.C(task.FieldID), joinT.C(team.TasksPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(team.TasksPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(team.TasksPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Team]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := tq.loadTasks(ctx, query, nodes, + func(n *Team) { n.Edges.Tasks = []*Task{} }, + func(n *Team, e *Task) { n.Edges.Tasks = append(n.Edges.Tasks, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "tasks" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Tasks = append(kn.Edges.Tasks, n) - } + } + if query := tq.withUsers; query != nil { + if err := tq.loadUsers(ctx, query, nodes, + func(n *Team) { n.Edges.Users = []*User{} }, + func(n *Team, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := tq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Team) - nids := make(map[int]map[*Team]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} +func (tq *TeamQuery) loadTasks(ctx context.Context, query *TaskQuery, nodes []*Team, init func(*Team), assign func(*Team, *Task)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Team) + nids := make(map[int]map[*Team]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(team.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(team.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(team.UsersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(team.UsersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(team.TasksTable) + s.Join(joinT).On(s.C(task.FieldID), joinT.C(team.TasksPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(team.TasksPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(team.TasksPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Team]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Team]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "tasks" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (tq *TeamQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Team, init func(*Team), assign func(*Team, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Team) + nids := make(map[int]map[*Team]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(team.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(team.UsersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(team.UsersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(team.UsersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Team]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (tq *TeamQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/privacy/ent/user_query.go b/entc/integration/privacy/ent/user_query.go index e80e8192ab..5e82f95902 100644 --- a/entc/integration/privacy/ent/user_query.go +++ b/entc/integration/privacy/ent/user_query.go @@ -425,90 +425,108 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withTeams; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Teams = []*Team{} + if err := uq.loadTeams(ctx, query, nodes, + func(n *User) { n.Edges.Teams = []*Team{} }, + func(n *User, e *Team) { n.Edges.Teams = append(n.Edges.Teams, e) }); err != nil { + return nil, err } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.TeamsTable) - s.Join(joinT).On(s.C(team.FieldID), joinT.C(user.TeamsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.TeamsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.TeamsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + } + if query := uq.withTasks; query != nil { + if err := uq.loadTasks(ctx, query, nodes, + func(n *User) { n.Edges.Tasks = []*Task{} }, + func(n *User, e *Task) { n.Edges.Tasks = append(n.Edges.Tasks, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) + } + return nodes, nil +} + +func (uq *UserQuery) loadTeams(ctx context.Context, query *TeamQuery, nodes []*User, init func(*User), assign func(*User, *Team)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.TeamsTable) + s.Join(joinT).On(s.C(team.FieldID), joinT.C(user.TeamsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.TeamsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.TeamsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Teams = append(kn.Edges.Teams, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withTasks; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Tasks = []*Task{} + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "teams" node returned %v`, n.ID) } - query.withFKs = true - query.Where(predicate.Task(func(s *sql.Selector) { - s.Where(sql.InValues(user.TasksColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for kn := range nodes { + assign(kn, n) } - for _, n := range neighbors { - fk := n.user_tasks - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_tasks" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_tasks" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Tasks = append(node.Edges.Tasks, n) + } + return nil +} +func (uq *UserQuery) loadTasks(ctx context.Context, query *TaskQuery, nodes []*User, init func(*User), assign func(*User, *Task)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } } - - return nodes, nil + query.withFKs = true + query.Where(predicate.Task(func(s *sql.Selector) { + s.Where(sql.InValues(user.TasksColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_tasks + if fk == nil { + return fmt.Errorf(`foreign-key "user_tasks" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_tasks" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/entc/integration/template/ent/pet_query.go b/entc/integration/template/ent/pet_query.go index 30431df2a7..bbcc133d26 100644 --- a/entc/integration/template/ent/pet_query.go +++ b/entc/integration/template/ent/pet_query.go @@ -394,39 +394,45 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].user_pets == nil { - continue - } - fk := *nodes[i].user_pets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_pets == nil { + continue + } + fk := *nodes[i].user_pets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() if len(pq.modifiers) > 0 { diff --git a/entc/integration/template/ent/user_query.go b/entc/integration/template/ent/user_query.go index b7d90368d8..9a036c6ae5 100644 --- a/entc/integration/template/ent/user_query.go +++ b/entc/integration/template/ent/user_query.go @@ -423,90 +423,108 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) + } + if query := uq.withFriends; query != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_pets + if fk == nil { + return fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/edgeindex/ent/city_query.go b/examples/edgeindex/ent/city_query.go index b00f205026..e7d796aed7 100644 --- a/examples/edgeindex/ent/city_query.go +++ b/examples/edgeindex/ent/city_query.go @@ -381,39 +381,48 @@ func (cq *CityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*City, e if len(nodes) == 0 { return nodes, nil } - if query := cq.withStreets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*City) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Streets = []*Street{} - } - query.withFKs = true - query.Where(predicate.Street(func(s *sql.Selector) { - s.Where(sql.InValues(city.StreetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadStreets(ctx, query, nodes, + func(n *City) { n.Edges.Streets = []*Street{} }, + func(n *City, e *Street) { n.Edges.Streets = append(n.Edges.Streets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.city_streets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "city_streets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "city_streets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Streets = append(node.Edges.Streets, n) - } } - return nodes, nil } +func (cq *CityQuery) loadStreets(ctx context.Context, query *StreetQuery, nodes []*City, init func(*City), assign func(*City, *Street)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*City) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Street(func(s *sql.Selector) { + s.Where(sql.InValues(city.StreetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.city_streets + if fk == nil { + return fmt.Errorf(`foreign-key "city_streets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "city_streets" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (cq *CityQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/examples/edgeindex/ent/street_query.go b/examples/edgeindex/ent/street_query.go index 6f98255041..12feac4e1e 100644 --- a/examples/edgeindex/ent/street_query.go +++ b/examples/edgeindex/ent/street_query.go @@ -388,39 +388,45 @@ func (sq *StreetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Stree if len(nodes) == 0 { return nodes, nil } - if query := sq.withCity; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Street) - for i := range nodes { - if nodes[i].city_streets == nil { - continue - } - fk := *nodes[i].city_streets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(city.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := sq.loadCity(ctx, query, nodes, nil, + func(n *Street, e *City) { n.Edges.City = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "city_streets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.City = n - } - } } - return nodes, nil } +func (sq *StreetQuery) loadCity(ctx context.Context, query *CityQuery, nodes []*Street, init func(*Street), assign func(*Street, *City)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Street) + for i := range nodes { + if nodes[i].city_streets == nil { + continue + } + fk := *nodes[i].city_streets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(city.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "city_streets" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (sq *StreetQuery) sqlCount(ctx context.Context) (int, error) { _spec := sq.querySpec() _spec.Node.Columns = sq.fields diff --git a/examples/fs/ent/file_query.go b/examples/fs/ent/file_query.go index ea57dabdfa..ae2f738f1c 100644 --- a/examples/fs/ent/file_query.go +++ b/examples/fs/ent/file_query.go @@ -416,59 +416,74 @@ func (fq *FileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*File, e if len(nodes) == 0 { return nodes, nil } - if query := fq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*File) - for i := range nodes { - fk := nodes[i].ParentID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(file.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := fq.loadParent(ctx, query, nodes, nil, + func(n *File, e *File) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := fq.withChildren; query != nil { + if err := fq.loadChildren(ctx, query, nodes, + func(n *File) { n.Edges.Children = []*File{} }, + func(n *File, e *File) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := fq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*File) +func (fq *FileQuery) loadParent(ctx context.Context, query *FileQuery, nodes []*File, init func(*File), assign func(*File, *File)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*File) + for i := range nodes { + fk := nodes[i].ParentID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(file.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*File{} + assign(nodes[i], n) } - query.Where(predicate.File(func(s *sql.Selector) { - s.Where(sql.InValues(file.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (fq *FileQuery) loadChildren(ctx context.Context, query *FileQuery, nodes []*File, init func(*File), assign func(*File, *File)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*File) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.ParentID - node, ok := nodeids[fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.Where(predicate.File(func(s *sql.Selector) { + s.Where(sql.InValues(file.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.ParentID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "parent_id" returned %v for node %v`, fk, n.ID) } + assign(node, n) } - - return nodes, nil + return nil } func (fq *FileQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/m2m2types/ent/group_query.go b/examples/m2m2types/ent/group_query.go index d48c03c8a6..9607f10524 100644 --- a/examples/m2m2types/ent/group_query.go +++ b/examples/m2m2types/ent/group_query.go @@ -381,61 +381,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/m2m2types/ent/user_query.go b/examples/m2m2types/ent/user_query.go index aef4f45948..317a5b3dba 100644 --- a/examples/m2m2types/ent/user_query.go +++ b/examples/m2m2types/ent/user_query.go @@ -381,61 +381,70 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + return nodes, nil +} + +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/m2mbidi/ent/user_query.go b/examples/m2mbidi/ent/user_query.go index 3aff07661f..b17076becb 100644 --- a/examples/m2mbidi/ent/user_query.go +++ b/examples/m2mbidi/ent/user_query.go @@ -380,61 +380,70 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + return nodes, nil +} + +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/m2mrecur/ent/user_query.go b/examples/m2mrecur/ent/user_query.go index 26b41fbfcd..effc53a04f 100644 --- a/examples/m2mrecur/ent/user_query.go +++ b/examples/m2mrecur/ent/user_query.go @@ -416,114 +416,132 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withFollowers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Followers = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFollowers(ctx, query, nodes, + func(n *User) { n.Edges.Followers = []*User{} }, + func(n *User, e *User) { n.Edges.Followers = append(n.Edges.Followers, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Followers = append(kn.Edges.Followers, n) - } + } + if query := uq.withFollowing; query != nil { + if err := uq.loadFollowing(ctx, query, nodes, + func(n *User) { n.Edges.Following = []*User{} }, + func(n *User, e *User) { n.Edges.Following = append(n.Edges.Following, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withFollowing; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Following = []*User{} +func (uq *UserQuery) loadFollowers(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FollowingTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FollowingPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.FollowersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - }) - if err != nil { - return nil, err + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "followers" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "following" node returned %v`, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadFollowing(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FollowingTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FollowingPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FollowingPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FollowingPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Following = append(kn.Edges.Following, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "following" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/o2m2types/ent/pet_query.go b/examples/o2m2types/ent/pet_query.go index 5b0a84095e..e7f6a34131 100644 --- a/examples/o2m2types/ent/pet_query.go +++ b/examples/o2m2types/ent/pet_query.go @@ -388,39 +388,45 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].user_pets == nil { - continue - } - fk := *nodes[i].user_pets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_pets == nil { + continue + } + fk := *nodes[i].user_pets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { _spec := pq.querySpec() _spec.Node.Columns = pq.fields diff --git a/examples/o2m2types/ent/user_query.go b/examples/o2m2types/ent/user_query.go index a822e0274d..f781bdb8ca 100644 --- a/examples/o2m2types/ent/user_query.go +++ b/examples/o2m2types/ent/user_query.go @@ -381,39 +381,48 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) - } } - return nodes, nil } +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_pets + if fk == nil { + return fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() _spec.Node.Columns = uq.fields diff --git a/examples/o2mrecur/ent/node_query.go b/examples/o2mrecur/ent/node_query.go index f3ff81e9dc..8f055b2175 100644 --- a/examples/o2mrecur/ent/node_query.go +++ b/examples/o2mrecur/ent/node_query.go @@ -424,66 +424,81 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e if len(nodes) == 0 { return nodes, nil } - if query := nq.withParent; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Node) - for i := range nodes { - if nodes[i].node_children == nil { - continue - } - fk := *nodes[i].node_children - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(node.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := nq.loadParent(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Parent = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_children" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Parent = n - } + } + if query := nq.withChildren; query != nil { + if err := nq.loadChildren(ctx, query, nodes, + func(n *Node) { n.Edges.Children = []*Node{} }, + func(n *Node, e *Node) { n.Edges.Children = append(n.Edges.Children, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := nq.withChildren; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Node) +func (nq *NodeQuery) loadParent(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Node) + for i := range nodes { + if nodes[i].node_children == nil { + continue + } + fk := *nodes[i].node_children + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_children" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Children = []*Node{} + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Node(func(s *sql.Selector) { - s.Where(sql.InValues(node.ChildrenColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (nq *NodeQuery) loadChildren(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - for _, n := range neighbors { - fk := n.node_children - if fk == nil { - return nil, fmt.Errorf(`foreign-key "node_children" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_children" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Children = append(node.Edges.Children, n) + } + query.withFKs = true + query.Where(predicate.Node(func(s *sql.Selector) { + s.Where(sql.InValues(node.ChildrenColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.node_children + if fk == nil { + return fmt.Errorf(`foreign-key "node_children" is nil for node %v`, n.ID) } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_children" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) } - - return nodes, nil + return nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/o2o2types/ent/card_query.go b/examples/o2o2types/ent/card_query.go index 8291f76936..703598bf45 100644 --- a/examples/o2o2types/ent/card_query.go +++ b/examples/o2o2types/ent/card_query.go @@ -388,39 +388,45 @@ func (cq *CardQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Card, e if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Card) - for i := range nodes { - if nodes[i].user_card == nil { - continue - } - fk := *nodes[i].user_card - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Card, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_card" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CardQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Card, init func(*Card), assign func(*Card, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Card) + for i := range nodes { + if nodes[i].user_card == nil { + continue + } + fk := *nodes[i].user_card + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_card" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CardQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/examples/o2o2types/ent/user_query.go b/examples/o2o2types/ent/user_query.go index b3e39ec2b6..d0cacf43c2 100644 --- a/examples/o2o2types/ent/user_query.go +++ b/examples/o2o2types/ent/user_query.go @@ -381,38 +381,44 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withCard; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - } - query.withFKs = true - query.Where(predicate.Card(func(s *sql.Selector) { - s.Where(sql.InValues(user.CardColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadCard(ctx, query, nodes, nil, + func(n *User, e *Card) { n.Edges.Card = e }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_card - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_card" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_card" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Card = n - } } - return nodes, nil } +func (uq *UserQuery) loadCard(ctx context.Context, query *CardQuery, nodes []*User, init func(*User), assign func(*User, *Card)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Card(func(s *sql.Selector) { + s.Where(sql.InValues(user.CardColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_card + if fk == nil { + return fmt.Errorf(`foreign-key "user_card" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_card" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() _spec.Node.Columns = uq.fields diff --git a/examples/o2obidi/ent/user_query.go b/examples/o2obidi/ent/user_query.go index 5534754a03..5ee3cc8eaf 100644 --- a/examples/o2obidi/ent/user_query.go +++ b/examples/o2obidi/ent/user_query.go @@ -387,39 +387,45 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withSpouse; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - if nodes[i].user_spouse == nil { - continue - } - fk := *nodes[i].user_spouse - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadSpouse(ctx, query, nodes, nil, + func(n *User, e *User) { n.Edges.Spouse = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Spouse = n - } - } } - return nodes, nil } +func (uq *UserQuery) loadSpouse(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + if nodes[i].user_spouse == nil { + continue + } + fk := *nodes[i].user_spouse + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_spouse" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() _spec.Node.Columns = uq.fields diff --git a/examples/o2orecur/ent/node_query.go b/examples/o2orecur/ent/node_query.go index f10b145163..12bc2028a3 100644 --- a/examples/o2orecur/ent/node_query.go +++ b/examples/o2orecur/ent/node_query.go @@ -424,65 +424,77 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e if len(nodes) == 0 { return nodes, nil } - if query := nq.withPrev; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Node) - for i := range nodes { - if nodes[i].node_next == nil { - continue - } - fk := *nodes[i].node_next - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(node.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := nq.loadPrev(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Prev = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_next" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Prev = n - } + } + if query := nq.withNext; query != nil { + if err := nq.loadNext(ctx, query, nodes, nil, + func(n *Node, e *Node) { n.Edges.Next = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := nq.withNext; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*Node) +func (nq *NodeQuery) loadPrev(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Node) + for i := range nodes { + if nodes[i].node_next == nil { + continue + } + fk := *nodes[i].node_next + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_next" returned %v`, n.ID) + } for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] + assign(nodes[i], n) } - query.withFKs = true - query.Where(predicate.Node(func(s *sql.Selector) { - s.Where(sql.InValues(node.NextColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + } + return nil +} +func (nq *NodeQuery) loadNext(ctx context.Context, query *NodeQuery, nodes []*Node, init func(*Node), assign func(*Node, *Node)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + query.withFKs = true + query.Where(predicate.Node(func(s *sql.Selector) { + s.Where(sql.InValues(node.NextColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.node_next + if fk == nil { + return fmt.Errorf(`foreign-key "node_next" is nil for node %v`, n.ID) } - for _, n := range neighbors { - fk := n.node_next - if fk == nil { - return nil, fmt.Errorf(`foreign-key "node_next" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "node_next" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Next = n + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_next" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } - - return nodes, nil + return nil } func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/privacytenant/ent/group_query.go b/examples/privacytenant/ent/group_query.go index 916376993e..683eb06b2a 100644 --- a/examples/privacytenant/ent/group_query.go +++ b/examples/privacytenant/ent/group_query.go @@ -425,87 +425,102 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withTenant; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Group) - for i := range nodes { - fk := nodes[i].TenantID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(tenant.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := gq.loadTenant(ctx, query, nodes, nil, + func(n *Group, e *Tenant) { n.Edges.Tenant = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tenant = n - } + } + if query := gq.withUsers; query != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} +func (gq *GroupQuery) loadTenant(ctx context.Context, query *TenantQuery, nodes []*Group, init func(*Group), assign func(*Group, *Tenant)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Group) + for i := range nodes { + fk := nodes[i].TenantID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tenant.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/privacytenant/ent/user_query.go b/examples/privacytenant/ent/user_query.go index 2a9b7502d0..03c3b9e014 100644 --- a/examples/privacytenant/ent/user_query.go +++ b/examples/privacytenant/ent/user_query.go @@ -425,87 +425,102 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withTenant; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*User) - for i := range nodes { - fk := nodes[i].TenantID - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(tenant.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadTenant(ctx, query, nodes, nil, + func(n *User, e *Tenant) { n.Edges.Tenant = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Tenant = n - } + } + if query := uq.withGroups; query != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} +func (uq *UserQuery) loadTenant(ctx context.Context, query *TenantQuery, nodes []*User, init func(*User), assign func(*User, *Tenant)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*User) + for i := range nodes { + fk := nodes[i].TenantID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(tenant.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "tenant_id" returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/start/ent/car_query.go b/examples/start/ent/car_query.go index 1334486599..f51d1e84cb 100644 --- a/examples/start/ent/car_query.go +++ b/examples/start/ent/car_query.go @@ -388,39 +388,45 @@ func (cq *CarQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Car, err if len(nodes) == 0 { return nodes, nil } - if query := cq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Car) - for i := range nodes { - if nodes[i].user_cars == nil { - continue - } - fk := *nodes[i].user_cars - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) - } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { + if err := cq.loadOwner(ctx, query, nodes, nil, + func(n *Car, e *User) { n.Edges.Owner = e }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_cars" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } - } } - return nodes, nil } +func (cq *CarQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Car, init func(*Car), assign func(*Car, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Car) + for i := range nodes { + if nodes[i].user_cars == nil { + continue + } + fk := *nodes[i].user_cars + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_cars" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + func (cq *CarQuery) sqlCount(ctx context.Context) (int, error) { _spec := cq.querySpec() _spec.Node.Columns = cq.fields diff --git a/examples/start/ent/group_query.go b/examples/start/ent/group_query.go index 2a345f3f07..cdb9c97bc5 100644 --- a/examples/start/ent/group_query.go +++ b/examples/start/ent/group_query.go @@ -381,61 +381,70 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + return nodes, nil +} + +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/start/ent/user_query.go b/examples/start/ent/user_query.go index c0395874b7..f619e38d4d 100644 --- a/examples/start/ent/user_query.go +++ b/examples/start/ent/user_query.go @@ -418,90 +418,108 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withCars; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Cars = []*Car{} - } - query.withFKs = true - query.Where(predicate.Car(func(s *sql.Selector) { - s.Where(sql.InValues(user.CarsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadCars(ctx, query, nodes, + func(n *User) { n.Edges.Cars = []*Car{} }, + func(n *User, e *Car) { n.Edges.Cars = append(n.Edges.Cars, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_cars - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_cars" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_cars" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Cars = append(node.Edges.Cars, n) + } + if query := uq.withGroups; query != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { + return nil, err } } + return nodes, nil +} - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} +func (uq *UserQuery) loadCars(ctx context.Context, query *CarQuery, nodes []*User, init func(*User), assign func(*User, *Car)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { - return nil, err + } + query.withFKs = true + query.Where(predicate.Car(func(s *sql.Selector) { + s.Where(sql.InValues(user.CarsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_cars + if fk == nil { + return fmt.Errorf(`foreign-key "user_cars" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_cars" returned %v for node %v`, *fk, n.ID) } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - return nodes, nil + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) + } + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/traversal/ent/group_query.go b/examples/traversal/ent/group_query.go index be413f77ff..dc61a9bfe0 100644 --- a/examples/traversal/ent/group_query.go +++ b/examples/traversal/ent/group_query.go @@ -425,90 +425,105 @@ func (gq *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, if len(nodes) == 0 { return nodes, nil } - if query := gq.withUsers; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Group) - nids := make(map[int]map[*Group]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Users = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(group.UsersTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(group.UsersPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Group]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := gq.loadUsers(ctx, query, nodes, + func(n *Group) { n.Edges.Users = []*User{} }, + func(n *Group, e *User) { n.Edges.Users = append(n.Edges.Users, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "users" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Users = append(kn.Edges.Users, n) - } + } + if query := gq.withAdmin; query != nil { + if err := gq.loadAdmin(ctx, query, nodes, nil, + func(n *Group, e *User) { n.Edges.Admin = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := gq.withAdmin; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Group) - for i := range nodes { - if nodes[i].group_admin == nil { - continue +func (gq *GroupQuery) loadUsers(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Group) + nids := make(map[int]map[*Group]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(group.UsersTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(group.UsersPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(group.UsersPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(group.UsersPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - fk := *nodes[i].group_admin - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Group]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) + nids[inValue][byID[outValue]] = struct{}{} + return nil } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "users" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_admin" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Admin = n - } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (gq *GroupQuery) loadAdmin(ctx context.Context, query *UserQuery, nodes []*Group, init func(*Group), assign func(*Group, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Group) + for i := range nodes { + if nodes[i].group_admin == nil { + continue + } + fk := *nodes[i].group_admin + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_admin" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (gq *GroupQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/traversal/ent/pet_query.go b/examples/traversal/ent/pet_query.go index ec54a7cffb..190b9e20e3 100644 --- a/examples/traversal/ent/pet_query.go +++ b/examples/traversal/ent/pet_query.go @@ -425,90 +425,105 @@ func (pq *PetQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Pet, err if len(nodes) == 0 { return nodes, nil } - if query := pq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*Pet) - nids := make(map[int]map[*Pet]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*Pet{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(pet.FriendsTable) - s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(pet.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*Pet]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := pq.loadFriends(ctx, query, nodes, + func(n *Pet) { n.Edges.Friends = []*Pet{} }, + func(n *Pet, e *Pet) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } + } + if query := pq.withOwner; query != nil { + if err := pq.loadOwner(ctx, query, nodes, nil, + func(n *Pet, e *User) { n.Edges.Owner = e }); err != nil { + return nil, err } } + return nodes, nil +} - if query := pq.withOwner; query != nil { - ids := make([]int, 0, len(nodes)) - nodeids := make(map[int][]*Pet) - for i := range nodes { - if nodes[i].user_pets == nil { - continue +func (pq *PetQuery) loadFriends(ctx context.Context, query *PetQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *Pet)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*Pet) + nids := make(map[int]map[*Pet]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(pet.FriendsTable) + s.Join(joinT).On(s.C(pet.FieldID), joinT.C(pet.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(pet.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(pet.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - fk := *nodes[i].user_pets - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*Pet]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) + nids[inValue][byID[outValue]] = struct{}{} + return nil } - query.Where(user.IDIn(ids...)) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) - } - for i := range nodes { - nodes[i].Edges.Owner = n - } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (pq *PetQuery) loadOwner(ctx context.Context, query *UserQuery, nodes []*Pet, init func(*Pet), assign func(*Pet, *User)) error { + ids := make([]int, 0, len(nodes)) + nodeids := make(map[int][]*Pet) + for i := range nodes { + if nodes[i].user_pets == nil { + continue + } + fk := *nodes[i].user_pets + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil } func (pq *PetQuery) sqlCount(ctx context.Context) (int, error) { diff --git a/examples/traversal/ent/user_query.go b/examples/traversal/ent/user_query.go index 029563d348..441b65bea7 100644 --- a/examples/traversal/ent/user_query.go +++ b/examples/traversal/ent/user_query.go @@ -490,172 +490,208 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e if len(nodes) == 0 { return nodes, nil } - if query := uq.withPets; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Pets = []*Pet{} - } - query.withFKs = true - query.Where(predicate.Pet(func(s *sql.Selector) { - s.Where(sql.InValues(user.PetsColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { + if err := uq.loadPets(ctx, query, nodes, + func(n *User) { n.Edges.Pets = []*Pet{} }, + func(n *User, e *Pet) { n.Edges.Pets = append(n.Edges.Pets, e) }); err != nil { return nil, err } - for _, n := range neighbors { - fk := n.user_pets - if fk == nil { - return nil, fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) - } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) - } - node.Edges.Pets = append(node.Edges.Pets, n) - } } - if query := uq.withFriends; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Friends = []*User{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.FriendsTable) - s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) - s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.FriendsPrimaryKey[0])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadFriends(ctx, query, nodes, + func(n *User) { n.Edges.Friends = []*User{} }, + func(n *User, e *User) { n.Edges.Friends = append(n.Edges.Friends, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) - } - for kn := range nodes { - kn.Edges.Friends = append(kn.Edges.Friends, n) - } - } } - if query := uq.withGroups; query != nil { - edgeids := make([]driver.Value, len(nodes)) - byid := make(map[int]*User) - nids := make(map[int]map[*User]struct{}) - for i, node := range nodes { - edgeids[i] = node.ID - byid[node.ID] = node - node.Edges.Groups = []*Group{} - } - query.Where(func(s *sql.Selector) { - joinT := sql.Table(user.GroupsTable) - s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) - s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeids...)) - columns := s.SelectedColumns() - s.Select(joinT.C(user.GroupsPrimaryKey[1])) - s.AppendSelect(columns...) - s.SetDistinct(false) - }) - neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { - assign := spec.Assign - values := spec.ScanValues - spec.ScanValues = func(columns []string) ([]interface{}, error) { - values, err := values(columns[1:]) - if err != nil { - return nil, err - } - return append([]interface{}{new(sql.NullInt64)}, values...), nil - } - spec.Assign = func(columns []string, values []interface{}) error { - outValue := int(values[0].(*sql.NullInt64).Int64) - inValue := int(values[1].(*sql.NullInt64).Int64) - if nids[inValue] == nil { - nids[inValue] = map[*User]struct{}{byid[outValue]: struct{}{}} - return assign(columns[1:], values[1:]) - } - nids[inValue][byid[outValue]] = struct{}{} - return nil - } - }) - if err != nil { + if err := uq.loadGroups(ctx, query, nodes, + func(n *User) { n.Edges.Groups = []*Group{} }, + func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil { + return nil, err + } + } + if query := uq.withManage; query != nil { + if err := uq.loadManage(ctx, query, nodes, + func(n *User) { n.Edges.Manage = []*Group{} }, + func(n *User, e *Group) { n.Edges.Manage = append(n.Edges.Manage, e) }); err != nil { return nil, err } - for _, n := range neighbors { - nodes, ok := nids[n.ID] - if !ok { - return nil, fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + return nodes, nil +} + +func (uq *UserQuery) loadPets(ctx context.Context, query *PetQuery, nodes []*User, init func(*User), assign func(*User, *Pet)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Pet(func(s *sql.Selector) { + s.Where(sql.InValues(user.PetsColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.user_pets + if fk == nil { + return fmt.Errorf(`foreign-key "user_pets" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_pets" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} +func (uq *UserQuery) loadFriends(ctx context.Context, query *UserQuery, nodes []*User, init func(*User), assign func(*User, *User)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.FriendsTable) + s.Join(joinT).On(s.C(user.FieldID), joinT.C(user.FriendsPrimaryKey[1])) + s.Where(sql.InValues(joinT.C(user.FriendsPrimaryKey[0]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.FriendsPrimaryKey[0])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - for kn := range nodes { - kn.Edges.Groups = append(kn.Edges.Groups, n) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } + nids[inValue][byID[outValue]] = struct{}{} + return nil } + }) + if err != nil { + return err } - - if query := uq.withManage; query != nil { - fks := make([]driver.Value, 0, len(nodes)) - nodeids := make(map[int]*User) - for i := range nodes { - fks = append(fks, nodes[i].ID) - nodeids[nodes[i].ID] = nodes[i] - nodes[i].Edges.Manage = []*Group{} - } - query.withFKs = true - query.Where(predicate.Group(func(s *sql.Selector) { - s.Where(sql.InValues(user.ManageColumn, fks...)) - })) - neighbors, err := query.All(ctx) - if err != nil { - return nil, err + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "friends" node returned %v`, n.ID) } - for _, n := range neighbors { - fk := n.group_admin - if fk == nil { - return nil, fmt.Errorf(`foreign-key "group_admin" is nil for node %v`, n.ID) + for kn := range nodes { + assign(kn, n) + } + } + return nil +} +func (uq *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + edgeIDs := make([]driver.Value, len(nodes)) + byID := make(map[int]*User) + nids := make(map[int]map[*User]struct{}) + for i, node := range nodes { + edgeIDs[i] = node.ID + byID[node.ID] = node + if init != nil { + init(node) + } + } + query.Where(func(s *sql.Selector) { + joinT := sql.Table(user.GroupsTable) + s.Join(joinT).On(s.C(group.FieldID), joinT.C(user.GroupsPrimaryKey[0])) + s.Where(sql.InValues(joinT.C(user.GroupsPrimaryKey[1]), edgeIDs...)) + columns := s.SelectedColumns() + s.Select(joinT.C(user.GroupsPrimaryKey[1])) + s.AppendSelect(columns...) + s.SetDistinct(false) + }) + neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) { + assign := spec.Assign + values := spec.ScanValues + spec.ScanValues = func(columns []string) ([]interface{}, error) { + values, err := values(columns[1:]) + if err != nil { + return nil, err } - node, ok := nodeids[*fk] - if !ok { - return nil, fmt.Errorf(`unexpected foreign-key "group_admin" returned %v for node %v`, *fk, n.ID) + return append([]interface{}{new(sql.NullInt64)}, values...), nil + } + spec.Assign = func(columns []string, values []interface{}) error { + outValue := int(values[0].(*sql.NullInt64).Int64) + inValue := int(values[1].(*sql.NullInt64).Int64) + if nids[inValue] == nil { + nids[inValue] = map[*User]struct{}{byID[outValue]: struct{}{}} + return assign(columns[1:], values[1:]) } - node.Edges.Manage = append(node.Edges.Manage, n) + nids[inValue][byID[outValue]] = struct{}{} + return nil + } + }) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nids[n.ID] + if !ok { + return fmt.Errorf(`unexpected "groups" node returned %v`, n.ID) + } + for kn := range nodes { + assign(kn, n) } } - - return nodes, nil + return nil +} +func (uq *UserQuery) loadManage(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + query.withFKs = true + query.Where(predicate.Group(func(s *sql.Selector) { + s.Where(sql.InValues(user.ManageColumn, fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.group_admin + if fk == nil { + return fmt.Errorf(`foreign-key "group_admin" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected foreign-key "group_admin" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil } func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) {