8000 Compile + torch.autograd.grad returns no gradients · Issue #132929 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Compile + torch.autograd.grad returns no gradients  #132929
Open
@daniel-z-kaplan

Description

@daniel-z-kaplan

🐛 Describe the bug

This is a (very) minimal example of some code that's not working. It runs (seemingly) successfully without compile, but has issues with.

from torch import nn
import torch

class loss_function():
    
    def __init__(self):
        pass
    
    def loss(self, latents, targets):
        
        grads = torch.autograd.grad(outputs=[(targets).sum()], inputs=[latents], 
                create_graph=True)[0]
        
        print(grads)
        return grads.mean()


class network(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.down = torch.nn.Conv2d(3,4,1,1)
        self.up = torch.nn.Conv2d(4,3,1,1)
    
    def forward(self, x):
        z = self.down(x)
        dec = self.up(z)
        return dec, z

device = "cuda"
model = network().to(device)
loss_function = loss_function()


def run():
    
    data = torch.randn(4,3,256,256).to(device)
    output, latents = model(data)
    loss = loss_function.loss(latents = latents, targets = output)
    loss_two = output.mean() + loss
    loss_two.backward()

run()
#Succeeds



compiled = torch.compile(model)
model = compiled
run()
  File "test_compile.py", line 51, in <module>
    run()
  File "test_compile.py", line 39, in run
    loss = loss_function.loss(latents = latents, targets = output)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "test_compile.py", line 11, in loss
    grads = torch.autograd.grad(outputs=[(targets).sum()], inputs=[latents], 
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".local/lib/python3.12/site-packages/torch/autograd/__init__.py", line 436, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File ".local/lib/python3.12/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Versions

PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Fedora release 39 (Thirty Nine) (x86_64)
GCC version: (GCC) 13.2.1 20231205 (Red Hat 13.2.1-6)
Clang version: 17.0.6 (Fedora 17.0.6-2.fc39)
CMake version: version 3.27.7
Libc version: glibc-2.38

Python version: 3.12.1 (main, Dec 18 2023, 00:00:00) [GCC 13.2.1 20231205 (Red Hat 13.2.1-6)] (64-bit runtime)
Python platform: Linux-6.7.4-100.fc38.x86_64-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.154.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
CPU family: 6
Model: 79
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 1
CPU(s) scaling MHz: 54%
CPU max MHz: 4000.0000
CPU min MHz: 1200.0000
BogoMIPS: 7199.98
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 pti intel_ppin ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 1.5 MiB (6 instances)
L3 cache: 15 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable

Versions of relevant libraries:
[pip3] ema-pytorch==0.5.1
[pip3] gigagan-pytorch==0.2.20
[pip3] lion-pytorch==0.1.2
[pip3] numpy==1.26.4
[pip3] open-clip-torch==2.24.0
[pip3] pytorch-lightning==2.2.0.post0
[pip3] torch==2.4.0
[pip3] torch-pitch-shift==1.2.4
[pip3] torchaudio==2.2.0
[pip3] torchaudio-augmentations==0.2.4
[pip3] torcheval==0.0.7
[pip3] torchlibrosa==0.1.0
[pip3] torchmetrics==1.3.1
[pip3] torchvision==0.17.0
[pip3] triton==3.0.0
[conda] Could not collect

cc @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0