Description
🐛 Describe the bug
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = torch.sqrt(x)
dy_dx = torch.autograd.grad(y, x, torch.ones_like(y))[0]
return y, dy_dx
model = Model()
x = torch.rand([100,100])
x = x.requires_grad_(True)
inputs = (x,)
y_ref, dy_dx_ref = model.forward(*inputs)
ep = torch.export.export(model, args=inputs, strict=False)
#ep = ep.run_decompositions(torch.export.default_decompositions()) #Can't decompose because we hit 'RuntimeError: CUDA error: an illegal memory access was encountered'
print(ep)
y_test, dy_dx_test = ep.module()(*inputs)
assert torch.allclose(y_ref, y_test), 'Mismatch between Python module and call to its underlying export program for y!' #Passes
print('Reference: ' + str(dy_dx_ref))
print('Export: ' + str(dy_dx_test))
#assert torch.allclose(dy_dx_ref,dy_dx_test), 'Mismatch between Python module and call to its underlying export program for dy_dx!' #Fails because dy_dx_test is a FakeTensor
@torch.library.custom_op("mylib::sqrt", mutates_args=())
def sqrt(x: torch.Tensor) -> torch.Tensor:
return torch.sqrt(x)
@sqrt.register_fake
def _(x):
return x.new_empty(x.shape, requires_grad=x.requires_grad)
@torch.library.custom_op("mylib::sqrt_backward", mutates_args=())
def sqrt_backward(x: torch.Tensor, y: torch.Tensor, pullback: torch.Tensor) -> torch.Tensor:
return (1/(2*y))*pullback
@sqrt_backward.register_fake
def _(x, y, pullback):
return y.new_empty(y.shape)
def backward(ctx, grad_output):
return sqrt_backward(ctx.x, ctx.y, grad_output)
def setup_context(ctx, inputs, output):
ctx.x = inputs[0]
ctx.y = output
sqrt.register_autograd(backward, setup_context=setup_context)
class FixedModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
y = torch.ops.mylib.sqrt(x)
dy_dx = torch.autograd.grad(y, x, torch.ones_like(y))[0]
return y, dy_dx
fixed_model = FixedModel()
y_fixed_ref, dy_dx_fixed_ref = fixed_model.forward(*inputs)
assert torch.allclose(y_ref, y_fixed_ref), "Custom op didn't compute sqrt(x) correctly!"
assert torch.allclose(dy_dx_ref, dy_dx_fixed_ref), "Custom op didn't compute dy_dx correctly!"
fixed_ep = torch.export.export(fixed_model, args=inputs, strict=False)
print(fixed_ep)
y_fixed_test, dy_dx_fixed_test = fixed_ep.module()(*inputs)
assert torch.allclose(y_ref, y_fixed_test), 'Mismatch between Python module call and the fixed export program for y!' #Still passes
assert torch.allclose(dy_dx_ref, dy_dx_fixed_test), 'Mismatch between Python module call and the fixed export program for dy_dx!' #Now passes
Note that execution of the initial export program returns a FakeTensor
for dy_dx
and also note that we can't run decompositions on the initial export program, presumably because the FakeTensor
leads to an illegal memory access (I tried this using CUDA, not sure how this would work on the CPU).
If I thinly wrap torch.sqrt
as a custom op (shown above in FixedModel
) I'm able to export the program and have it work correctly. Here's what the output looks like for the above program:
I0603 12:25:44.074000 1531711 torch/fx/experimental/symbolic_shapes.py:3334] create_env
I0603 12:25:44.222000 1531711 torch/fx/experimental/symbolic_shapes.py:4734] produce_guards
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, c_lifted_tensor_0: "f32[100, 100]", x: "f32[100, 100]"):
# File: /u/pgh/mclaugha/./sandbox/sqrt_export_backward_example/test.py:8 in forward, code: y = torch.sqrt(x)
sqrt: "f32[100, 100]" = torch.ops.aten.sqrt.default(x); x = None
# File: /u/pgh/mclaugha/./sandbox/sqrt_export_backward_example/test.py:9 in forward, code: dy_dx = torch.autograd.grad(y, x, torch.ones_like(y))[0]
ones_like: "f32[100, 100]" = torch.ops.aten.ones_like.default(sqrt, pin_memory = False)
mul: "f32[100, 100]" = torch.ops.aten.mul.Scalar(c_lifted_tensor_0, 2); c_lifted_tensor_0 = None
div: "f32[100, 100]" = torch.ops.aten.div.Tensor(ones_like, mul); ones_like = mul = None
return (sqrt, div)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sqrt'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='div'), target=None)])
Range constraints: {}
/u/pgh/mclaugha/.venv/lib/python3.10/site-packages/torch/export/_unlift.py:81: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessa
6BD3
ry buffer
getattr_node = gm.graph.get_attr(lifted_node)
/u/pgh/mclaugha/.venv/lib/python3.10/site-packages/torch/fx/graph.py:1772: UserWarning: Node lifted_tensor_0 target lifted_tensor_0 lifted_tensor_0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
/u/pgh/mclaugha/.venv/lib/python3.10/site-packages/torch/export/_unlift.py:330: UserWarning: A model attribute `lifted_tensor_0` requires gradient. but it's not properly registered as a parameter. torch.export will detach it and treat it as a constant tensor but please register it as parameter instead.
warnings.warn(
Reference: tensor([[1.0254, 0.8146, 0.5030, ..., 0.7383, 0.9542, 0.5347],
[0.5092, 1.0499, 0.5585, ..., 0.5496, 1.7775, 1.3391],
[0.5341, 0.8281, 0.5627, ..., 2.9407, 0.6332, 0.8140],
...,
[0.5239, 1.9028, 0.6196, ..., 0.6412, 0.5666, 2.5104],
[0.7582, 0.8188, 0.7271, ..., 0.7210, 1.4510, 1.1716],
[0.5200, 0.5867, 0.7121, ..., 0.5817, 0.5796, 0.8216]])
Export: FakeTensor(..., size=(100, 100))
I0603 12:25:44.272000 1531711 torch/fx/experimental/symbolic_shapes.py:3334] create_env
I0603 12:25:44.283000 1531711 torch/fx/experimental/symbolic_shapes.py:4734] produce_guards
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[100, 100]"):
# File: /u/pgh/mclaugha/./sandbox/sqrt_export_backward_example/test.py:64 in forward, code: y = torch.ops.mylib.sqrt(x)
sqrt: "f32[100, 100]" = torch.ops.mylib.sqrt.default(x)
# File: /u/pgh/mclaugha/./sandbox/sqrt_export_backward_example/test.py:65 in forward, code: dy_dx = torch.autograd.grad(y, x, torch.ones_like(y))[0]
ones_like: "f32[100, 100]" = torch.ops.aten.ones_like.default(sqrt, pin_memory = False)
sqrt_backward: "f32[100, 100]" = torch.ops.mylib.sqrt_backward.default(x, sqrt, ones_like); x = ones_like = None
return (sqrt, sqrt_backward)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sqrt'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sqrt_backward'), target=None)])
Range constraints: {}
We can see that the initial export program has an addition constant lifted_tensor_0
that appears to be equal to the value of torch.sqrt(x)
. This error is related to #146719, where I added a comment with some speculation about this issue. I decided to open a new issue here because the old one was quite stale.
Error logs
If I try to run decompositions on the initial ExportProgram
above I get RuntimeError: CUDA error: an illegal memory access was encountered
. If I skip decompositions and try to execute the initial ExportProgram
above it returns a FakeTensor
for the gradient, which should never happen.
Versions
2.7.0
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4