From 05fe728488ce5ef3e916f37fadbb4d04f2b2d49c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Sep 2023 20:54:20 -0700 Subject: [PATCH 01/11] feat: support flatten and reshape via shuffle_layer --- .../dynamo/conversion/aten_ops_converters.py | 24 +++++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/shuffle.py | 52 +++++++++++++++++++ .../py/dynamo/conversion/test_reshape_aten.py | 1 + 4 files changed, 78 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/shuffle.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67ce83469f..f9bdec4509 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1555,3 +1555,27 @@ def tensorrt_scaled_dot_product_attention( return impl.attention.scaled_dot_product_attention( ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] ) + + +@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.view.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_reshape( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.reshape( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + shape=args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 3f49377619..9e83c6657e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -17,6 +17,7 @@ reduce, select, shape, + shuffle, slice, split, squeeze, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py new file mode 100644 index 0000000000..5d078c146a --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -0,0 +1,52 @@ +from typing import List, Optional, Union + +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def reshape( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + shape: List[int], +) -> TRTTensor: + layer = network.add_shuffle(input) + layer.reshape_dims = tuple(shape) + set_layer_name(layer, target, f"{name}_reshape", source_ir) + return layer.get_output(0) + + +def flatten( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + start_dim: int, + end_dim: int, +) -> TRTTensor: + shape = input.shape + dim_size = len(shape) + start_dim = get_positive_dim(start_dim, dim_size) + end_dim = get_positive_dim(end_dim, dim_size) + + num_elements = 1 + for i in range(start_dim, end_dim + 1): + num_elements *= shape[i] + + new_shape = ( + tuple(shape[:start_dim]) + + (num_elements,) + + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) + ) + layer = network.add_shuffle(input) + layer.reshape_dims = new_shape + set_layer_name(layer, target, f"{name}_flatten", source_ir) + return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_reshape_aten.py b/tests/py/dynamo/conversion/test_reshape_aten.py index 6be138303d..1d137efec5 100644 --- a/tests/py/dynamo/conversion/test_reshape_aten.py +++ b/tests/py/dynamo/conversion/test_reshape_aten.py @@ -65,6 +65,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(target_shape), input_specs, + expected_ops={torch.ops.aten.view.default}, ) From 5f9f7abb268761ec1d861d0aa15a1a288832baac Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 29 Sep 2023 18:36:39 -0700 Subject: [PATCH 02/11] update --- .../dynamo/conversion/impl/shuffle.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 5d078c146a..4624b25571 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,11 +1,14 @@ -from typing import List, Optional, Union +from typing import List, Optional, Sequence, Union +import numpy as np +import torch from torch.fx.node import Target -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, get_positive_dim, - set_layer_name, + get_trt_tensor, ) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor @@ -14,9 +17,11 @@ def reshape( target: Union[Target, str], source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], shape: List[int], ) -> TRTTensor: + if not isinstance(input, TRTTensor): + input = get_trt_tensor(network, input, f"{name}_input") layer = network.add_shuffle(input) layer.reshape_dims = tuple(shape) set_layer_name(layer, target, f"{name}_reshape", source_ir) @@ -28,7 +33,7 @@ def flatten( target: Union[Target, str], source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], start_dim: int, end_dim: int, ) -> TRTTensor: @@ -37,6 +42,9 @@ def flatten( start_dim = get_positive_dim(start_dim, dim_size) end_dim = get_positive_dim(end_dim, dim_size) + if not isinstance(input, TRTTensor): + input = get_trt_tensor(network, input, f"{name}_input") + num_elements = 1 for i in range(start_dim, end_dim + 1): num_elements *= shape[i] From 8c58b9ff43ed06751d6a93b1fc25609315b83011 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 29 Sep 2023 18:47:04 -0700 Subject: [PATCH 03/11] modify TRTNetwork to ConversionContext --- .../dynamo/conversion/impl/shuffle.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 4624b25571..ee2fa80880 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -3,17 +3,18 @@ import numpy as np import torch from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( SourceIR, get_positive_dim, get_trt_tensor, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def reshape( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -21,15 +22,15 @@ def reshape( shape: List[int], ) -> TRTTensor: if not isinstance(input, TRTTensor): - input = get_trt_tensor(network, input, f"{name}_input") - layer = network.add_shuffle(input) + input = get_trt_tensor(ctx, input, f"{name}_input") + layer = ctx.net.add_shuffle(input) layer.reshape_dims = tuple(shape) set_layer_name(layer, target, f"{name}_reshape", source_ir) return layer.get_output(0) def flatten( - network: TRTNetwork, + ctx: ConversionContext, target: Union[Target, str], source_ir: Optional[SourceIR], name: str, @@ -43,7 +44,7 @@ def flatten( end_dim = get_positive_dim(end_dim, dim_size) if not isinstance(input, TRTTensor): - input = get_trt_tensor(network, input, f"{name}_input") + input = get_trt_tensor(ctx, input, f"{name}_input") num_elements = 1 for i in range(start_dim, end_dim + 1): @@ -54,7 +55,7 @@ def flatten( + (num_elements,) + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) ) - layer = network.add_shuffle(input) + layer = ctx.net.add_shuffle(input) layer.reshape_dims = new_shape set_layer_name(layer, target, f"{name}_flatten", source_ir) return layer.get_output(0) From 9f1d7a80ac75b96bdd08eb7b6b6a328b095a1cfa Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 2 Oct 2023 15:03:04 -0700 Subject: [PATCH 04/11] update and add tests --- tests/py/dynamo/conversion/test_reshape_aten.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/py/dynamo/conversion/test_reshape_aten.py b/tests/py/dynamo/conversion/test_reshape_aten.py index 1d137efec5..a2f62a57c5 100644 --- a/tests/py/dynamo/conversion/test_reshape_aten.py +++ b/tests/py/dynamo/conversion/test_reshape_aten.py @@ -12,6 +12,8 @@ class TestReshapeConverter(DispatchTestCase): @parameterized.expand( [ + ((-1,),), + ((20,),), ((1, 20),), ((1, 10, -1),), ] @@ -37,6 +39,7 @@ def forward(self, x): @parameterized.expand( [ + ((-1,),), ((-1, 10),), ((-1, 5),), ((2, 2, -1),), From dfaa36db9c4d4135a57c040ec104db390d2a25f2 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 2 Oct 2023 16:59:50 -0700 Subject: [PATCH 05/11] remove flatten converter, instead, add flatten_dims utility --- .../dynamo/conversion/aten_ops_converters.py | 4 +- .../dynamo/conversion/converter_utils.py | 31 +++++++++ .../dynamo/conversion/impl/shuffle.py | 68 +++++++++---------- 3 files changed, 65 insertions(+), 38 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f9bdec4509..5b5cfc4e60 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -506,7 +506,7 @@ def aten_ops_slice( { 0: (TRTTensor,), } -) # type: ignore[misc] +) def aten_ops_permute( ctx: ConversionContext, target: Target, @@ -1394,7 +1394,7 @@ def conv_param_validator(conv_node: Node) -> bool: 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) # type: ignore[misc] +) def aten_ops_convolution( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index d1d94d40cc..fd7fc83058 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -9,6 +9,7 @@ from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_registry import ( ConverterRegistry, @@ -511,3 +512,33 @@ def to_numpy( raise AssertionError( f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" ) + + +def flatten_dims( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], + start_dim: int, + end_dim: int, +) -> TRTTensor: + shape = input.shape + dim_size = len(shape) + start_dim = get_positive_dim(start_dim, dim_size) + end_dim = get_positive_dim(end_dim, dim_size) + + if not isinstance(input, TRTTensor): + input = get_trt_tensor(ctx, input, f"{name}_flatten") + + num_elements = 1 + for i in range(start_dim, end_dim + 1): + num_elements *= shape[i] + + new_shape = ( + tuple(shape[:start_dim]) + + (num_elements,) + + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) + ) + + return impl.shuffle.reshape(ctx, target, source_ir, name, input, new_shape) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index ee2fa80880..70484b06f2 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -4,11 +4,7 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - get_positive_dim, - get_trt_tensor, -) +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -25,37 +21,37 @@ def reshape( input = get_trt_tensor(ctx, input, f"{name}_input") layer = ctx.net.add_shuffle(input) layer.reshape_dims = tuple(shape) - set_layer_name(layer, target, f"{name}_reshape", source_ir) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) -def flatten( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], - start_dim: int, - end_dim: int, -) -> TRTTensor: - shape = input.shape - dim_size = len(shape) - start_dim = get_positive_dim(start_dim, dim_size) - end_dim = get_positive_dim(end_dim, dim_size) - - if not isinstance(input, TRTTensor): - input = get_trt_tensor(ctx, input, f"{name}_input") - - num_elements = 1 - for i in range(start_dim, end_dim + 1): - num_elements *= shape[i] - - new_shape = ( - tuple(shape[:start_dim]) - + (num_elements,) - + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) - ) - layer = ctx.net.add_shuffle(input) - layer.reshape_dims = new_shape - set_layer_name(layer, target, f"{name}_flatten", source_ir) - return layer.get_output(0) +# def flatten( +# ctx: ConversionContext, +# target: Union[Target, str], +# source_ir: Optional[SourceIR], +# name: str, +# input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], +# start_dim: int, +# end_dim: int, +# ) -> TRTTensor: +# shape = input.shape +# dim_size = len(shape) +# start_dim = get_positive_dim(start_dim, dim_size) +# end_dim = get_positive_dim(end_dim, dim_size) + +# if not isinstance(input, TRTTensor): +# input = get_trt_tensor(ctx, input, f"{name}_input") + +# num_elements = 1 +# for i in range(start_dim, end_dim + 1): +# num_elements *= shape[i] + +# new_shape = ( +# tuple(shape[:start_dim]) +# + (num_elements,) +# + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) +# ) +# layer = ctx.net.add_shuffle(input) +# layer.reshape_dims = new_shape +# set_layer_name(layer, target, name, source_ir) +# return layer.get_output(0) From b4678d954a353638a53a16ebba6be106dd3f0184 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 3 Oct 2023 12:56:42 -0700 Subject: [PATCH 06/11] fix circular import error --- .../dynamo/conversion/converter_utils.py | 31 --------- .../dynamo/conversion/impl/shuffle.py | 68 ++++++++++--------- 2 files changed, 35 insertions(+), 64 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index fd7fc83058..d1d94d40cc 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -9,7 +9,6 @@ from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_registry import ( ConverterRegistry, @@ -512,33 +511,3 @@ def to_numpy( raise AssertionError( f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" ) - - -def flatten_dims( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], - start_dim: int, - end_dim: int, -) -> TRTTensor: - shape = input.shape - dim_size = len(shape) - start_dim = get_positive_dim(start_dim, dim_size) - end_dim = get_positive_dim(end_dim, dim_size) - - if not isinstance(input, TRTTensor): - input = get_trt_tensor(ctx, input, f"{name}_flatten") - - num_elements = 1 - for i in range(start_dim, end_dim + 1): - num_elements *= shape[i] - - new_shape = ( - tuple(shape[:start_dim]) - + (num_elements,) - + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) - ) - - return impl.shuffle.reshape(ctx, target, source_ir, name, input, new_shape) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 70484b06f2..53b1a204a8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,10 +1,14 @@ -from typing import List, Optional, Sequence, Union +from typing import Optional, Sequence, Union import numpy as np import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + get_positive_dim, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -15,7 +19,7 @@ def reshape( source_ir: Optional[SourceIR], name: str, input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], - shape: List[int], + shape: Sequence[int], ) -> TRTTensor: if not isinstance(input, TRTTensor): input = get_trt_tensor(ctx, input, f"{name}_input") @@ -25,33 +29,31 @@ def reshape( return layer.get_output(0) -# def flatten( -# ctx: ConversionContext, -# target: Union[Target, str], -# source_ir: Optional[SourceIR], -# name: str, -# input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], -# start_dim: int, -# end_dim: int, -# ) -> TRTTensor: -# shape = input.shape -# dim_size = len(shape) -# start_dim = get_positive_dim(start_dim, dim_size) -# end_dim = get_positive_dim(end_dim, dim_size) - -# if not isinstance(input, TRTTensor): -# input = get_trt_tensor(ctx, input, f"{name}_input") - -# num_elements = 1 -# for i in range(start_dim, end_dim + 1): -# num_elements *= shape[i] - -# new_shape = ( -# tuple(shape[:start_dim]) -# + (num_elements,) -# + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) -# ) -# layer = ctx.net.add_shuffle(input) -# layer.reshape_dims = new_shape -# set_layer_name(layer, target, name, source_ir) -# return layer.get_output(0) +def flatten( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], + start_dim: int, + end_dim: int, +) -> TRTTensor: + shape = input.shape + dim_size = len(shape) + start_dim = get_positive_dim(start_dim, dim_size) + end_dim = get_positive_dim(end_dim, dim_size) + + if not isinstance(input, TRTTensor): + input = get_trt_tensor(ctx, input, f"{name}_flatten") + + num_elements = 1 + for i in range(start_dim, end_dim + 1): + num_elements *= shape[i] + + new_shape = ( + tuple(shape[:start_dim]) + + (num_elements,) + + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) + ) + + return reshape(ctx, target, source_ir, name, input, new_shape) From 30306f859abd0ec213812fb2ff0be71419c9f973 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 3 Oct 2023 19:42:27 -0700 Subject: [PATCH 07/11] change to flatten_dims utility --- .../dynamo/conversion/converter_utils.py | 30 ++++++++++++++++ .../dynamo/conversion/impl/shuffle.py | 36 +------------------ 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index d1d94d40cc..4f11daef2c 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -17,6 +17,7 @@ from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, get_axes_for_reduce_op, + set_layer_name, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTTensor @@ -511,3 +512,32 @@ def to_numpy( raise AssertionError( f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}" ) + + +def flatten_dims( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], + start_dim: int, + end_dim: int, +) -> TRTTensor: + shape = input.shape + dim_size = len(shape) + start_dim = get_positive_dim(start_dim, dim_size) + end_dim = get_positive_dim(end_dim, dim_size) + + if not isinstance(input, TRTTensor): + input = get_trt_tensor(ctx, input, f"{name}_flatten") + + num_elements = 1 + for i in range(start_dim, end_dim + 1): + num_elements *= shape[i] + + new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :]) + + layer = ctx.net.add_shuffle(input) + layer.reshape_dims = new_shape + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 53b1a204a8..19b01fcc23 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -4,11 +4,7 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - get_positive_dim, - get_trt_tensor, -) +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -27,33 +23,3 @@ def reshape( layer.reshape_dims = tuple(shape) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) - - -def flatten( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, - input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], - start_dim: int, - end_dim: int, -) -> TRTTensor: - shape = input.shape - dim_size = len(shape) - start_dim = get_positive_dim(start_dim, dim_size) - end_dim = get_positive_dim(end_dim, dim_size) - - if not isinstance(input, TRTTensor): - input = get_trt_tensor(ctx, input, f"{name}_flatten") - - num_elements = 1 - for i in range(start_dim, end_dim + 1): - num_elements *= shape[i] - - new_shape = ( - tuple(shape[:start_dim]) - + (num_elements,) - + (tuple(shape[end_dim + 1 :]) if end_dim + 1 < dim_size else tuple()) - ) - - return reshape(ctx, target, source_ir, name, input, new_shape) From 0f24a5204f454d2514d00a525869e2cea42a9332 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 4 Oct 2023 11:29:14 -0700 Subject: [PATCH 08/11] modify flatten_dims() --- .../dynamo/conversion/converter_utils.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4f11daef2c..b382c5c329 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -17,7 +17,6 @@ from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, get_axes_for_reduce_op, - set_layer_name, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTTensor @@ -515,29 +514,32 @@ def to_numpy( def flatten_dims( - ctx: ConversionContext, - target: Union[Target, str], - source_ir: Optional[SourceIR], - name: str, input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], start_dim: int, end_dim: int, -) -> TRTTensor: +) -> Tuple[int, ...]: + """ + Given an input, start and end indices of dimension, + this function will return a flattened new shape. + + Args: + input (Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]]): + an input value waiting to be flattened + start_dim (int): the first dim to flatten + end_dim (int): the last dim to flatten (this dim is included) + + Returns: + Tuple[int]: new_shape + """ shape = input.shape dim_size = len(shape) start_dim = get_positive_dim(start_dim, dim_size) end_dim = get_positive_dim(end_dim, dim_size) - if not isinstance(input, TRTTensor): - input = get_trt_tensor(ctx, input, f"{name}_flatten") - num_elements = 1 for i in range(start_dim, end_dim + 1): num_elements *= shape[i] new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :]) - layer = ctx.net.add_shuffle(input) - layer.reshape_dims = new_shape - set_layer_name(layer, target, name, source_ir) - return layer.get_output(0) + return new_shape From c0cfeb82eb705f33dd52908d55f6d3ded1e0d93d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 4 Oct 2023 15:44:21 -0700 Subject: [PATCH 09/11] clean flatten converter and add tests --- .../dynamo/conversion/aten_ops_converters.py | 4 +- .../dynamo/conversion/test_converter_utils.py | 40 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5b5cfc4e60..f9bdec4509 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -506,7 +506,7 @@ def aten_ops_slice( { 0: (TRTTensor,), } -) +) # type: ignore[misc] def aten_ops_permute( ctx: ConversionContext, target: Target, @@ -1394,7 +1394,7 @@ def conv_param_validator(conv_node: Node) -> bool: 1: (np.ndarray, torch.Tensor, TRTTensor), 2: (np.ndarray, torch.Tensor, TRTTensor), } -) +) # type: ignore[misc] def aten_ops_convolution( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_converter_utils.py b/tests/py/dynamo/conversion/test_converter_utils.py index b4f1ff2f93..8025eae841 100644 --- a/tests/py/dynamo/conversion/test_converter_utils.py +++ b/tests/py/dynamo/conversion/test_converter_utils.py @@ -1,7 +1,11 @@ import numpy as np import torch +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests -from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types +from torch_tensorrt.dynamo.conversion.converter_utils import ( + enforce_tensor_types, + flatten_dims, +) from torch_tensorrt.fx.types import TRTTensor from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -37,5 +41,39 @@ def test_invalid_invocation_type(self): enforce_tensor_types({0: (int, bool)}) +class TestFlattenDimsEnforcement(TestCase): + @parameterized.expand( + [ + ((1, 2), 0, 0, (1, 2)), + ((1, 2), 0, 1, (2,)), + ((2, 3, 4), 1, 2, (2, 12)), + ((2, 3, 4), 0, 1, (6, 4)), + ((2, 3, 4), -3, 2, (24,)), + ((2, 3, 4, 5), 0, -2, (24, 5)), + ((2, 3, 4, 5), -4, -1, (120,)), + ] + ) + def test_numpy_array(self, input_shape, start_dim, end_dim, true_shape): + inputs = np.random.randn(*input_shape) + new_shape = flatten_dims(inputs, start_dim, end_dim) + self.assertEqual(new_shape, true_shape) + + @parameterized.expand( + [ + ((1, 2), 0, 0, (1, 2)), + ((1, 2), 0, 1, (2,)), + ((2, 3, 4), 1, 2, (2, 12)), + ((2, 3, 4), 0, 1, (6, 4)), + ((2, 3, 4), -3, 2, (24,)), + ((2, 3, 4, 5), 0, -2, (24, 5)), + ((2, 3, 4, 5), -4, -1, (120,)), + ] + ) + def test_torch_tensor(self, input_shape, start_dim, end_dim, true_shape): + inputs = torch.randn(input_shape) + new_shape = flatten_dims(inputs, start_dim, end_dim) + self.assertEqual(new_shape, true_shape) + + if __name__ == "__main__": run_tests() From 9846f239ed2be1df8e7766be14bcaaf689fc0c85 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 5 Oct 2023 13:20:43 -0700 Subject: [PATCH 10/11] use enforce_tensor_types decorator to cast type --- py/torch_tensorrt/dynamo/conversion/impl/shuffle.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 19b01fcc23..3a4c160d77 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -1,10 +1,8 @@ from typing import Optional, Sequence, Union -import numpy as np -import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -14,11 +12,9 @@ def reshape( target: Union[Target, str], source_ir: Optional[SourceIR], name: str, - input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], + input: TRTTensor, shape: Sequence[int], ) -> TRTTensor: - if not isinstance(input, TRTTensor): - input = get_trt_tensor(ctx, input, f"{name}_input") layer = ctx.net.add_shuffle(input) layer.reshape_dims = tuple(shape) set_layer_name(layer, target, name, source_ir) From 72342b75097b92769a94b12f9ee8dde4e50927dc Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 6 Oct 2023 12:01:26 -0700 Subject: [PATCH 11/11] rebase --- .../py/dynamo/conversion/test_reshape_aten.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/py/dynamo/conversion/test_reshape_aten.py b/tests/py/dynamo/conversion/test_reshape_aten.py index a2f62a57c5..a8c831f23b 100644 --- a/tests/py/dynamo/conversion/test_reshape_aten.py +++ b/tests/py/dynamo/conversion/test_reshape_aten.py @@ -23,17 +23,16 @@ class TestReshapeConverter(DispatchTestCase): "Shape tensor supported well in TensorRT 8.5 and later", ) def test_reshape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): + class Reshape(torch.nn.Module): + def __init__(self): super().__init__() - self.target_shape = target_shape def forward(self, x): - return torch.ops.aten.view.default(x, self.target_shape) + return torch.ops.aten.view.default(x, target_shape) inputs = [torch.randn(1, 2, 10)] self.run_test( - TestModule(target_shape), + Reshape(), inputs, ) @@ -50,13 +49,12 @@ def forward(self, x): "Shape tensor supported well in TensorRT 8.5 and later", ) def test_reshape_with_dynamic_shape(self, target_shape): - class TestModule(torch.nn.Module): - def __init__(self, target_shape): + class Reshape(torch.nn.Module): + def __init__(self): super().__init__() - self.target_shape = target_shape def forward(self, x): - return torch.ops.aten.view.default(x, self.target_shape) + return torch.ops.aten.view.default(x, target_shape) input_specs = [ Input( @@ -66,9 +64,8 @@ def forward(self, x): ), ] self.run_test_with_dynamic_shape( - TestModule(target_shape), + Reshape(), input_specs, - expected_ops={torch.ops.aten.view.default}, )