-
Notifications
You must be signed in to change notification settings
8000
- Fork 24.1k
[PT2.1] SIGSEGV seen with view + sgn operator inside torch.compile #111027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Verified reproducible with pytorch 2.1, but not on the nightly or my local build (hash |
Hey @jay746, just confirming, but did you file this as a regression? I also tried installing torch 2.0.0, and confirmed that I see the same segfault. So looks like a hi-pri bug, although it doesn't seem to be a regression. |
This looks like a bad interaction between efficientzerotensor and functionalization. Here's a minimal repro:
When I run under a debug build with
You can see that The right thing to do here is probably not to copy the |
Hi @bdhirsh , |
I'm tentatively removing hi-pri, since this is "fixed" on tip of main (you cannot repro the segfault). Since this is not a regression, I'm not sure that this is a candidate for a fix in the 2.1 release. The main issue is that the segfault was incidentally "fixed" by updating AOTAutograd to use python functionalization, and this was a pretty large change (that would be risky to put into a patch release). Alternatively we could try to fix the segfault directly in C++ functionalization, although this would mean adding an entirely new set of changes to the 2.1 branch that are not in main, which I'm not sure that we want either. |
I wanted to leave this issue open and mark it triage review though, because the state of ZeroTensor with
errors with:
What's going on? It's a bit clearer with this slightly larger repro:
this "runs", and prints the following FX graph (I've annotated the forward and backward bits myself):
The backward graph is particularly weird- there are calls to I'm marking this with triage review, because I'd like to understand what we actually want to have happen with (a) Should we fix the compile-time errors with zerotensor that I mentioned above? Probably yes. (b) Should we try to avoid directly tracing the meta-tensor calls from ZeroTensor into the graph, which will force inductor to handle them? Probably yes (c) What do we actually want inductor to see in the graph when we're tracing zerotensor code from eager mode? Today, inductor has a fallback for By the same logic, we probably want cc @ezyang, @zou3519 since "tracing ZeroTensor" came up a few times in conversation when I added python functionalization. |
Decompose is 100% right, if you want to be fancy we might want to sometimes return zero tensors as outputs when we know some outputs must be zero but this should be done as an add on |
@ezyang I don't see this issue after PT2.2 upgrade, did we fix it? |
idk, maybe @bdhirsh knows! |
Hi @bdhirsh |
I believe this issue is already fixed in PT2.2. Closing. |
🐛 Describe the bug
when view operator with sgn used inside torch.compile, then signal segmentation violation error show.
Please use below code to reproduce the issue.
Error logs
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007fffeb281440 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() ()
from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
(gdb) bt
#0 0x00007fffeb281440 in c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() ()
from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#1 0x00007fffeb2a9999 in at::FunctionalTensorWrapper::replace_(at::Tensor const&) () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#2 0x00007fffeb2aa48c in at::FunctionalTensorWrapper::regenerate_from_base() () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so
#3 0x00007ffff6710e3b in torch::autograd::THPVariable__sync(_object*, _object*, _object*) ()
from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#4 0x00000000005f6939 in PyCFunction_Call ()
#5 0x00000000005f7506 in _PyObject_MakeTpCall ()
#6 0x0000000000570b8e in _PyEval_EvalFrameDefault ()
#7 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#8 0x00000000005f6ce6 in _PyFunction_Vectorcall ()
#9 0x000000000056b4ed in _PyEval_EvalFrameDefault ()
#10 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#11 0x00000000005697da in _PyEval_EvalCodeWithName ()
#12 0x00000000005f6ec3 in _PyFunction_Vectorcall ()
#13 0x000000000056b4ed in _PyEval_EvalFrameDefault ()
#14 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#15 0x00000000005697da in _PyEval_EvalCodeWithName ()
#16 0x00000000005f6ec3 in _PyFunction_Vectorcall ()
#17 0x0000000000570556 in _PyEval_EvalFrameDefault ()
#18 0x00007ffff673afcb in custom_eval_frame_shim () from /tmp/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#19 0x00000000005697da in _PyEval_EvalCodeWithName ()
#20 0x00000000005f6ec3 in _PyFunction_Vectorcall ()
Minified repro
No response
Versions
[pip3] numpy==1.24.4
[pip3] torch==2.1.0
[pip3] torchaudio==2.0.1
[pip3] torchdata==0.6.1
[pip3] torchmetrics==1.2.0
[pip3] torchtext==0.15.2a0
[pip3] torchvision==0.15.1a0
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305
The text was updated successfully, but these errors were encountered: