8000 [FSDP2] for mixed precision, input casting can get blocked when cuda streams are full · Issue #154272 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[FSDP2] for mixed precision, input casting can get blocked when cuda streams are full #154272
Open
@weifengpy

Description

@weifengpy

🚀 The feature, motivation and pitch

this is raised by @kwen2501 : for mixed precision, we cast input per tensor. kernel launching (see cpu thread in the snapshot) can get blocked when there are many input tensors. when it launch 1024-ish kernels, we start seeing "command buffer full" and cudaLaunchKernel gets blocked

we probably need a best practice to avoid such situation. for example, recommending user to pack/unpack in user code

not sure how often this happens though

Image

if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype:
with torch.profiler.record_function("FSDP::cast_forward_inputs"):
cast_fn = functools.partial(
_cast_fp_tensor, self._mp_policy.param_dtype
)
args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs)

repro

# torchrun --standalone --nproc_per_node=2 run_fsdp2.py

import os
import pickle

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard

torch.cuda.memory._record_memory_history(max_entries=100000)


import contextlib
import os

import torch
import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(4, 4, bias=False)

    def forward(self, x):
        return x[0].sum() + self.l1.weight.sum()


@contextlib.contextmanager
def enable_profiling(enable=False):
    if not enable:
        torch_profiler = contextlib.nullcontext()
        yield None
    else:
        trace_dir = "./profilers"
        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0

        def trace_handler(prof):
            curr_trace_dir_name = "iteration_" + str(prof.step_num)
            curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
            if not os.path.exists(curr_trace_dir):
                os.makedirs(curr_trace_dir, exist_ok=True)
            prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")

        if not os.path.exists(trace_dir):
            os.makedirs(trace_dir, exist_ok=True)
        warmup, active = 1, 2
        wait = 1
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
            on_trace_ready=trace_handler,
            record_shapes=True,
        ) as torch_profiler:
            yield torch_profiler


def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    torch.manual_seed(0)
    model = MyModel()
    from torch.distributed.fsdp import MixedPrecisionPolicy

    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16, reduce_dtype=torch.float32
    )
    fully_shard(model, mp_policy=mp_policy)
    optim = torch.optim.Adam(model.parameters(), lr=1e-2)
    x = [torch.randn((1024, 1024), device=device, requires_grad=True)] * 10000
    stream = torch.cuda.Stream()
    m = torch.randn((16384, 16384), device=device, requires_grad=True)

    with enable_profiling(True) as prof:
        for _ in range(10):
            with torch.cuda.stream(stream):
                torch.nn.functional.linear(m, m)
            model(x).sum().backward()
            optim.step()
            prof.step()


if __name__ == "__main__":
    main()

Alternatives

No response

Additional context

No response

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0