8000 [BUG] TensorSpecs require memory proportionally to batch size · Issue #2974 · pytorch/rl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[BUG] TensorSpecs require memory proportionally to batch size #2974
Open
@NTR0314

Description

@NTR0314

Describe the bug

When creating spec sizes for an environment, the specs expanded by batch size take up memory proportionally to batch size.

For example, the space attribute of Bounded will have:

2 (low/high) * spec size * batch_size entries, independant if the spec is created as expand() which is expected to be a view of the tensor.

To Reproduce

import torch
from torchrl.envs import EnvBase
from torchrl.data import Bounded, Composite

class MyMinimalEnv(EnvBase):
    def __init__(self, batch_size=None):
        self.batch_size = [batch_size if batch_size is not None else 1]
        super().__init__(batch_size=self.batch_size)

        # Define specs
        obs_spec_shape = torch.Size([3])
        self.observation_spec = Composite(
            observation=Bounded(low=-1, high =1, shape=obs_spec_shape,)
        ).expand(self.batch_size)

    def _reset(self, tensordict=None):
        return

    def _step(self, tensordict):
        return

    def _set_seed(self):
        return

# Instantiate and test
batch_size = 512
env = MyMinimalEnv(batch_size=batch_size)
print("Observation spec:", env.observation_spec)
print(env.observation_spec['observation'].space.high.storage())

This creates the output:

> [torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 1536]

Expected behavior

The spec should store only the batched view of the spec.

System info

torchrl==0.8.0

Reason and Possible fixes

I am not sure why .clone is called on all the tensors on the expand() function in, e.g., Bounded()

    def expand(self, *shape):
        if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
            shape = shape[0]
        if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
            raise ValueError(
                f"The last {self.ndim} of the expanded shape {shape} must match the"
                f"shape of the {self.__class__.__name__} spec in expand()."
            )
        return self.__class__(
            low=self.space.low.expand(_remove_neg_shapes(shape)).clone(),
            high=self.space.high.expand(_remove_neg_shapes(shape)).clone(),
            shape=shape,
            device=self.device,
            dtype=self.dtype,
        )

But it seems like this is done on purpose on several locations such as the init function:

        if high.ndimension():
            if shape_corr is not None and shape_corr != high.shape:
                raise RuntimeError(err_msg)
            if shape is None:
                shape = high.shape
            if shape_corr is not None:
                low = low.expand(shape_corr).clone()
        elif low.ndimension():
            if shape_corr is not None and shape_corr != low.shape:
                raise RuntimeError(err_msg)
            if shape is None:
                shape = low.shape
            if shape_corr is not None:
                high = high.expand(shape_corr).clone()
        elif shape_corr is None:
            raise RuntimeError(err_msg)
        else:
            low = low.expand(shape_corr).clone()
            high = high.expand(shape_corr).clone()

        if low.numel() > high.numel():
            high = high.expand_as(low).clone()
        elif high.numel() > low.numel():
            low = low.expand_as(high).clone()
        if shape_corr is None:
            shape = low.shape
        else:
            if isinstance(shape_corr, float):
                shape_corr = _size([shape_corr])
            elif not isinstance(shape_corr, torch.Size):
                shape_corr = _size(shape_corr)
            shape_corr_err_msg = (
                f"low and shape_corr mismatch, got {low.shape} and {shape_corr}"
            )
            if len(low.shape) != len(shape_corr):
                raise RuntimeError(shape_corr_err_msg)
            if not all(_s == _sa for _s, _sa in zip(shape_corr, low.shape)):
                raise RuntimeError(shape_corr_err_msg)
        self.shape = shape

Checklist

  • [ x ] I have checked that there is no similar issue in the repo (required)
  • [ x ] I have read the documentation (required)
  • [ x ] I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0