-
Notifications
You must be signed in to change notification settings - Fork 24.4k
fix torch.tensor for functionalization #76319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
❌ 3 New FailuresAs of commit 724cc37 (more details on the Dr. CI page): Expand to see more
🕵️ 3 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
torch/csrc/utils/tensor_new.cpp
Outdated
@@ -292,7 +294,29 @@ Tensor internal_new_from_data( | |||
"Expected a Storage of type ", scalar_type, | |||
" or an _UntypedStorage, but got ", storage_scalar_type); | |||
tensor = at::empty(sizes, at::initialTensorOptions().dtype(is_typed_storage ? storage_scalar_type : inferred_scalar_type).pinned_memory(pin_memory).device(storage.device())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me that this actually is very inefficient, right?! We allocate a sizes storage, and then throw it out immediately after! If we had some sort of new_as_strided
(which @albanD was mumbling about at #75994 ) we could do this all in one go, it would be faster, and you could directly implement functionalization there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have the low level function in c++:
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor |
The new version could be added!
torch/csrc/utils/tensor_new.cpp
Outdated
// LazyNativeFunctions::empty is explicitly responsible for wrapping its output into a FunctionalTensorWrapper. | ||
// - That leaves us with the problem described here though: at::empty() is going to return a wrapper. | ||
// One way to generalize this would be to make "wrapper tensor" a first class concept, | ||
// e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a long comment saying why the thing doesn't work, but I think what I'd actually read about is how, morally it should work
torch/csrc/utils/tensor_new.cpp
Outdated
if (at::functionalization::impl::isFunctionalTensor(tensor)) { | ||
at::functionalization::impl::from_functional_tensor(tensor).set_(storage); | ||
} else { | ||
data_tensor = tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how come data_tensor
gets set in one branch but not the other?
torch/csrc/utils/tensor_new.cpp
Outdated
if (c10::multiply_integers(tensor.sizes()) != 0) { | ||
|
||
// See Note [Functionalization <> torch.Tensor Constructor] | ||
at::Tensor data_tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...shadowing the data_tensor above?
torch/csrc/utils/tensor_new.cpp
Outdated
// One way to generalize this would be to make "wrapper tensor" a first class concept, | ||
// e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls). | ||
if (at::functionalization::impl::isFunctionalTensor(tensor)) { | ||
at::functionalization::impl::from_functional_tensor(tensor).set_(storage); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the "storage was passed to tensor constructor" code path, doesn't anyone actually need this?
If we want to do this soundly we need to identify if the passed in storage is a functionalization storage or a regular storage, because it seems to me like unwrapping the functional tensor if its a functionalization storage would be the wrong thing to do.
torch/csrc/utils/tensor_new.cpp
Outdated
data_tensor = tensor; | ||
} | ||
|
||
if (c10::multiply_integers(data_tensor.sizes()) != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be sound to compute this on tensor.sizes() too right?
The code at stake here is not big, so I don't think it's too risky to land this as is (esp as an unblocker). However, I think I have an alternative suggestion for how to do this properly. The general concept of what is going on here is that |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
Hmm ok, I originally wasn't sure if this was the end-state behavior that we wanted (partially because of the weird interaction that it would cause with LTC/XLA), but the description makes sense to me. I'll switch it over to do it this way in the PR. We can also get things to work out on the LTC/XLA side by having their |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
I didn't think carefully about the LTC/XLA side of things. But it seems similar? You need to make an honest to goodness tensor with the data you want, and then lower it into the XLA graph as a constant. That's what lift should be doing, I think? |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) `torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper. I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to). [ghstack-poisoned]
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
@zhxchen17 I've been working on a big stack of functionalization changes locally including the feedback from this PR. Actually just pushed out the changes a minute ago. @ezyang I saw that you just gave the PR an approve, feel free to take a look at the new changes - I added a new I'm going to need to add a companion PR for functorch before I can land this though - the existing wrapper subclasses in functorch need to know about "lifting" |
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
c10::impl::ExcludeDispatchKeyGuard functorch_guard(c10::DispatchKey::FuncTorchDynamicLayerBackMode); | ||
// We disable DeferredInit handler for similar reasons as functorch. | ||
c10::impl::ExcludeDispatchKeyGuard deferred_init_guard(c10::DispatchKey::DeferredInit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like there should be an easy way to just "exclude everything", but we can probably work that out later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the new code with lift and I like it a lot! Thanks!
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/) update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below. [ghstack-poisoned]
I'm actually pretty sure that this doesn't require functorch changes: I was trying to change functorch locally, and my understanding is that now the code in the Getting functorch to support It also looks like functorch CI is failing on master, but I was able to run |
@pytorchbot merge this please |
Hey @bdhirsh. |
@pytorchbot revert this as it breaks ONNX tests (which also show up on the PR) https://hud.pytorch.org/minihud?name_filter=pull%20/%20linux-xenial-py3.7-clang7-onnx%20/%20test%20(default,%202,%202,%20linux.2xlarge) |
This reverts commit 9edee09. Reverted #76319 on behalf of https://github.com/janeyx99
Right now, using the
torch.Tensor
constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)update: I took the approach described in the comments (letting
at::empty()
run directly with the data, and using.to()
to "lift" it into a wrapper), which required a minor change described below.Stack from ghstack: