8000 feat: support prod, max, min, and mean via reduce layer by zewenli98 · Pull Request #2355 · pytorch/TensorRT · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: support prod, max, min, and mean via reduce layer #2355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on 8000 GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import operator
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -137,12 +138,12 @@ def aten_ops_sigmoid(
)


@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
) # type: ignore[misc]
def aten_ops_index(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -667,7 +668,7 @@ def aten_ops_amax(
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 1, replacement=[]),
args_bounds_check(args, 2, replacement=False),
)

Expand All @@ -692,6 +693,97 @@ def aten_ops_sum(
)


@dynamo_tensorrt_converter(torch.ops.aten.prod.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int) # type: ignore[misc]
def aten_ops_prod(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.prod(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, replacement=None),
args_bounds_check(args, 2, replacement=False),
)


def one_user_validator(node: Node) -> bool:
# Validate only one user, which is a getitem node that accesses the first element in the list
return (
len(node.users) == 1
and list(node.users)[0].target == operator.getitem
and list(node.users)[0].args[1] == 0
)


@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc]
def aten_ops_max(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.max(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
dim=args_bounds_check(args, 1, replacement=None),
keepdim=args_bounds_check(args, 2, replacement=False),
return_indices=(target == torch.ops.aten.max.dim),
)


@dynamo_tensorrt_converter(torch.ops.aten.min.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.min.dim, capability_validator=one_user_validator) # type: ignore[misc]
def aten_ops_min(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.min(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
dim=args_bounds_check(args, 1, replacement=None),
keepdim=args_bounds_check(args, 2, replacement=False),
return_indices=(target == torch.ops.aten.min.dim),
)


@dynamo_tensorrt_converter(torch.ops.aten.mean.default) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) # type: ignore[misc]
def aten_ops_mean(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.reduce.mean(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args_bounds_check(args, 1, replacement=None),
args_bounds_check(args, 2, replacement=False),
)


@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
def aten_ops_exp(
ctx: ConversionContext,
Expand Down Expand Up @@ -1118,7 +1210,7 @@ def aten_ops_mul(


@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc]
def aten_ops_max(
def aten_ops_maximum(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -1136,7 +1228,7 @@ def aten_ops_max(


@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc]
def aten_ops_min(
def aten_ops_minimum(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand Down
130 changes: 126 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/reduce.py
6DB6
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Target
Expand Down Expand Up @@ -27,6 +27,9 @@ def amax(
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.MAX,
Expand All @@ -43,16 +46,17 @@ def sum(
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]] = None,
keepdim: bool = False,
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.SUM,
Expand All @@ -61,3 +65,121 @@ def sum(
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def prod(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.PROD,
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def max(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
return_indices: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.MAX,
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)

if return_indices:
return layer.get_output(0), None
Comment on lines +123 to +124
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan, @apbose - there are certain converter cases where indices are returned from an operator but never used nor accessed in the graph (confirmed via validator). In these cases, we wouldn't want to use extra computation time to add layers for an unused tensor. Which of these seems best?

  • Return None for the unused tensor, as here - (data, None)
  • Return only the data, since the unused tensor should never be accessed, as in: (data,)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)


return layer.get_output(0)


def min(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
return_indices: bool,
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None:
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.MIN,
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)

if return_indices:
return layer.get_output(0), None

return layer.get_output(0)


def mean(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
dim: Optional[Union[int, Sequence[int]]],
keepdim: bool,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)

if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
dim = tuple(range(len(input_val.shape)))

layer = ctx.net.add_reduce(
input_val,
trt.ReduceOperation.AVG,
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
keep_dims=keepdim,
)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)
2 changes: 2 additions & 0 deletions tests/py/dynamo/conversion/test_amax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True),
((3, 2, 4), [1], True),
((2, 1, 4, 5), [0, 3], True),
((2, 3, 4, 5), [0, 1, 2, 3], False),
Expand Down Expand Up @@ -69,6 +70,7 @@ def forward(self, x):

@parameterized.expand(
[
((1, 2, 4), [], True, torch.int, 0, 5),
((3, 2, 4), [1], True, torch.int, 0, 5),
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
Expand Down
59 changes: 49 additions & 10 deletions tests/py/dynamo/conversion/test_max_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,67 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestMaxConverter(DispatchTestCase):
@parameterized.expand(
[
("2d", (2, 1)),
("3d", (2, 1, 2)),
((1, 2),),
((3, 2, 4),),
((2, 3, 4, 5),),
((6, 7, 5, 4, 5),),
]
)
def test_max(self, _, shape):
class max(nn.Module):
def forward(self, lhs_val, rhs_val):
return torch.ops.aten.maximum.default(lhs_val, rhs_val)
def test_max_dim_int_default(self, input_shape):
class Max(nn.Module):
def forward(self, x):
return torch.ops.aten.max.default(x)

inputs = [torch.randn(shape), torch.randn(shape)]
inputs = [torch.randn(*input_shape)]
self.run_test(
max(),
Max(),
inputs,
# expected_ops={torch.ops.aten.maximum.default},
)

@parameterized.expand(
[
((3, 2, 4), 1, True),
((2, 3, 4, 5), 3, True),
((6, 7, 5, 4, 5), 4, False),
((1, 5, 2, 1), -3, False),
((1, 5, 2, 3), -2, True),
]
)
def test_max_dim_int(self, input_shape, dim, keep_dims):
class Max(nn.Module):
def forward(self, x):
return torch.ops.aten.max.dim(x, dim, keep_dims)[0]

inputs = [torch.randn(*input_shape)]
self.run_test(
Max(),
inputs,
)

@parameterized.expand(
[
((3, 2, 4), 1, True, torch.int, 0, 5),
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
]
)
def test_max_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
class Max(nn.Module):
def forward(self, x):
return torch.ops.aten.max.dim(x, dim, keep_dims)[0]

inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
self.run_test(
Max(),
inputs,
check_dtype=False,
)


Expand Down
Loading
0