Open
Description
Describe the bug
When creating spec sizes for an environment, the specs expanded by batch size take up memory proportionally to batch size.
For example, the space attribute of Bounded will have:
2 (low/high) * spec size * batch_size entries, independant if the spec is created as expand() which is expected to be a view of the tensor.
To Reproduce
import torch
from torchrl.envs import EnvBase
from torchrl.data import Bounded, Composite
class MyMinimalEnv(EnvBase):
def __init__(self, batch_size=None):
self.batch_size = [batch_size if batch_size is not None else 1]
super().__init__(batch_size=self.batch_size)
# Define specs
obs_spec_shape = torch.Size([3])
self.observation_spec = Composite(
observation=Bounded(low=-1, high =1, shape=obs_spec_shape,)
).expand(self.batch_size)
def _reset(self, tensordict=None):
return
def _step(self, tensordict):
return
def _set_seed(self):
return
# Instantiate and test
batch_size = 512
env = MyMinimalEnv(batch_size=batch_size)
print("Observation spec:", env.observation_spec)
print(env.observation_spec['observation'].space.high.storage())
This creates the output:
> [torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 1536]
Expected behavior
The spec should store only the batched view of the spec.
System info
torchrl==0.8.0
Reason and Possible fixes
I am not sure why .clone is called on all the tensors on the expand() function in, e.g., Bounded()
def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
raise ValueError(
f"The last {self.ndim} of the expanded shape {shape} must match the"
f"shape of the {self.__class__.__name__} spec in expand()."
)
return self.__class__(
low=self.space.low.expand(_remove_neg_shapes(shape)).clone(),
high=self.space.high.expand(_remove_neg_shapes(shape)).clone(),
shape=shape,
device=self.device,
dtype=self.dtype,
)
But it seems like this is done on purpose on several locations such as the init function:
if high.ndimension():
if shape_corr is not None and shape_corr != high.shape:
raise RuntimeError(err_msg)
if shape is None:
shape = high.shape
if shape_corr is not None:
low = low.expand(shape_corr).clone()
elif low.ndimension():
if shape_corr is not None and shape_corr != low.shape:
raise RuntimeError(err_msg)
if shape is None:
shape = low.shape
if shape_corr is not None:
high = high.expand(shape_corr).clone()
elif shape_corr is None:
raise RuntimeError(err_msg)
else:
low = low.expand(shape_corr).clone()
high = high.expand(shape_corr).clone()
if low.numel() > high.numel():
high = high.expand_as(low).clone()
elif high.numel() > low.numel():
low = low.expand_as(high).clone()
if shape_corr is None:
shape = low.shape
else:
if isinstance(shape_corr, float):
shape_corr = _size([shape_corr])
elif not isinstance(shape_corr, torch.Size):
shape_corr = _size(shape_corr)
shape_corr_err_msg = (
f"low and shape_corr mismatch, got {low.shape} and {shape_corr}"
)
if len(low.shape) != len(shape_corr):
raise RuntimeError(shape_corr_err_msg)
if not all(_s == _sa for _s, _sa in zip(shape_corr, low.shape)):
raise RuntimeError(shape_corr_err_msg)
self.shape = shape
Checklist
- [ x ] I have checked that there is no similar issue in the repo (required)
- [ x ] I have read the documentation (required)
- [ x ] I have provided a minimal working example to reproduce the bug (required)