Description
🐛 Describe the bug
I tried the following minimal script with cpu and mps. Cpu one has stable memory usage but mps one keeps creeping up, leaking memory. I was working on a more complex linformer dqn model. I kind of narrowed down to these issue.
import torch
import torch.nn as nn
import time
device = torch.device("mps")
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
model = SimpleModel().to(device)
target = SimpleModel().to(device)
while True:
target.load_state_dict(model.state_dict())
time.sleep(0.01)
Affected line: target.load_state_dict(model.state_dict())
import torch
import torch.nn as nn
import time
device = torch.device("mps")
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
input_tensor = torch.randn(32, 10, device=device)
target = torch.randn(32, 10, device=device)
loss_fn = nn.MSELoss()
while True:
optimizer.zero_grad()
output = model(input_tensor)
loss = loss_fn(output, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
time.sleep(0.01)
Affected line: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Versions
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.4 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.0.1
Libc version: N/A
Python version: 3.13.1 (v3.13.1:06714517797, Dec 3 2024, 14:00:22) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime)
Python platform: macOS-15.4-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4 Max
Versions of relevant libraries:
[pip3] numpy==2.2.2
[pip3] performer-pytorch==1.1.4
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchvision==0.22.0
[conda] Could not collect