@@ -30,12 +30,13 @@ def test_compile_traced(self):
30
30
self .assertTrue (same < 2e-2 )
31
31
32
32
def test_compile_script (self ):
33
- trt_mod = torchtrt .ts .compile (self .scripted_model ,
33
+ with torch .no_grad ():
34
+ trt_mod = torchtrt .ts .compile (self .scripted_model ,
34
35
inputs = [self .input ],
35
36
device = torchtrt .Device (gpu_id = 0 ),
36
37
enabled_precisions = {torch .float })
37
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
38
- self .assertTrue (same < 2e-2 )
38
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
39
+ self .assertTrue (same < 2e-2 )
39
40
40
41
def test_compile_global (self ):
41
42
trt_mod = torchtrt .compile (self .scripted_model ,
@@ -46,12 +47,13 @@ def test_compile_global(self):
46
47
self .assertTrue (same < 2e-2 )
47
48
48
49
def test_compile_global_nn_mod (self ):
49
- trt_mod = torchtrt .compile (self .model ,
50
+ with torch .no_grad ():
51
+ trt_mod = torchtrt .compile (self .model ,
50
52
inputs = [self .input ],
51
53
device = torchtrt .Device (gpu_id = 0 ),
52
54
enabled_precisions = {torch .float })
53
- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
54
- self .assertTrue (same < 2e-2 )
55
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
56
+ self .assertTrue (same < 2e-2 )
55
57
56
58
def test_from_torch_tensor (self ):
57
59
compile_spec = {
0 commit comments