8000 [BUG] CUDAgraph policy changes the learning process · Issue #3002 · pytorch/rl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[BUG] CUDAgraph policy changes the learning process #3002
Open
@m-krastev

Description

@m-krastev

Describe the bug

I have a single collector and I am training with ClipPPO a simple path finding policy. The agent must find a path in 3D around a simple maze while maximizing the coverage of the maze. It receives as input small patches centered around its current position and outputs the parameters to a Beta distribution (see example). From the distribution, we sample a value which is mapped to some action delta ($a_t \in R^3 $) to determine the next position.

To accelerate the experience sampling, I opted to apply CUDA graphs, but it seems to lead to unstable learning.

A specific pathology I notice is that the standard deviation of the actions determined by the CUDAgraph'd policy tends to zero.

To Reproduce

I can't provide my full code, but the only difference is turning the cudagraph_policy=True in the collector. Maybe good to know that the policy module is torch.compile'd BEFORE it is passed onto the collector.

Policy (CNN extraction head + MLP, Beta distribution):

class IndependentBeta(Independent):
    def __init__(
        self,
        alpha: torch.Tensor,
        beta: torch.Tensor,
        min: Union[float, torch.Tensor] = 0.0,
        max: Union[float, torch.Tensor] = 1.0,
        event_dims: int = 1,
    ):
        self.min = torch.as_tensor(min, device=alpha.device).broadcast_to(alpha.shape)
        self.max = torch.as_tensor(max, device=alpha.device).broadcast_to(alpha.shape)
        self.scale = self.max - self.min
        self.eps = torch.finfo(alpha.dtype).eps
        base_dist = Beta(alpha, beta)
        super().__init__(base_dist, event_dims)

    def sample(self, sample_shape: torch.Size = torch.Size()):
        return super().sample(sample_shape) * self.scale + self.min

    def rsample(self, sample_shape: torch.Size = torch.Size()):
        return super().rsample(sample_shape) * self.scale + self.min

    def log_prob(self, value: torch.Tensor):
        return super().log_prob(((value - self.min) / self.scale).clamp(self.eps, 1.0 - self.eps))

policy_module = ProbabilisticActor(
        module=TensorDictSequential(
            actor_cnn_module,  # Outputs TD with "alpha, beta"
        ),
        spec=action_spec,
        in_keys=["alpha", "beta"],
        out_keys=["action"],
        distribution_class=IndependentBeta,
        return_log_prob=True,
        default_interaction_type=InteractionType.RANDOM,
).to(device)

policy_module.compile(fullgraph=True, dynamic=False)

Collector:

    collector = SyncDataCollector(
        create_env_fn=env_maker,  # Function to create environments
        policy=policy_module,
        # Total frames (steps) to collect in training
        total_frames=total_timesteps - collected_frames,
        frames_per_batch=config.frames_per_batch,
        # No initial random exploration phase needed if policy handles exploration
        init_random_frames=-1,
        split_trajs=False,  # Process rollouts as single batch
        device=device,  # Device for collector ops (usually same as models/env)
        # Device where data is stored (can be CPU if memory is tight)
        storing_device=device,
        max_frames_per_traj=config.max_episode_steps, 
        # num_threads=8
        # cudagraph_policy=True
    )

Training code (AMP+BF16):

    for i, batch_data in enumerate(collector, start=collected_frames):
        current_frames = batch_data.numel()  # Number of steps collected in this batch
        collected_frames += current_frames

        for _ in range(config.update_epochs):
            batch_data = batch_data.reshape(-1)

            with (
                torch.no_grad(),
                torch.autocast(device.type, amp_dtype, enabled=config.amp),
            ):
                    adv_module(batch_data)

            for j in range(0, config.frames_per_batch, batch_size):
                minibatch = batch_data[j : j + batch_size]
                with torch.autocast(device.type, amp_dtype, enabled=config.amp):
                    loss_dict = loss_module(minibatch)

                    actor_loss = loss_dict["loss_objective"] + loss_dict["loss_entropy"]
                    critic_loss = loss_dict["loss_critic"]

                optimizer.zero_grad()
                scaler.scale(actor_loss).backward()
                scaler.scale(critic_loss).backward()
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    loss_module.parameters(), config.max_grad_norm
                )
                scaler.step(optimizer)
                scaler.update()

Expected behavior

The loss/training curves to be the same.

Screenshots

Environment is a custom environment. Red is with cudagraph_policy=False, gray is cudagraph_policy=True

Image

Page 2:

Image

System info

  • torchrl==0.8.0 (pip, uv)
  • Python 3.12.3
  • Linux

Additional context

Add any other context about the problem here.

Reason and Possible fixes

No real idea why that might be.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0