8000 entc/gen: move eager-loading to method by a8m · Pull Request #2790 · ent/ent · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

entc/gen: move eager-loading to method #2790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 138 additions & 130 deletions entc/gen/template/dialect/sql/query.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
return nodes, nil
}
{{- range $e := $.Edges }}
{{- with extend $ "Rec" $receiver "Edge" $e }}
{{ template "dialect/sql/query/eagerloading" . }}
{{- end }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
if err := {{ $receiver }}.load{{ $e.StructField }}(ctx, query, nodes, {{ if and (not $e.M2M) (not $e.O2M) }}nil{{ else }}
func(n *{{ $.Name }}){ n.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{} }{{ end }},
func(n *{{ $.Name }}, e *{{ $e.Type.Name }}){ n.Edges.{{ $e.StructField }} = {{ if or $e.OwnFK $e.Unique }}e{{ else }}append(n.Edges.{{ $e.StructField }}, e){{ end }} }); err != nil {
return nil, err
}
}
{{- end }}
{{- /* Allow extensions to inject code using templates to process nodes before they are returned. */}}
{{- with $tmpls := matchTemplate "dialect/sql/query/all/nodes/*" }}
Expand All @@ -88,6 +9 F438 2,137 @@ func ({{ $receiver }} *{{ $builder }}) sqlAll(ctx context.Context, hooks ...quer
return nodes, nil
}

{{/* Generate a method to eager-load each edge. */}}
{{- range $e := $.Edges }}
func ({{ $receiver }} *{{ $builder }}) load{{ $e.StructField }}(ctx context.Context, query *{{ $e.Type.QueryName }}, nodes []*{{ $.Name }}, init func(*{{ $.Name }}), assign func(*{{ $.Name }}, *{{ $e.Type.Name }})) error {
{{- if $e.M2M }}
edgeIDs := make([]driver.Value, len(nodes))
byID := make(map[{{ $.ID.Type }}]*{{ $.Name }})
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
for i, node := range nodes {
edgeIDs[i] = node.ID
byID[node.ID] = node
if init != nil {
init(node)
}
}
query.Where(func(s *sql.Selector) {
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeIDs...))
columns := s.SelectedColumns()
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]interface{}{new({{ $out }})}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
if nids[inValue] == nil {
nids[inValue] = map[*{{ $.Name }}]struct{}{byID[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
nids[inValue][byID[outValue]] = struct{}{}
return nil
}
})
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nids[n.ID]
if !ok {
return fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for kn := range nodes {
assign(kn, n)
}
}
{{- else if $e.OwnFK }}
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
for i := range nodes {
{{- $fk := $e.ForeignKey }}
{{- if $fk.Field.Nillable }}
if nodes[i].{{ $fk.StructField }} == nil {
continue
}
{{- end }}
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
query.Where({{ $e.Type.Package }}.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
}
}
{{- else }}
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
{{- if $e.O2M }}
if init != nil {
init(nodes[i])
}
{{- end }}
}
{{- with $e.Type.UnexportedForeignKeys }}
query.withFKs = true
{{- end }}
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return err
}
for _, n := range neighbors {
{{- $fk := $e.ForeignKey }}
fk := n.{{ $fk.StructField }}
{{- if $fk.Field.Nillable }}
if fk == nil {
return fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
}
{{- end }}
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
if !ok {
return fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
}
assign(node, n)
}
{{- end }}
return nil
}
{{- end }}

func ({{ $receiver }} *{{ $builder }}) sqlCount(ctx context.Context) (int, error) {
_spec := {{ $receiver }}.querySpec()
{{- /* Allow mutating the sqlgraph.QuerySpec by ent extensions or user templates. */}}
Expand Down Expand Up @@ -286,133 +421,6 @@ func ({{ $receiver }} *{{ $builder }}) sqlQuery(ctx context.Context) *sql.Select
{{ $ident }} = sqlgraph.Neighbors({{ $receiver }}.driver.Dialect(), step)
{{ end }}

{{ define "dialect/sql/query/eagerloading" }}
{{- $e := $.Scope.Edge }}
{{- $receiver := $.Scope.Rec }}
if query := {{ $receiver }}.{{ $e.EagerLoadField }}; query != nil {
{{- if $e.M2M }}
edgeids := make([]driver.Value, len(nodes))
byid := make(map[{{ $.ID.Type }}]*{{ $.Name }})
nids := make(map[{{ $e.Type.ID.Type }}]map[*{{ $.Name }}]struct{})
for i, node := range nodes {
edgeids[i] = node.ID
byid[node.ID] = node
node.Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
}
query.Where(func(s *sql.Selector) {
joinT := sql.Table({{ $.Package }}.{{ $e.TableConstant }})
{{- $edgeid := print $e.Type.Package "." $e.Type.ID.Constant }}
{{- $fk1idx := 1 }}{{- $fk2idx := 0 }}{{ if $e.IsInverse }}{{ $fk1idx = 0 }}{{ $fk2idx = 1 }}{{ end }}
s.Join(joinT).On(s.C({{ $edgeid }}), joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk1idx }}]))
s.Where(sql.InValues(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]), edgeids...))
columns := s.SelectedColumns()
s.Select(joinT.C({{ $.Package }}.{{ $e.PKConstant }}[{{ $fk2idx }}]))
s.AppendSelect(columns...)
s.SetDistinct(false)
})
neighbors, err := query.sqlAll(ctx, func(_ context.Context, spec *sqlgraph.QuerySpec) {
assign := spec.Assign
values := spec.ScanValues
{{- $out := "sql.NullInt64" }}{{ if $.ID.UserDefined }}{{ $out = $.ID.ScanType }}{{ end }}
{{- $in := "sql.NullInt64" }}{{ if $e.Type.ID.UserDefined }}{{ $in = $e.Type.ID.ScanType }}{{ end }}
spec.ScanValues = func(columns []string) ([]interface{}, error) {
values, err := values(columns[1:])
if err != nil {
return nil, err
}
return append([]interface{}{new({{ $out }})}, values...), nil
}
spec.Assign = func(columns []string, values []interface{}) error {
outValue := {{ with extend $ "Arg" "values[0]" "Field" $.ID "ScanType" $out }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
inValue := {{ with extend $ "Arg" "values[1]" "Field" $e.Type.ID "ScanType" $in }}{{ template "dialect/sql/query/eagerloading/m2massign" . }}{{ end }}
if nids[inValue] == nil {
nids[inValue] = map[*{{ $.Name }}]struct{}{byid[outValue]: struct{}{}}
return assign(columns[1:], values[1:])
}
nids[inValue][byid[outValue]] = struct{}{}
return nil
}
})
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected "{{ $e.Name }}" node returned %v`, n.ID)
}
for kn := range nodes {
kn.Edges.{{ $e.StructField }} = append(kn.Edges.{{ $e.StructField }}, n)
}
}
{{- else if $e.OwnFK }}
ids := make([]{{ $e.Type.ID.Type }}, 0, len(nodes))
nodeids := make(map[{{ $e.Type.ID.Type }}][]*{{ $.Name }})
for i := range nodes {
{{- $fk := $e.ForeignKey }}
{{- if $fk.Field.Nillable }}
if nodes[i].{{ $fk.StructField }} == nil {
continue
}
{{- end }}
fk := {{ if $fk.Field.Nillable }}*{{ end }}nodes[i].{{ $fk.StructField }}
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
query.Where({{ $e.Type.Package }}.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
nodes, ok := nodeids[n.ID]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v`, n.ID)
}
for i := range nodes {
nodes[i].Edges.{{ $e.StructField }} = n
}
}
{{- else }}
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[{{ $.ID.Type }}]*{{ $.Name }})
for i := range nodes {
fks = append(fks, nodes[i].ID)
nodeids[nodes[i].ID] = nodes[i]
{{- if $e.O2M }}
nodes[i].Edges.{{ $e.StructField }} = []*{{ $e.Type.Name }}{}
{{- end }}
}
{{- with $e.Type.UnexportedForeignKeys }}
query.withFKs = true
{{- end }}
query.Where(predicate.{{ $e.Type.Name }}(func(s *sql.Selector) {
s.Where(sql.InValues({{ $.Package }}.{{ $e.ColumnConstant }}, fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
return nil, err
}
for _, n := range neighbors {
{{- $fk := $e.ForeignKey }}
fk := n.{{ $fk.StructField }}
{{- if $fk.Field.Nillable }}
if fk == nil {
return nil, fmt.Errorf(`foreign-key "{{ $fk.Field.Name }}" is nil for node %v`, n.ID)
}
{{- end }}
node, ok := nodeids[{{ if $fk.Field.Nillable }}*{{ end }}fk]
if !ok {
return nil, fmt.Errorf(`unexpected foreign-key "{{ $fk.Field.Name }}" returned %v for node %v`, {{ if $fk.Field.Nillable }}*{{ end }}fk, n{{ if $e.Type.HasOneFieldID }}.ID{{ end }})
}
node.Edges.{{ $e.StructField }} = {{ if $e.Unique }}n{{ else }}append(node.Edges.{{ $e.StructField }}, n){{ end }}
}
{{- end }}
}
{{ end }}

{{ define "dialect/sql/query/eagerloading/m2massign" }}
{{- $arg := $.Scope.Arg }}
{{- $field := $.Scope.Field }}
Expand Down
52 changes: 29 additions & 23 deletions entc/integration/cascadelete/ent/comment_query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
0