8000 fix torch.tensor for functionalization by bdhirsh · Pull Request #76319 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix torch.tensor for functionalization #76319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

bdhirsh
Copy link
Contributor
@bdhirsh bdhirsh commented Apr 25, 2022

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Apr 25, 2022

🔗 Helpful links

❌ 3 New Failures

As of commit 724cc37 (more details on the Dr. CI page):

Expand to see more
  • 3/3 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge) (1/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T16:48:52.4973713Z /var/lib/jenkins/w... virtual member functions can be marked 'override'
2022-05-11T16:48:52.4970459Z In file included from /var/lib/jenkins/workspace/xla/torch_xla/csrc/aten_xla_bridge.cpp:11:
2022-05-11T16:48:52.4970932Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:33:40: error: only virtual member functions can be marked 'override'
2022-05-11T16:48:52.4971260Z   at::IntArrayRef sizes_custom() const override;
2022-05-11T16:48:52.4971484Z                                        ^~~~~~~~
2022-05-11T16:48:52.4971876Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:34:42: error: only virtual member functions can be marked 'override'
2022-05-11T16:48:52.4972198Z   at::IntArrayRef strides_custom() const override;
2022-05-11T16:48:52.4972405Z                                          ^~~~~~~~
2022-05-11T16:48:52.4972796Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:36:30: error: only virtual member functions can be marked 'override'
2022-05-11T16:48:52.4973140Z   int64_t dim_custom() const override;
2022-05-11T16:48:52.4973325Z                              ^~~~~~~~
2022-05-11T16:48:52.4973713Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_impl.h:38:32: error: only virtual member functions can be marked 'override'
2022-05-11T16:48:52.4974008Z   int64_t numel_custom() const override;
2022-05-11T16:48:52.4974207Z                                ^~~~~~~~
2022-05-11T16:48:52.4974373Z 4 errors generated.
2022-05-11T16:49:20.1887972Z [10/179] clang++-8 -MMD -MF /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/RegisterXLA.o.d -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -I/var/lib/jenkins/workspace/xla -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-bin -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/protobuf_archive/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_protobuf/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/eigen_archive -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_absl -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/site-packages/torch/include/THC -I/opt/conda/include/python3.7m -c -c /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp -o /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/RegisterXLA.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_clang"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1002"' -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1
2022-05-11T16:49:20.1890263Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp:63:6: warning: unused function 'resize_out' [-Wunused-function]
2022-05-11T16:49:20.1890672Z void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
2022-05-11T16:49:20.1890952Z      ^
2022-05-11T16:49:20.1891346Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/RegisterXLA.cpp:82:6: warning: unused function 'check_inplace' [-Wunused-function]
2022-05-11T16:49:20.1891716Z void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
2022-05-11T16:49:20.1891956Z      ^

See GitHub Actions build pull / linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge) (2/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T15:50:55.2569661Z ##[error]Process completed with exit code 1.
2022-05-11T15:50:54.8774762Z SKIPPED [1] test/onnx/test_pytorch_common.py:68: Unsupported opset_version: 10 > 8
2022-05-11T15:50:54.8775111Z SKIPPED [1] test/onnx/test_pytorch_common.py:53: Failing, see https://github.com/pytorch/pytorch/issues/66528
2022-05-11T15:50:54.8775545Z SKIPPED [1] test/onnx/test_pytorch_common.py:53: Unstable loading pretrained quantized mobilenet v3: https://github.com/pytorch/vision/issues/5303
2022-05-11T15:50:54.8775937Z SKIPPED [3] test/onnx/test_pytorch_common.py:68: Unsupported opset_version: 10 < 11
2022-05-11T15:50:54.8776339Z SKIPPED [1] test/onnx/test_pytorch_common.py:53: Bug in ORT, skip test until rel-1.11.
2022-05-11T15:50:54.8776653Z SKIPPED [1] test/onnx/test_pytorch_common.py:112: Unsupported opset_version: 10 < 11
2022-05-11T15:50:54.8777011Z FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset10::test_arange_with_floats_out
2022-05-11T15:50:54.8777389Z FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset10::test_dynamic_arange_out
2022-05-11T15:50:54.8777749Z FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset10::test_dynamic_arange_start_out
2022-05-11T15:50:54.8778218Z �[31m==== �[31m�[1m3 failed�[0m, �[32m611 passed�[0m, �[33m190 skipped�[0m, �[33m1461 warnings�[0m�[31m in 209.83s (0:03:29)�[0m�[31m =====�[0m
2022-05-11T15:50:55.2569661Z ##[error]Process completed with exit code 1.
2022-05-11T15:50:55.2612516Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master
2022-05-11T15:50:55.2612779Z with:
2022-05-11T15:50:55.2613200Z   github-token: ***
2022-05-11T15:50:55.2613359Z env:
2022-05-11T15:50:55.2613512Z   IN_CI: 1
2022-05-11T15:50:55.2613672Z   IS_GHA: 1
2022-05-11T15:50:55.2613839Z   GIT_DEFAULT_BRANCH: master
2022-05-11T15:50:55.2614023Z ##[endgroup]
2022-05-11T15:50:55.2643944Z ##[group]Run nick-fields/retry@71062288b76e2b6214ebde0e673ce0de1755740a
2022-05-11T15:50:55.2644184Z with:

See GitHub Actions build pull / linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge) (3/3)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-11T16:32:05.8162069Z RuntimeError: Unab...r) (compile in debug mode for type information)
2022-05-11T16:32:05.8158223Z ――――――――――――――― TestUtilityFuns_opset15.test_constant_fold_shape ―――――――――――――――
2022-05-11T16:32:05.8158595Z Traceback (most recent call last):
2022-05-11T16:32:05.8158900Z   File "/var/lib/jenkins/workspace/test/onnx/test_utility_funs.py", line 638, in test_constant_fold_shape
2022-05-11T16:32:05.8159298Z     ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
2022-05-11T16:32:05.8159644Z   File "/var/lib/jenkins/workspace/test/onnx/test_utility_funs.py", line 71, in _model_to_graph
2022-05-11T16:32:05.8159932Z     dynamic_axes=dynamic_axes,
2022-05-11T16:32:05.8160385Z   File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 732, in _model_to_graph
2022-05-11T16:32:05.8160659Z     module=module,
2022-05-11T16:32:05.8161109Z   File "/opt/conda/lib/python3.7/site-packages/torch/onnx/utils.py", line 301, in _optimize_graph
2022-05-11T16:32:05.8161392Z     torch._C._jit_pass_onnx_lint(graph)
2022-05-11T16:32:05.8162069Z RuntimeError: Unable to cast from non-held to held instance (T& to Holder<T>) (compile in debug mode for type information)
2022-05-11T16:32:05.8162461Z 
2022-05-11T16:32:05.8162466Z 
2022-05-11T16:32:05.8163348Z 
2022-05-11T16:32:05.8459085Z  �[36mtest/onnx/test_utility_funs.py�[0m::TestUtilityFuns_opset15.test_constant_fold_shape�[0m �[31m⨯�[0m�[31m45% �[0m�[40m�[32m█�[0m�[40m�[32m█�[0m�[40m�[31m█�[0m�[40m�[31m█�[0m�[40m�[31m▌�[0m�[40m�[31m     �[0m
2022-05-11T16:32:05.8459634Z 
2022-05-11T16:32:05.8459909Z ――――――――――――――― TestUtilityFuns_opset15.test_constant_fold_slice ―――――――――――――――
2022-05-11T16:32:05.8460199Z Traceback (most recent call last):
2022-05-11T16:32:05.8460503Z   File "/var/lib/jenkins/workspace/test/onnx/test_utility_funs.py", line 225, in test_constant_fold_slice
2022-05-11T16:32:05.8460843Z     NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
2022-05-11T16:32:05.8461177Z   File "/var/lib/jenkins/workspace/test/onnx/test_utility_funs.py", line 71, in _model_to_graph

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

`torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper.

I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to).




[ghstack-poisoned]
@@ -292,7 +294,29 @@ Tensor internal_new_from_data(
"Expected a Storage of type ", scalar_type,
" or an _UntypedStorage, but got ", storage_scalar_type);
tensor = at::empty(sizes, at::initialTensorOptions().dtype(is_typed_storage ? storage_scalar_type : inferred_scalar_type).pinned_memory(pin_memory).device(storage.device()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that this actually is very inefficient, right?! We allocate a sizes storage, and then throw it out immediately after! If we had some sort of new_as_strided (which @albanD was mumbling about at #75994 ) we could do this all in one go, it would be faster, and you could directly implement functionalization there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have the low level function in c++:

- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

The new version could be added!

// LazyNativeFunctions::empty is explicitly responsible for wrapping its output into a FunctionalTensorWrapper.
// - That leaves us with the problem described here though: at::empty() is going to return a wrapper.
// One way to generalize this would be to make "wrapper tensor" a first class concept,
// e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a long comment saying why the thing doesn't work, but I think what I'd actually read about is how, morally it should work

if (at::functionalization::impl::isFunctionalTensor(tensor)) {
at::functionalization::impl::from_functional_tensor(tensor).set_(storage);
} else {
data_tensor = tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come data_tensor gets set in one branch but not the other?

if (c10::multiply_integers(tensor.sizes()) != 0) {

// See Note [Functionalization <> torch.Tensor Constructor]
at::Tensor data_tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...shadowing the data_tensor above?

// One way to generalize this would be to make "wrapper tensor" a first class concept,
// e.g. by giving TensorImpl a virtual unwrap() function (guarded to error on normal TensorImpls).
if (at::functionalization::impl::isFunctionalTensor(tensor)) {
at::functionalization::impl::from_functional_tensor(tensor).set_(storage);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the "storage was passed to tensor constructor" code path, doesn't anyone actually need this?

If we want to do this soundly we need to identify if the passed in storage is a functionalization storage or a regular storage, because it seems to me like unwrapping the functional tensor if its a functionalization storage would be the wrong thing to do.

data_tensor = tensor;
}

if (c10::multiply_integers(data_tensor.sizes()) != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be sound to compute this on tensor.sizes() too right?

8000
@ezyang
Copy link
Contributor
ezyang commented Apr 25, 2022

The code at stake here is not big, so I don't think it's too risky to land this as is (esp as an unblocker). However, I think I have an alternative suggestion for how to do this properly. The general concept of what is going on here is that torch.tensor is all about creating a "raw" tensor (directly writing in data) without any transforms in effect, and then "lifting/wrapping" it into a constant in the relevant transform. So what I think should happen is that the "data" manipulations should happen under some RAII guard that disables all transforms, and then we need a way to lift/wrap based on all of the currently active transforms that require mandatory wrapping immediately.

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

`torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper.

I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to).




[ghstack-poisoned]
@bdhirsh
Copy link
Contributor Author
bdhirsh commented Apr 26, 2022

So what I think should happen is that the "data" manipulations should happen under some RAII guard that disables all transforms, and then we need a way to lift/wrap based on all of the currently active transforms that require mandatory wrapping immediately.

Hmm ok, I originally wasn't sure if this was the end-state behavior that we wanted (partially because of the weird interaction that it would cause with LTC/XLA), but the description makes sense to me. I'll switch it over to do it this way in the PR. We can also get things to work out on the LTC/XLA side by having their empty kernel explicitly read out the TLS to know whether or not to do the wrapping.

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

`torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper.

I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to).




[ghstack-poisoned]
@ezyang
Copy link
Contributor
ezyang commented Apr 26, 2022

I didn't think carefully about the LTC/XLA side of things. But it seems similar? You need to make an honest to goodness tensor with the data you want, and then lower it into the XLA graph as a constant. That's what lift should be doing, I think?

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

`torch.Tensor` already has to be handled specially in several other contexts (autograd and functorch). Unfortunately we can't use the same approach to fix the issue for functionalization - I described the problem in more detail in the code comments. The previous solutions rely on `at::empty()` *not* returning a wrapper by setting some TLS, and relying on a `.to()` call later on to "promote" to a wrapper.

I'm wondering what people's thoughts are on landing this directly, or trying to be more general / not specialize on functionalization. For example, we could make "wrapper tensor" a first class concept (maybe e.g. by adding a, `unwrap()` function on `TensorImpl` that errors out unless you override it not to).




[ghstack-poisoned]
Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below.




[ghstack-poisoned]
@bdhirsh
Copy link
Contributor Author
bdhirsh commented May 10, 2022

@zhxchen17 I've been working on a big stack of functionalization changes locally including the feedback from this PR. Actually just pushed out the changes a minute ago.

@ezyang I saw that you just gave the PR an approve, feel free to take a look at the new changes - I added a new lift aten op, that currently only has an implementation for functionalization (and has a default "no-op" implementation).

I'm going to need to add a companion PR for functorch before I can land this though - the existing wrapper subclasses in functorch need to know about "lifting"

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below.




[ghstack-poisoned]
c10::impl::ExcludeDispatchKeyGuard functorch_guard(c10::DispatchKey::FuncTorchDynamicLayerBackMode);
// We disable DeferredInit handler for similar reasons as functorch.
c10::impl::ExcludeDispatchKeyGuard deferred_init_guard(c10::DispatchKey::DeferredInit);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like there should be an easy way to just "exclude everything", but we can probably work that out later.

Copy link
Contributor
@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the new code with lift and I like it a lot! Thanks!

Right now, using the `torch.Tensor` constructor inside of a functionalized function is broken (and there's a request to use it during tracing for mobile: https://fb.workplace.com/groups/1405155842844877/permalink/5805679106125840/)

update: I took the approach described in the comments (letting `at::empty()` run directly with the data, and using `.to()` to "lift" it into a wrapper), which required a minor change described below.




[ghstack-poisoned]
@bdhirsh bdhirsh requested a review from soulitzer as a code owner May 11, 2022 14:32
@bdhirsh
Copy link
Contributor Author
bdhirsh commented May 11, 2022

I'm actually pretty sure that this doesn't require functorch changes: I was trying to change functorch locally, and my understanding is that now the code in the FuncTorchDynamicLayerFrontMode fallback kernel is responsible for the wrapping. Since that key is still enabled during the .to() call, I think the wrapping still happens correctly without functorch adding support for lift().

Getting functorch to support lift() seems like the right the to do but I'd rather land this PR now to unblock mobile.

It also looks like functorch CI is failing on master, but I was able to run python test/test_eager_transforms.py TestGradTransformCPU locally without any failures (which includes tests that run the torch.tensor() constructor under the grad transform).

@bdhirsh
Copy link
Contributor Author
bdhirsh commented May 11, 2022

@pytorchbot merge this please

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@janeyx99
Copy link
Contributor

@pytorchbot revert this as it breaks ONNX tests (which also show up on the PR) https://hud.pytorch.org/minihud?name_filter=pull%20/%20linux-xenial-py3.7-clang7-onnx%20/%20test%20(default,%202,%202,%20linux.2xlarge)

pytorchmergebot added a commit that referenced this pull request May 11, 2022
bdhirsh added a commit that referenced this pull request May 12, 2022
…alization"

Re-land of #76319 - I needed to tell `onnx` what the new `at::lift` op is, since it technically gets traced by `torch.jit.trace()` into torchscript. I think it should be a no-op though.


This reverts commit 85bd65a.

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2022
…sor for functionalization"

Re-land of #76319 - I needed to tell `onnx` what the new `at::lift` op is, since it technically gets traced by `torch.jit.trace()` into torchscript. I think it should be a no-op though.


This reverts commit 85bd65a.

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2022
…alization"

Re-land of #76319 - I needed to tell `onnx` what the new `at::lift` op is, since it technically gets traced by `torch.jit.trace()` into torchscript. I think it should be a no-op though.


This reverts commit 85bd65a.

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/218/head branch May 15, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0