Open
Description
🐛 Describe the bug
Repro:
import torch
from functools import lru_cache
from transformers import AutoConfig, AutoModelForCausalLM
dtype = torch.bfloat16
BS = 2
SEQ_LEN = 4096
HF_TOKEN = None
from torch.fx.passes.split_module import split_module
PARTITION_ID = 0
PARTITION_OPS_CTR = 0
@lru_cache
def callback(node)->int:
global PARTITION_ID, PARTITION_OPS_CTR
if PARTITION_OPS_CTR % 5 == 0:
PARTITION_ID += 1
PARTITION_OPS_CTR += 1
return PARTITION_ID
def backend(gm, inps):
split_gm = split_module(gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True)
return split_gm
model_name = "ibm-granite/granite-3.1-3b-a800m-instruct"
with torch.device("cuda"):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=False)
config.num_hidden_layers = 1
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype, trust_remote_code=False)
torch.compile(model, backend=backend)(torch.randint(0, 10, (BS, SEQ_LEN)))
Error:
File "/usr/local/lib/python3.12/dist-packages/torch/fx/passes/split_module.py", line 395, in split_module
partitions[dependent].dependencies.pop(root_partition)
torch._dynamo.exc.BackendCompilerFailed: backend='backend' raised:
KeyError: '230'
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Versions
main