8000 fix: Repair integer inputs in dynamic shape cases by gs-olive · Pull Request #2876 · pytorch/TensorRT · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix: Repair integer inputs in dynamic shape cases #2876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 25, 2024
Merged
2 changes: 1 addition & 1 deletion .github/workflows/build-test-linux.yml
8000
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ jobs:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
cancel-in-progress: true
cancel-in-progress: true
19 changes: 14 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down Expand Up @@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
logger.warning(
Expand All @@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int32 explicitly.
# Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32)
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = (
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
)
self.context.set_tensor_address(
input_name, inputs_cpu.data_ptr()
input_name, inputs_cpu.ctypes.data
)
else:
self.context.set_input_shape(
Expand Down
26 changes: 11 additions & 15 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
"""Implementation of the forward pass for a TensorRT engine

Args:
*inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor``
*inputs (Union[torch.Tensor, int]): Inputs to the forward function

Returns:
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
Expand All @@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.input_binding_names
), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}."

types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs]

try:
assert all(types)
except AssertionError:

def is_non_tensor(i: Tuple[Any, bool]) -> bool:
return not i[1]

non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))]
raise RuntimeError(
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
)
# If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors
# which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer)
# directly cast the input to a Torch Tensor.
#
# This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors
input_tensors: List[torch.Tensor] = [
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(inputs), self.engine
list(input_tensors), self.engine
)

if len(outputs) == 1:
Expand Down
55 changes: 55 additions & 0 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,58 @@ def forward(self, x):
cos_sim > COSINE_THRESHOLD,
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_dynamic_with_fallback_shape_tensor_pass_through(ir):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
x = x + 2
x = x * 2
out = torch.reshape(x, (-1, 224 * 224))
out = self.relu(out)
return out

model = MyModule().eval().cuda()
input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda")

compile_spec = {
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.add.Tensor"},
}

# Compile the model
if ir == "torch_compile":
torch._dynamo.mark_dynamic(input_bs4, 0, min=4, max=1024)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
trt_model(input_bs4)
elif ir == "dynamo":
compile_spec["inputs"] = [
torchtrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(1024, 3, 224, 224),
dtype=torch.float32,
name="x",
)
]
trt_model = torchtrt.compile(model, **compile_spec)

trt_model(input_bs4)

input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
2 changes: 1 addition & 1 deletion tests/py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pytest-xdist>=3.6.1
pyyaml
tensorrt==10.0.1
timm>=1.0.3
transformers==4.39.3
transformers==4.40.2
--extra-index-url https://pypi.nvidia.com
Loading
0