8000 fix: Fix for torch scripted module faiure with DLFW · pytorch/TensorRT@88c02d9 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Commit 88c02d9

Browse files
author
Anurag Dixit
committed
fix: Fix for torch scripted module faiure with DLFW
Signed-off-by: Anurag Dixit <anuragd@nvidia.com>
1 parent a10613e commit 88c02d9

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tests/py/test_api.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def test_compile_traced(self):
3030
self.assertTrue(same < 2e-2)
3131

3232
def test_compile_script(self):
33-
trt_mod = torchtrt.ts.compile(self.scripted_model,
33+
with torch.no_grad():
34+
trt_mod = torchtrt.ts.compile(self.scripted_model,
3435
inputs=[self.input],
3536
device=torchtrt.Device(gpu_id=0),
3637
enabled_precisions={torch.float})
37-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
38-
self.assertTrue(same < 2e-2)
38+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
39+
self.assertTrue(same < 2e-2)
3940

4041
def test_compile_global(self):
4142
trt_mod = torchtrt.compile(self.scripted_model,
@@ -46,12 +47,13 @@ def test_compile_global(self):
4647
self.assertTrue(same < 2e-2)
4748

4849
def test_compile_global_nn_mod(self):
49-
trt_mod = torchtrt.compile(self.model,
50+
with torch.no_grad():
51+
trt_mod = torchtrt.compile(self.model,
5052
inputs=[self.input],
5153
device=torchtrt.Device(gpu_id=0),
5254
enabled_precisions={torch.float})
53-
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
54-
self.assertTrue(same < 2e-2)
55+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
56+
self.assertTrue(same < 2e-2)
5557

5658
def test_from_torch_tensor(self):
5759
compile_spec = {

0 commit comments

Comments
 (0)
0