Description
🐛 Describe the bug
Gradients populated in the backward pass differ significantly between MPS and CPU when the shapes of the relevant operations reach large enough sizes. Demonstrated in the example with two consecutive linear op layers where the gradient of the output layer matches the CPU but the inner layer differs significantly. Decreasing the shape decreases the error until at small shapes the problem is hidden as it goes below the assertion tolerance threshold.
import torch
import copy
from torch.testing import assert_close
torch.manual_seed(123)
#Shapes reproducing the mismatch
n_samples, context_len, emb_size = 2,256,130
vocab_size = 50257
#Example shapes to pass the test
#n_samples, context_len, emb_size = 1,4,8
#vocab_size = 6675
inp_shape = (n_samples, context_len, emb_size)
x_cpu = torch.randn(inp_shape, dtype=torch.float)
x_mps = x_cpu.detach().clone().to('mps')
model_cpu = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size, bias=False),
torch.nn.Linear(emb_size, vocab_size, bias=False),
)
model_mps = copy.deepcopy(model_cpu).to('mps')
cpu_res = model_cpu(x_cpu)
mps_res = model_mps(x_mps)
cpu_targ = torch.randn_like(cpu_res)
mps_targ = cpu_targ.detach().clone().to('mps')
cpu_loss = (cpu_res-cpu_targ).sum()
mps_loss = (mps_res-mps_targ).sum()
torch.mps.synchronize()
assert_close(cpu_loss, mps_loss.to('cpu'))
cpu_loss.backward()
mps_loss.backward()
torch.mps.synchronize()
cpu_dict = {}
mps_dict = {}
for name, param in model_cpu.named_parameters():
if param.requires_grad:
cpu_dict[name] = param.grad
for name, param in model_mps.named_parameters():
if param.requires_grad:
mps_dict[name] = param.grad
for name in cpu_dict.keys():
print(f"Name: {name}")
close = assert_close(cpu_dict[name], mps_dict[name].to('cpu'))
Versions
Collecting environment information...
PyTorch version: 2.8.0.dev20250520
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.3.2.7)
CMake version: version 4.0.0
Libc version: N/A
Python version: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Max