Open
Description
🐛 Describe the bug
OS: Linux Ubuntu 22.04
GPU: Nvidia T4
pytorch: 2.1.0.dev20230817+cu118
torch.jit.trace()
generates C++ cuda code which contains bad constants - for example -3.402823466385289e+38.f
.
The issue is in exponent part of the number.
e+38.f
is bad. It should be e+38f
(without dot btw 38 and f).
Error exists in Pytorch 2.1.0 Nightly (2.0.1 works fine)
Example code:
import torch
from transformers import RobertaTokenizer, RobertaModel
torch.set_grad_enabled(False)
class RobertaTraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return (out['last_hidden_state'], out['pooler_output'])
model = RobertaModel.from_pretrained('roberta-base').cuda().eval()
wrap_model = RobertaTraceWrapper(model).cuda().eval()
input_ids = torch.tensor([[0,9064,6406,4,2]], dtype=torch.int64).cuda()
traced_model = torch.jit.trace(wrap_model, input_ids).eval().cuda()
# the following line fails
out = traced_model(input_ids)
Error message
Traceback (most recent call last):
File "/home/ubuntu/workspace/models/roberta-large/error.py", line 21, in <module>
out = traced_model(input_ids)
File "/home/ubuntu/workspace/virtenv-pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/workspace/virtenv-pytorch-nightly/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: default_program(22): error: extra text after expected end of number
default_program(26): error: extra text after expected end of number
2 errors detected in the compilation of "default_program".
nvrtc compilation failed:
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template<typename T>
__device__ T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
__device__ T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
extern "C" __global__
void fused_mul_div_add(float* tattention_scores_1, float* tv_, float* aten_add, float* aten_mul) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<5ll ? 1 : 0) {
float v = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v * -3.402823466385289e+38.f;
}if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<300ll ? 1 : 0) {
float v_1 = __ldg(tattention_scores_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
float v_2 = __ldg(tv_ + ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)) % 5ll);
aten_add[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = v_1 / 8.f + v_2 * -3.402823466385289e+38.f;
}}
}
Versions
PyTorch version: 2.1.0.dev20230817+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.2.0-1009-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 7
BogoMIPS: 4999.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke avx512_vnni
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 128 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 4 MiB (4 instances)
L3 cache: 35.8 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+e6216047b8
[pip3] torch==2.1.0.dev20230817+cu118
[pip3] torchaudio==2.1.0.dev20230817+cu118
[pip3] torchvision==0.16.0.dev20230817+cu118
[conda] Could not collect