8000 Mixed precision causes NaN loss · Issue #40497 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Mixed precision causes NaN loss #40497
Open
@ruathudo

Description

@ruathudo

🐛 Bug

I'm using autocast with GradScaler to train on mixed precision. For small dataset, it works fine. But when I trained on bigger dataset, after few epochs (3-4), the loss turns to nan.
It is seq2seq, transformer model, using Adam optimizer, cross entropy criterion.

Here is the training code:

def get_correction(output, target):
    diff = torch.sum((output != target), axis=1)
    acc = torch.sum(diff == 0)
    return acc.item()

def train(model, data_loader, optimizer, criterion, device, scaler):
    clip = 1
    model.train()
    epoch_loss = 0
    total_correct = 0
    total_sample = 0

    for i, batch in enumerate(data_loader):
        optimizer.zero_grad()
        src, trg = batch
        src = src.to(device, non_blocking=True)
        trg = trg.to(device, non_blocking=True)

        with autocast():
            output, _ = model(src, trg[:, :-1])

            y_pred = torch.argmax(output, 2)
            y_true = trg[:, 1:]

            total_sample += y_true.shape[0]
            total_correct += get_correction(y_pred, y_true)

            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        scaler.step(optimizer)
        scaler.update()

    epoch_loss = epoch_loss / len(data_loader)
    acc = total_correct / total_sample

    return epoch_loss, acc

Note that the get_correction function is just for calculate the accuracy based on word level instead of character level.

Environment

  • PyTorch Version: 1.6.0.dev20200623
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.7.5
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: RTX 2060 super

cc @mcarilli @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: NaNs and InfsProblems related to NaN and Inf handling in floating pointmodule: amp (automated mixed precision)autocasttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0