8000 [compile time][inductor] Quadratic compile time observed in Inductor fusion · Issue #154652 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[compile time][inductor] Quadratic compile time observed in Inductor fusion #154652
Closed
@anijain2305

Description

@anijain2305

🐛 Describe the bug

Image

This can be repro'd with following code. We are unrolling a for loop, therefore we have a really big graph. There are better ways to resolve the compile time problem like hierarchical compilation or regional compilation. But even then, we should prevent quadratic behavior in fusion. You can see increase the value of seq_len to observe the quadratic behavior. At 1500, it takes roughly an hour in the fusion in forward, and 4 hours in the backward just for the fusion.

import torch
import torch.nn as nn



class CustomGRUCell(nn.Module):
    def __init__(self, input_size: int, hidden_size: int) -> None:
        super(CustomGRUCell, self).__init__()
        self.hidden_size = hidden_size

        # Linear layers for update gate
        self.linear_xz = nn.Linear(input_size, hidden_size, bias=True)
        self.linear_hz = nn.Linear(hidden_size, hidden_size, bias=True)

        # Linear layers for reset gate
        self.linear_xr = nn.Linear(input_size, hidden_size, bias=True)
        self.linear_hr = nn.Linear(hidden_size, hidden_size, bias=True)

        # Linear layers for candidate hidden state
        self.linear_xn = nn.Linear(input_size, hidden_size, bias=True)
        self.linear_hn = nn.Linear(hidden_size, hidden_size, bias=True)
        self.register_buffer("minus_ones", -torch.ones(1, 1, hidden_size))

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        x = inp[0]
        h_prev = inp[1]

        # Update gate
        z = torch.sigmoid(self.linear_xz(x) + self.linear_hz(h_prev))
        # Reset gate
        r = torch.sigmoid(self.linear_xr(x) + self.linear_hr(h_prev))
        # Candidate hidden state
        n = torch.tanh(self.linear_xn(x) + r * self.linear_hn(h_prev))
        # New hidden state
        # h_next = (self.ones - z) * n + z * h_prev
        # self.ones - z doesn't work - as it will lead to ones.sub(z) op, which doesn't quantize ones.
        # so, we go with z - self.ones instead, which will quantize self.ones tensor.
        h_next = z * h_prev - (z + self.minus_ones) * n
        return h_next



class CustomGRU(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int,
        output_dim: int,
        dropout: float,
        apply_softmax: bool = True,
        apply_sigmoid: bool = False,
    ) -> None:
        super(CustomGRU, self).__init__()
        self.gru_layers = nn.ModuleList()
        self.num_layers = num_layers
        if num_layers == 1 and dropout > 0:
            raise ValueError("Dropout in GRU requires at least 2 layers")

        self.dropout_layer = nn.Dropout(dropout)
        for layer_num in range(num_layers):
            layer_input_size = input_size if layer_num == 0 else hidden_size
            self.gru_layers.append(
                CustomGRUCell(
                    input_size=layer_input_size,
                    hidden_size=hidden_size,
                )
            )

        # setup outputs
        self.apply_sigmoid = apply_sigmoid
        self.head: nn.Module
        if apply_softmax:
            self.head = nn.Sequential(
                nn.Linear(hidden_size, output_dim),
                nn.LogSoftmax(dim=-1),
            )
        else:
            self.head = nn.Linear(hidden_size, output_dim)

    def forward(self, inp: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
        """expects a tuple of (input, hidden states...) as input
        Input should only contain one time step, shape should be (B, 1, input_size)"""

        # loop through each layer and update output and hidden states
        hidden_outputs = []
        x = inp[0]
        for layer_idx in range(self.num_layers):
            gru_in = (x, inp[1 + layer_idx])
            x = self.gru_layers[layer_idx](gru_in)
            if layer_idx < self.num_layers - 1:
                x = self.dropout_layer(x)
            hidden_outputs.append(x)

        # -> (B, output_dim)
        x = x.flatten(start_dim=1)
        x = self.head(x)

        if self.apply_sigmoid:
            x[:, 0] = torch.sigmoid(x[:, 0])

        return x, *hidden_outputs


class CustomGRUAggregationModel(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int,
        output_dim: int,
        dropout: float,
        apply_softmax: bool = True,
        apply_sigmoid: bool = False,
    ) -> None:
        """Wrapper class for CustomGRU, applying GRU to entire sequence."""
        super(CustomGRUAggregationModel, self).__init__()
        self.gru: nn.Module = CustomGRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            output_dim=output_dim,
            dropout=dropout,
            apply_softmax=apply_softmax,
            apply_sigmoid=apply_sigmoid,
        )
        self.hidden_size = hidden_size
        self.num_layers = num_layers

    def forward(self, x):
        # pass to the GRU (input, hidden_states...)
        batch_size, seq_len, _ = x.shape
        hidden_outputs = [
            torch.zeros(
                batch_size, 1, self.hidden_size, dtype=torch.float32, device=x.device
            )
            for _ in range(self.num_layers)
        ]

        # extract all the output values
        pred = []
        for seq_idx in range(seq_len):
            cur_values = self.gru([x[:, seq_idx : seq_idx + 1, :], *hidden_outputs])
            pred.append(cur_values[0])
            hidden_outputs = cur_values[1:]

        # during training we use the entire sequence
        if self.training:
            return torch.stack(pred, dim=1)
        else:
            return pred[-1]


torch.random.manual_seed(0)
# compile time is quadratic with seq_len
seq_len = 500
input_size = 200
hidden_size = 32
output_dim = 61
batch_size = 5
mdl = CustomGRUAggregationModel(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=2,
    output_dim=output_dim,
    dropout=0.1,
    apply_softmax=True,
)

mdl = torch.compile(mdl, fullgraph=True)
x = torch.randn(batch_size, seq_len, input_size)
mdl.eval()
output_test = mdl(x)
print(output_test)

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @eellison @BoyuanFeng

Error logs

No response

Versions

nightly

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and priori 348A tized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0