8000 feat: support embedding_bag converter (1D input) by zewenli98 · Pull Request #2395 · pytorch/TensorRT · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: support embedding_bag converter (1D input) #2395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,51 @@ def aten_ops_embedding(
)


def embedding_bag_validator(node: Node) -> bool:
mode = args_bounds_check(node.args, 4, 0)
indices = node.args[1].meta.get("tensor_meta")
if indices is None:
return False
return (
bool(node.args[2].op == "get_attr")
and (mode == 0 or mode == 1 or mode == 2)
and len(indices.shape) == 1
)


@dynamo_tensorrt_converter(torch.ops.aten.embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten._embedding_bag.default, capability_validator=embedding_bag_validator) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (TRTTensor,),
2: (np.ndarray, torch.Tensor),
}
) # type: ignore[misc]
def aten_ops_embedding_bag(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.embedding.embedding_bag(
ctx,
target,
SourceIR.ATEN,
name,
weight=args[0],
indices=args[1],
offsets=args[2],
scale_grad_by_freq=args_bounds_check(args, 3, False),
mode=args_bounds_check(args, 4, 0),
sparse=args_bounds_check(args, 5, False),
per_sample_weights=args_bounds_check(args, 6, None),
include_last_offset=args_bounds_check(args, 7, False),
# padding index is useful for training only
)


@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) # type: ignore[misc]
def aten_ops_fmod(
Expand Down
138 changes: 135 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Optional
import functools
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand Down Expand Up @@ -40,5 +43,134 @@ def embedding(

# Implement embedding lookup with gather layer
gather_layer = ctx.net.add_gather(embedding_tensor, indices_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather", source_ir)
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
return gather_layer.get_output(0)


def embedding_bag(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
weight: TRTTensor,
indices: TRTTensor,
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
scale_grad_by_freq: bool,
mode: int,
sparse: bool,
per_sample_weights: Optional[TRTTensor],
include_last_offset: bool,
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
"""
This function is for calculating embedding bags.

In PyTorch, `offsets` is only used when input is 1D. If input is 2D of shape (B, N),
it will be treated as B bags (sequences) each of fixed length N, and this will return
B values aggregated in a way depending on the mode. `offsets` is ignored and required
to be None in this case.

However, according to the schema, `offsets` is required for input with any dimensions.
Accordingly, this function flattens N-D input to 1D and then to calculate embedding bags.
"""

# TODO: support 2D inputs
# indices = impl.shuffle.reshape(ctx, target, source_ir, f"{name}_reshape_indices", indices, (-1,))

if mode == 0: # sum
reduce_op = functools.partial(
impl.reduce.sum, ctx=ctx, target=target, source_ir=source_ir
)
reduce_name = "sum"
elif mode == 1: # mean
reduce_op = functools.partial(
impl.reduce.mean, ctx=ctx, target=target, source_ir=source_ir
)
reduce_name = "mean"
elif mode == 2: # max
reduce_op = functools.partial(
impl.reduce.max,
ctx=ctx,
target=target,
source_ir=source_ir,
return_indices=False,
)
reduce_name = "max"

# calculate embedding
embed = embedding(
ctx,
target,
source_ir,
f"{name}_embedding",
indices,
weight,
scale_grad_by_freq,
sparse,
)

# give weights to embedding
if per_sample_weights is not None:
assert (
per_sample_weights.shape == indices.shape
), f"`per_sample_weights` (shape: {per_sample_weights.shape}) must have exactly the same shape as indices/input (shape: {indices.shape})!"
per_sample_weights = get_trt_tensor(
ctx, per_sample_weights, f"{name}_per_sample_weights", np.float32
)
per_sample_weights = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_per_sample_weights",
per_sample_weights,
(-1, 1),
)
embed = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_per_sample_weights",
embed,
per_sample_weights,
)

offsets = to_numpy(offsets)

if include_last_offset is False:
# add the end index to offsets
offsets = np.append(offsets, indices.shape[0])
else:
# modify the last index of offsets to the end index
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
10000 # is equal to the number of bags + 1. The last element is the size of the input,
# or the ending index position of the last bag (sequence).

offsets[-1] = indices.shape[0]

# separately reduce embeddings for different bags
reduced_embed = []
len_offsets = len(offsets)
for i in range(len_offsets - 1):
if offsets[i] < offsets[i + 1]:
sliced_embed = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_embed_{i}",
embed,
0,
offsets[i],
offsets[i + 1],
1,
)
reduced_sliced_embed = reduce_op(
name=f"{name}_{reduce_name}_{i}",
input_val=sliced_embed,
dim=0,
keepdim=True,
)
reduced_embed.append(reduced_sliced_embed)

out = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", reduced_embed, 0)
# out = reduce_op(input_val=embed, dim=1, keepdim=False) # Note: This implementation doesn't work for N-dim

return out, None, None, None
141 changes: 141 additions & 0 deletions tests/py/dynamo/conversion/test_embedding_bag_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestEmbeddingBagConverter(DispatchTestCase):
@parameterized.expand(
[
# 1D input
param(
test_name="1d_indices_1",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 3], dtype=torch.int32),
scale_grad_by_freq=False,
mode=1,
sparse=False,
per_sample_weights=None,
include_last_offset=True,
padding_idx=-1,
),
param(
test_name="1d_indices_2",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3], dtype=torch.int32),
offsets=torch.tensor([0, 5], dtype=torch.int32),
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=torch.randn((6,)),
include_last_offset=False,
padding_idx=-1,
),
param(
test_name="1d_indices_3",
weight=torch.randn((10, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.int32),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
scale_grad_by_freq=False,
mode=2,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=-1,
),
# 2D input
# param(
# test_name="2d_indices_1",
# weight=torch.randn((5, 10), dtype=torch.float32),
# indices=torch.tensor([[3, 1], [4, 3]], dtype=torch.int32),
# offsets=torch.tensor([0, 1], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=0,
# sparse=False,
# per_sample_weights=torch.randn((4,)),
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_3",
# weight=torch.tensor([
# [0.0, 0.0, 0.0],
# [1.0, 1.0, 1.0],
# [2.0, 2.0, 2.0],
# [3.0, 3.0, 3.0],
# [4.0, 4.0, 4.0],
# [5.0, 5.0, 5.0],
# ], dtype=torch.float32),
# indices=torch.tensor([[0, 2, 1], [3, 5, 4]], dtype=torch.int32),
# offsets=torch.tensor([0, 1], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=2,
# sparse=False,
# per_sample_weights=None,
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_2",
# weight=torch.randn((5, 5), dtype=torch.float32),
# indices=torch.tensor([[3, 1, 2], [4, 2, 3]], dtype=torch.int32),
# offsets=torch.tensor([0, 2], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=1,
# sparse=False,
# per_sample_weights=None,
# include_last_offset=False,
# padding_idx=-1,
# ),
# param(
# test_name="2d_indices_2",
# weight=torch.randn((5, 10), dtype=torch.float32),
# indices=torch.tensor([[3, 1, 2, 4], [4, 1, 3, 1]], dtype=torch.int32),
# offsets=torch.tensor([0, 2], dtype=torch.int32),
# scale_grad_by_freq=False,
# mode=0,
# sparse=False,
# per_sample_weights=torch.randn((8,)),
# include_last_offset=True,
# padding_idx=-1,
# ),
]
)
def test_embedding_bag(
self,
test_name,
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
71CB per_sample_weights,
include_last_offset,
padding_idx,
):
class TestEmbeddingBag(torch.nn.Module):
def forward(self, weight, indices):
return torch.ops.aten._embedding_bag.default(
weight,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx,
)[0]

self.run_test(
TestEmbeddingBag(),
inputs=[weight, indices],
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
0