8000 functionalization: add native fill() op by bdhirsh · Pull Request #76084 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

functionalization: add native fill() op #76084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
8 changes: 8 additions & 0 deletions aten/src/ATen/native/Fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ Tensor& fill_meta_(Tensor& self, const Tensor& value) {
return self;
}

Tensor fill(const Tensor& self, const Scalar& value) {
return at::empty_like(self).fill_(value);
}

Tensor fill(const Tensor& self, const Tensor& value) {
return at::empty_like(self).fill_(value);
}

DEFINE_DISPATCH(fill_stub);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill_diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2116,6 +2116,16 @@
- func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a)
variants: method

- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: fill

- func: fill.Tensor(Tensor self, Tensor value) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: fill

- func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function, method
Expand Down
6 changes: 3 additions & 3 deletions test/jit/test_remove_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ def test_successful():

# full_like is not implemented for a tensor fill value

def test_unsuccessful():
def test_successful():
x = torch.tensor([2, 2])
y = torch.tensor([2, 4])
x.fill_(y)
return x + x

fn = torch.jit.script(test_unsuccessful)
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass('remove_mutation', graph)
FileCheck().check('aten::fill_').run(graph)
FileCheck().check_not('aten::fill_').run(graph)

def normal():
return torch.rand(2, 1, 3, 4).normal_()
Expand Down
16 changes: 16 additions & 0 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,28 @@ def f(x):

self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
# zero() should decompose into zeros_like(), which will show up in the trace
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, $0)
$2 = torch._ops.aten.diagonal_copy.default($1)
$3 = torch._ops.aten.zeros_like.default($2)""")

def test_fill_(self):
def f(x):
y = x + x
z = y.diagonal()
z.fill_(0)
return y

self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, $0)
$2 = torch._ops.aten.diagonal_copy.default($1)
$3 = torch._ops.aten.fill.Scalar($2, 0)""")

def test_nested_functions_propagate_updates(self):
def g(x):
# Create a view of x
Expand Down
9 changes: 9 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@
- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)
self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)

- name: fill.Scalar(Tensor self, Scalar value) -> Tensor
self: zeros_like(grad)
result: at::fill(self_t, 0)

- name: fill.Tensor(Tensor self, Tensor value) -> Tensor
self: zeros_like(grad)
value: grad.sum()
result: at::fill(self_t, value_t)

- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.fill_(0)
Expand Down
2 changes: 2 additions & 0 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
"zero", # only used by the functionalization pass, doesn't need to be exposed to python
"copy", # only used by the functionalization pass
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass
]

SKIP_PYTHON_BINDINGS = list(
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.fft.rfftfreq,
torch.from_file,
torch.full,
torch.fill,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
Expand Down
0