Closed
Description
🐛 Describe the bug
This can be repro'd with following code. We are unrolling a for loop, therefore we have a really big graph. There are better ways to resolve the compile time problem like hierarchical compilation or regional compilation. But even then, we should prevent quadratic behavior in fusion. You can see increase the value of seq_len to observe the quadratic behavior. At 1500, it takes roughly an hour in the fusion in forward, and 4 hours in the backward just for the fusion.
import torch
import torch.nn as nn
class CustomGRUCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int) -> None:
super(CustomGRUCell, self).__init__()
self.hidden_size = hidden_size
# Linear layers for update gate
self.linear_xz = nn.Linear(input_size, hidden_size, bias=True)
self.linear_hz = nn.Linear(hidden_size, hidden_size, bias=True)
# Linear layers for reset gate
self.linear_xr = nn.Linear(input_size, hidden_size, bias=True)
self.linear_hr = nn.Linear(hidden_size, hidden_size, bias=True)
# Linear layers for candidate hidden state
self.linear_xn = nn.Linear(input_size, hidden_size, bias=True)
self.linear_hn = nn.Linear(hidden_size, hidden_size, bias=True)
self.register_buffer("minus_ones", -torch.ones(1, 1, hidden_size))
def forward(self, inp: torch.Tensor) -> torch.Tensor:
x = inp[0]
h_prev = inp[1]
# Update gate
z = torch.sigmoid(self.linear_xz(x) + self.linear_hz(h_prev))
# Reset gate
r = torch.sigmoid(self.linear_xr(x) + self.linear_hr(h_prev))
# Candidate hidden state
n = torch.tanh(self.linear_xn(x) + r * self.linear_hn(h_prev))
# New hidden state
# h_next = (self.ones - z) * n + z * h_prev
# self.ones - z doesn't work - as it will lead to ones.sub(z) op, which doesn't quantize ones.
# so, we go with z - self.ones instead, which will quantize self.ones tensor.
h_next = z * h_prev - (z + self.minus_ones) * n
return h_next
class CustomGRU(torch.nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
output_dim: int,
dropout: float,
apply_softmax: bool = True,
apply_sigmoid: bool = False,
) -> None:
super(CustomGRU, self).__init__()
self.gru_layers = nn.ModuleList()
self.num_layers = num_layers
if num_layers == 1 and dropout > 0:
raise ValueError("Dropout in GRU requires at least 2 layers")
self.dropout_layer = nn.Dropout(dropout)
for layer_num in range(num_layers):
layer_input_size = input_size if layer_num == 0 else hidden_size
self.gru_layers.append(
CustomGRUCell(
input_size=layer_input_size,
hidden_size=hidden_size,
)
)
# setup outputs
self.apply_sigmoid = apply_sigmoid
self.head: nn.Module
if apply_softmax:
self.head = nn.Sequential(
nn.Linear(hidden_size, output_dim),
nn.LogSoftmax(dim=-1),
)
else:
self.head = nn.Linear(hidden_size, output_dim)
def forward(self, inp: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
"""expects a tuple of (input, hidden states...) as input
Input should only contain one time step, shape should be (B, 1, input_size)"""
# loop through each layer and update output and hidden states
hidden_outputs = []
x = inp[0]
for layer_idx in range(self.num_layers):
gru_in = (x, inp[1 + layer_idx])
x = self.gru_layers[layer_idx](gru_in)
if layer_idx < self.num_layers - 1:
x = self.dropout_layer(x)
hidden_outputs.append(x)
# -> (B, output_dim)
x = x.flatten(start_dim=1)
x = self.head(x)
if self.apply_sigmoid:
x[:, 0] = torch.sigmoid(x[:, 0])
return x, *hidden_outputs
class CustomGRUAggregationModel(torch.nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
output_dim: int,
dropout: float,
apply_softmax: bool = True,
apply_sigmoid: bool = False,
) -> None:
"""Wrapper class for CustomGRU, applying GRU to entire sequence."""
super(CustomGRUAggregationModel, self).__init__()
self.gru: nn.Module = CustomGRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_dim=output_dim,
dropout=dropout,
apply_softmax=apply_softmax,
apply_sigmoid=apply_sigmoid,
)
self.hidden_size = hidden_size
self.num_layers = num_layers
def forward(self, x):
# pass to the GRU (input, hidden_states...)
batch_size, seq_len, _ = x.shape
hidden_outputs = [
torch.zeros(
batch_size, 1, self.hidden_size, dtype=torch.float32, device=x.device
)
for _ in range(self.num_layers)
]
# extract all the output values
pred = []
for seq_idx in range(seq_len):
cur_values = self.gru([x[:, seq_idx : seq_idx + 1, :], *hidden_outputs])
pred.append(cur_values[0])
hidden_outputs = cur_values[1:]
# during training we use the entire sequence
if self.training:
return torch.stack(pred, dim=1)
else:
return pred[-1]
torch.random.manual_seed(0)
# compile time is quadratic with seq_len
seq_len = 500
input_size = 200
hidden_size = 32
output_dim = 61
batch_size = 5
mdl = CustomGRUAggregationModel(
input_size=input_size,
hidden_size=hidden_size,
num_layers=2,
output_dim=output_dim,
dropout=0.1,
apply_softmax=True,
)
mdl = torch.compile(mdl, fullgraph=True)
x = torch.randn(batch_size, seq_len, input_size)
mdl.eval()
output_test = mdl(x)
print(output_test)
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @eellison @BoyuanFeng
Error logs
No response
Versions
nightly