8000 [MPS] Correctness issue in backward pass of `nn.ReplicationPad1d` and `nn.ReplicationPad2d` · Issue #135447 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[MPS] Correctness issue in backward pass of nn.ReplicationPad1d and nn.ReplicationPad2d #135447
@hvaara

Description

@hvaara

🐛 Describe the bug

This issue was discovered in #134184 and left as follow-up work.

There's a correctness issue in nn.ReplicationPad1d and nn.ReplicationPad2d with certain shapes as input. The reproducer below is for nn.ReplicationPad1d, and the first shape does not show any correctness issues, but the second shape does.

import torch

shapes = ([2, 65736, 4], [65736, 2, 4])
pl, pr = 3, 4

for shape in shapes:
    x_cpu = torch.randn(shape, device='cpu', requires_grad=True)
    x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True)
    
    model = torch.nn.ReplicationPad1d((pl, pr))

    # forward
    out_cpu = model(x_cpu)
    out_mps = model(x_mps)
    print(f"{((x_cpu - x_mps.cpu()).abs() > 1e-5).sum() = }")

    # backward
    g_cpu = torch.randn_like(out_cpu)
    g_mps = g_cpu.clone().detach().to('mps').requires_grad_(True)

    print(f"{((g_cpu - g_mps.cpu()).abs() > 1e-5).sum() = }")

    out_cpu.backward(g_cpu)
    out_mps.backward(g_mps)
    
    print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }")
    print()

# Output:
# ((x_cpu - x_mps.cpu()).abs() > 1e-5).sum() = tensor(0)
# ((g_cpu - g_mps.cpu()).abs() > 1e-5).sum() = tensor(0)
# ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0)
#
# ((x_cpu - x_mps.cpu()).abs() > 1e-5).sum() = tensor(0)
# ((g_cpu - g_mps.cpu()).abs() > 1e-5).sum() = tensor(0)
# ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(524283)

A reproducer for nn.ReplicationPad2d can be found in

pytorch/test/test_nn.py

Lines 8612 to 8656 in defb515

def test_ReplicationPad2d_large(self, device):
shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4])
pl, pr, pt, pb = 3, 4, 5, 6
for shape in shapes:
x = torch.randn(shape, device=device, requires_grad=True)
model = torch.nn.ReplicationPad2d((pl, pr, pt, pb))
# forward center, edge
out = model(x)
self.assertEqual(out[:, :, pt : -pb, pl : -pr], x)
left_padding = out[:, :, pt : -pb, : pl]
self.assertEqual(left_padding, x[:, :, :, :1].expand_as(left_padding))
right_padding = out[:, :, pt : -pb, -pr :]
self.assertEqual(right_padding, x[:, :, :, -1:].expand_as(right_padding))
top_padding = out[:, :, : pt, pl : -pr]
self.assertEqual(top_padding, x[:, :, :1, :].expand_as(top_padding))
bottom_padding = out[:, :, -pb : , pl : -pr]
self.assertEqual(bottom_padding, x[:, :, -1:, :].expand_as(bottom_padding))
# forward corner
tl_padding = out[:, :, : pt + 1, : pl + 1]
self.assertEqual(tl_padding, x[:, :, :1, :1].expand_as(tl_padding))
tr_padding = out[:, :, : pt + 1, -pr - 1:]
self.assertEqual(tr_padding, x[:, :, :1, -1:].expand_as(tr_padding))
bl_padding = out[:, :, -pb - 1:, : pl + 1]
self.assertEqual(bl_padding, x[:, :, -1:, :1].expand_as(bl_padding))
br_padding = out[:, :, -pb - 1:, -pr - 1:]
self.assertEqual(br_padding, x[:, :, -1:, -1:].expand_as(br_padding))
# backward center, edge
g = torch.randn_like(out)
out.backward(g)
self.assertEqual(x.grad[:, :, 1:-1, 1:-1], g[:, :, pt + 1 : -pb - 1, pl + 1 : -pr - 1])
self.assertEqual(x.grad[:, :, 1:-1, 0], g[:, :, pt + 1 : -pb - 1, : pl + 1].sum(-1))
self.assertEqual(x.grad[:, :, 1:-1, -1], g[:, :, pt + 1 : -pb - 1, -pr - 1 :].sum(-1))
self.assertEqual(x.grad[:, :, 0, 1:-1], g[:, :, : pt + 1, pl + 1 : -pr - 1].sum(-2))
self.assertEqual(x.grad[:, :, -1, 1:-1], g[:, :, -pb - 1 :, pl + 1 : -pr - 1].sum(-2))
# backward corner
self.assertEqual(x.grad[:, :, 0, 0], g[:, :, : pt + 1, : pl + 1].sum((-2, -1)))
self.assertEqual(x.grad[:, :, 0, -1], g[:, :, : pt + 1, -pr - 1 :].sum((-2, -1)))
self.assertEqual(x.grad[:, :, -1, 0], g[:, :, -pb - 1 :, : pl + 1].sum((-2, -1)))
self.assertEqual(x.grad[:, :, -1, -1], g[:, :, -pb - 1 :, -pr - 1 :].sum((-2, -1)))

Versions

PyTorch version: 2.5.0a0+git042f2f7
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A

Python version: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Max

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] optree==0.12.1
[pip3] torch==2.5.0a0+git042f2f7
[pip3] torch-tb-profiler==0.4.3
[pip3] torchvision==0.20.0a0+0d80848
[pip3] triton==3.0.0
[conda] numpy 1.26.0 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.5.0a0+git042f2f7 dev_0
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+0d80848 dev_0
[conda] triton 3.0.0 pypi_0 pypi

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @xmfan @kulinseth @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: correctness (silent)issue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0