diff --git a/tensordict/__init__.py b/tensordict/__init__.py index fa1611ee0..40cb8bdc3 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import tensordict._reductions -from tensordict._lazy import LazyStackedTensorDict +from tensordict._lazy import LazyStackedTensorDict, TensorDictCatView from tensordict._nestedkey import NestedKey from tensordict._td import ( cat, diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 88145d71e..454167a58 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -29,21 +29,10 @@ ) import numpy as np - import orjson as json import torch import torch.distributed as dist -from tensordict.memmap import MemoryMappedTensor - -try: - from functorch import dim as ftdim - - _has_funcdim = True -except ImportError: - from tensordict.utils import _ftdim_mock as ftdim - - _has_funcdim = False from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict from tensordict.base import ( _is_leaf_nontensor, @@ -58,6 +47,8 @@ T, TensorDictBase, ) + +from tensordict.memmap import MemoryMappedTensor from tensordict.utils import ( _as_context_manager, _broadcast_tensors, @@ -82,11 +73,22 @@ KeyedJaggedTensor, lock_blocked, NestedKey, + unravel_key, unravel_key_list, ) from torch import Tensor +from torch.utils._pytree import tree_map +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + _has_funcdim = False + _has_functorch = False try: try: @@ -4130,6 +4132,895 @@ def names(self, value): ) +def _pointwise_op_catview_dispatch(func): + """For a pointwise operation, executes the operation on storage and rebuilds the TD with the new storage.""" + name = func.__name__ + if not hasattr(torch.Tensor, name): + raise RuntimeError(f"Unknown method {name} within torch.Tensor.") + + @wraps(func) + def new_func(self, *args, **kwargs): + # We extract the cat_tensors for things like add_ etc. that could accept several TDs + has_td_input = False + + def extract_cat_tensor(obj): + nonlocal has_td_input + if isinstance(obj, TensorDictCatView): + return obj._cat_tensor + if is_tensor_collection(obj): + has_td_input = True + return obj + + from tensordict.nn.functional_modules import _exclude_td_from_pytree + + with _exclude_td_from_pytree(): + proc_args, proc_kwargs = tree_map(extract_cat_tensor, (args, kwargs)) + if has_td_input: + return getattr(self.to_tensordict(clone=False), name)(*args, **kwargs) + cat_tensor = self._cat_tensor + func = getattr(cat_tensor, name) + cat_tensor = func(*proc_args, **proc_kwargs) + offsets = self._offsets + lengths = self._lengths + keys = self._key_map.keys() + dim = self.dim + if cat_tensor.device != offsets.device: + # we send offsets to the device with non_blocking=False because it's safer (unless non_blocking is in the kwargs) + offsets = offsets.to( + cat_tensor.device, non_blocking=kwargs.get("non_blocking", False) + ) + return self._new_unsafe( + lengths=lengths, + offsets=offsets, + keys=keys, + cat_tensor=cat_tensor, + dim=dim, + names=self.names, + device=cat_tensor.device, + batch_size=self.batch_size, + td_example=self._tensordict, + ) + + return new_func + + +class TensorDictCatView(TensorDictBase): + """A lazy representation of a single-storage tensor where keys point to slices of the common tensor. + + + - To clone the TensorDictCatView but keep the same underlying data use `data.clone(recurse=False)`. + - To clone the TensorDictCatView and clone its content as well while keeping the same type, use `data.clone()` + - To transform the TensorDictCatView in a regular TensorDict instance while keeping a single storage for all + tensors, use `data.to_tensordict(clone=False)`. + - To transform the TensorDictCatView in a regular TensorDict instance with new data allocation, use + `data.to_tensordict()`. + + """ + + _is_memmap: bool = False + _is_shared: bool = False + _tensordict: TensorDict + _is_locked: bool = True + + def __init__( + self, + source=None, + batch_size: torch.Size | int | None = None, + device: DeviceType | None = None, + names: Sequence[str] | None = None, + non_blocking: bool | None = None, + lock: bool | None = None, + dim=0, + **kwargs, + ): + if lock not in (None, True): + raise ValueError("TensorDictCatView must be locked.") + + # Set the batch size to make a quick check on the value shapes + if not isinstance(source, TensorDictBase): + source = TensorDict(source, lock=True, batch_size=batch_size, **kwargs) + elif kwargs: + raise ValueError("Cannot provide a TensorDict source as well as kwargs.") + else: + source = source.copy().lock_() + batch_size = source.batch_size + keys, values = zip(*source.items(True, True)) + lengths = torch.tensor([t.shape[dim] for t in values]) + offsets = torch.cat([torch.zeros_like(lengths[:1]), lengths.cumsum(0)]) + lengths = lengths.tolist() + cat_tensor = torch.cat(values, dim=dim) + + if device is not None: + device = torch.device(device) + if cat_tensor.device != device: + if non_blocking is None: + non_blocking = device.type == "cuda" + cat_tensor = cat_tensor.to(device, non_blocking=non_blocking) + + self._init( + lengths, offsets, keys, cat_tensor, dim, names, device, batch_size, source + ) + + def _build_tree(self): + for key in self._key_map.keys(): + if isinstance(key, str): + continue + key = key[0] + item = self._tensordict.empty() + # for key, item in list(self._tensordict.items()): + keys, offsets, lengths = zip( + *( + (unravel_key(k[1:]), offset, length) + for (k, (offset, length)) in self._key_map.items() + if isinstance(k, tuple) and k[0] == key + ) + ) + offsets = torch.stack(offsets) + new_tree = self._new_unsafe( + lengths=lengths, + offsets=offsets, + keys=keys, + cat_tensor=self._cat_tensor, + dim=self.dim, + names=self.names, + device=self._tensordict.device, + batch_size=item.batch_size, + td_example=item, + ) + new_tree._build_tree() + self._tensordict._set_str(key, new_tree, validated=True, inplace=False) + return self + + def materialize(self): + """Ensures all views of the tensordict are materialized. + + Returns self. + """ + # Just iterate over the values + keys, lengths = zip( + *((key, length) for key, (offs, length) in self._key_map.items()) + ) + tensors = self._cat_tensor.split(lengths, dim=self.dim) + for k, t in zip(keys, tensors): + if isinstance(k, tuple): + self._tensordict._set_tuple( + k, t, validated=True, inplace=False, ignore_lock=True + ) + else: + self._tensordict._set_str( + k, t, validated=True, inplace=False, ignore_lock=True + ) + return self + + @classmethod + def _new_unsafe( + cls, + *, + lengths, + offsets, + keys, + cat_tensor, + dim, + names, + device, + batch_size, + td_example, + ): + self = cls.__new__(cls) + return self._init( + lengths, + offsets, + keys, + cat_tensor, + dim, + names, + device, + batch_size, + td_example, + ) + + def _init( + self, + lengths, + offsets, + keys, + cat_tensor, + dim, + names, + device, + batch_size, + td_example, + ): + self._lengths = lengths + self._offsets = offsets + self._cat_tensor = cat_tensor + self.dim = dim + self._key_map = { + key: (offset, length) + for key, offset, length in zip(keys, self._offsets, self._lengths) + } + self._tensordict = td_example.empty( + recurse=False, device=device, names=names, batch_size=batch_size + ) + self._build_tree() + return self + + def _get_str(self, key, default): + result = self._tensordict._get_str(key, None) + if result is None: + # Lazily allocate values to inner tensordict + offset, length = self._key_map.get(unravel_key(key), (None, None)) + if offset is None: + if default is NO_DEFAULT: + raise KeyError("Key {} not found in tensordict".format(key)) + return default + result = self._cat_tensor.narrow(self.dim, offset, length) + self._tensordict._set_str(key, result, validated=True, inplace=False) + + return result + + def _get_tuple(self, key, default): + result = self._tensordict._get_tuple(key, None) + if result is None: + # Lazily allocate values to inner tensordict + offset, length = self._key_map.get(unravel_key(key), (None, None)) + if offset is None: + if default is NO_DEFAULT: + raise KeyError("Key {} not found in tensordict".format(key)) + return default + result = self._cat_tensor.narrow(self.dim, offset, length) + self._tensordict._set_tuple(key, result, validated=True, inplace=False) + + return result + + def _set_str( + self, + key: str, + value: Any, + *, + inplace: bool, + validated: bool, + ignore_lock: bool = False, + non_blocking: bool = False, + ): + return self._tensordict._set_str( + key, + value, + inplace=inplace, + validated=validated, + ignore_lock=ignore_lock, + non_blocking=non_blocking, + ) + + def _set_tuple( + self, + key: NestedKey, + value: dict[str, CompatibleType] | CompatibleType, + *, + inplace: bool, + validated: bool, + non_blocking: bool = False, + ): + return self._tensordict._set_tuple( + key, + value, + inplace=inplace, + validated=validated, + non_blocking=non_blocking, + ) + + def __setitem__(self, key, value): + return TensorDict.__setitem__(self, key, value) + + @property + def names(self): + return self._tensordict.names + + @property + def device(self): + return self._cat_tensor.device + + @property + def batch_size(self): + return self._tensordict.batch_size + + @batch_size.setter + def batch_size(self, value): + self._tensordict.batch_size = value + + def is_locked(self) -> bool: + return True + + def unlock_(self, key): + raise RuntimeError(f"Cannot unlock a {type(self).__name__} instance.") + + # pointwise ops + @_pointwise_op_catview_dispatch + def to(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def add(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def add_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def div(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def div_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def mul(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def mul_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def neg(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def neg_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def erf(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def erf_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def erfc(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def erfc_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def addcdiv(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def addcdiv_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def addcmul(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def addcmul_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def cos(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def cos_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def cosh(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def cosh_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def tan(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def tan_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def tanh(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def tanh_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def atanh(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def atanh_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def asin(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def asin_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sin(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sin_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def lerp(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def lerp_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def minimum(self, *args, **kwargs): ... + + # @_pointwise_op_catview_dispatch + # def minimum_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def maximum(self, *args, **kwargs): ... + + # @_pointwise_op_catview_dispatch + # def maximum_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sub(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sub_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def trunc(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def trunc_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sqrt(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def sqrt_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __eq__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __neg__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __ne__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __xor__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __or__(self, *args, **kwargs): ... + + # @_pointwise_op_catview_dispatch + # def __divmod__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __truediv__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __itruediv__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __idiv__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __mul__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __bool__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __add__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __iadd__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __sub__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __isub__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __imul__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __pow__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __ipow__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __abs__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __ge__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __gt__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __le__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def __lt__(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def clamp_min(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def clamp_min_(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def clamp_max(self, *args, **kwargs): ... + @_pointwise_op_catview_dispatch + def clamp_max_(self, *args, **kwargs): ... + + def all(self, dim: int = None) -> bool | TensorDictBase: + if dim is None: + return self._cat_tensor.all() + return super().all(dim=dim) + + def any(self, dim: int = None) -> bool | TensorDictBase: + if dim is None: + return self._cat_tensor.any() + return super().any(dim=dim) + + def share_memory_(self): + self._is_shared = True + self._cat_tensor.share_memory_() + return self + + def _check_is_shared(self): + if self.is_shared() != self._cat_tensor.is_shared(): + raise RuntimeError("is_shared() attributes don't match.") + + def _clone(self, recurse: bool = False): + cat_tensor = self._cat_tensor + if recurse: + cat_tensor = cat_tensor.clone() + return self._clone_given_cat_tensor(cat_tensor) + + def _clone_given_cat_tensor(self, cat_tensor, **kwargs): + new_kwargs = { + 'lengths': self._lengths, + 'offsets': self._offsets, + 'keys': self._key_map.keys(), + 'cat_tensor': cat_tensor, + 'dim': self.dim, + 'names': self.names, + 'device': self.device, + 'batch_size': self.batch_size, + 'td_example': self._tensordict + } + new_kwargs.update(kwargs) + return self._new_unsafe(**new_kwargs) + + # memmap + + def _load_memmap( + cls, + prefix: Path, + metadata: dict, + device: torch.device | None = None, + *, + out=None, + ): ... + + def _memmap_( + self, + *, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, + share_non_tensor, + existsok, + ) -> T: + raise NotImplementedError + + def make_memmap( + self, + key: NestedKey, + shape: torch.Size | torch.Tensor, + *, + dtype: torch.dtype | None = None, + ) -> MemoryMappedTensor: + raise NotImplementedError + + def make_memmap_from_storage( + self, + key: NestedKey, + storage: torch.UntypedStorage, + shape: torch.Size | torch.Tensor, + *, + dtype: torch.dtype | None = None, + ) -> MemoryMappedTensor: + raise NotImplementedError + + def make_memmap_from_tensor( + self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True + ) -> MemoryMappedTensor: + raise NotImplementedError + + # apply + _apply_nest = TensorDict._apply_nest + + def _multithread_apply_flat( + self, + fn: Callable, + *others: T, + call_on_nested: bool = False, + default: Any = NO_DEFAULT, + named: bool = False, + nested_keys: bool = False, + prefix: tuple = (), + is_leaf: Callable = None, + executor: ThreadPoolExecutor, + futures: List[Future], + local_futures: List, + ) -> None: + raise NotImplementedError + + def _multithread_apply_nest( + self, + fn: Callable, + *others: T, + batch_size: Sequence[int] | None = None, + device: torch.device | None = NO_DEFAULT, + names: Sequence[str] | None = NO_DEFAULT, + inplace: bool = False, + checked: bool = False, + call_on_nested: bool = False, + default: Any = NO_DEFAULT, + named: bool = False, + nested_keys: bool = False, + prefix: tuple = (), + filter_empty: bool | None = None, + is_leaf: Callable = None, + out: TensorDictBase | None = None, + num_threads: int, + call_when_done: Callable | None = None, + **constructor_kwargs, + ) -> T | None: + raise NotImplementedError + + def _multithread_rebuild( + self, + *, + batch_size: Sequence[int] | None = None, + device: torch.device | None = NO_DEFAULT, + names: Sequence[str] | None = NO_DEFAULT, + inplace: bool = False, + checked: bool = False, + out: TensorDictBase | None = None, + filter_empty: bool = False, + executor: ThreadPoolExecutor, + futures: List[Future], + local_futures: List, + subs_results: Dict[Future, Any] | None = None, + multithread_set: bool = False, # Experimental + **constructor_kwargs, + ) -> None: + raise NotImplementedError + + # vmap + def _add_batch_dim(self, *, in_dim, vmap_level): + raise NotImplementedError + + def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): + raise NotImplementedError + + def _remove_batch_dim(self, vmap_level, batch_size, out_dim): + raise NotImplementedError + + def _cast_reduction( + self, + *, + reduction_name, + dim=NO_DEFAULT, + keepdim=NO_DEFAULT, + dtype, + tuple_ok=True, + further_reduce: bool, + **kwargs, + ): + raise NotImplementedError + + def _change_batch_size(self, new_size: torch.Size) -> None: + raise NotImplementedError + + def _check_device(self, *, raise_exception: bool = True) -> None | bool: + if raise_exception: + if self._cat_tensor.device != self.device: + raise RuntimeError( + f"Device mismatch: tensor device {self._cat_tensor.device} doesn't match TD's {self.device}." + ) + else: + return self._cat_tensor.device == self.device + + def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> T: + raise NotImplementedError + + def _erase_names(self): + return self._tensordict._erase_names() + + # select and exclude + def _exclude( + self, + *keys: NestedKey, + inplace: bool = False, + set_shared: bool = True, + ) -> T: + if not inplace: + return TensorDict._exclude( + self, *keys, set_shared=set_shared, inplace=False + ) + raise RuntimeError( + f"Cannot exclude inplace from a {type(self).__name__} instance." + ) + + def _select( + self, + *keys: NestedKey, + inplace: bool = False, + strict: bool = True, + set_shared: bool = True, + ) -> T: + if not inplace: + return TensorDict._select( + self, *keys, strict=strict, set_shared=set_shared, inplace=False + ) + raise RuntimeError( + f"Cannot select inplace from a {type(self).__name__} instance." + ) + + @property + def _td_dim_names(self): + return self._tensordict._td_dim_names + + def _has_names(self): + return self._tensordict._has_names() + + # shape ops + def _permute( + self, + *args, + **kwargs, + ): + raise NotImplementedError + + def _squeeze(self, dim=None): + raise NotImplementedError + + def _view( + self, + *args, + **kwargs, + ) -> T: + raise NotImplementedError + + def expand(self, *args: int | torch.Size) -> T: + raise NotImplementedError + + def _transpose(self, dim0, dim1) -> T: + raise NotImplementedError + + def _unbind(self, dim: int) -> tuple[T, ...]: + return self.to_tensordict(clone=False)._unbind(dim) + + def _unsqueeze(self, dim) -> T: + raise NotImplementedError + + def reshape( + self, + *args, + **kwargs, + ) -> T: + raise NotImplementedError + + def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]: + return self.to_tensordict(clone=False).split(split_size=split_size, dim=dim) + + def _rename_subtds(self, value): + raise NotImplementedError + + def contiguous(self) -> T: + if self._cat_tensor.is_contiguous(): + return self + cat_tensor = self._cat_tensor.contiguous() + + return self._new_unsafe( + lengths=self._lengths, + offsets=self._offsets, + keys=self._key_map.keys(), + cat_tensor=cat_tensor, + dim=self.dim, + names=self.names, + device=self.device, + batch_size=self.batch_size, + td_example=self._tensordict, + ) + + _set_at_str = TensorDict._set_at_str + _set_at_tuple = TensorDict._set_at_tuple + + _stack_onto_ = TensorDict._stack_onto_ + + _to_module = TensorDict._to_module + + del_ = TensorDict.del_ + + def detach(self) -> T: + if not self._cat_tensor.requires_grad: + return self + cat_tensor = self._cat_tensor.detach() + + return self._new_unsafe( + lengths=self._lengths, + offsets=self._offsets, + keys=self._key_map.keys(), + cat_tensor=cat_tensor, + dim=self.dim, + names=self.names, + device=self.device, + batch_size=self.batch_size, + td_example=self._tensordict, + ) + + def detach_(self) -> T: + self._cat_tensor.detach_() + return self + + @cache # noqa: B019 + def keys( + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, + ): + if is_leaf not in ( + None, + _is_leaf_nontensor, + _NESTED_TENSORS_AS_LISTS, + _NESTED_TENSORS_AS_LISTS_NONTENSOR, + ): + raise ValueError( + f"is_leaf is not supported in {type(self).__name__}.keys()." + ) + keys = list(self._key_map.keys()) + if leaves_only and not include_nested: + return [key for key in self._key_map if isinstance(key, str)] + + if not leaves_only: + keys_set = set(keys) + for key in keys: + if isinstance(key, str): + continue + for i in range(1, len(key)): + keys_set.add(key[:i] if i > 1 else key[0]) + keys = keys_set + + if not include_nested: + keys = [key for key in keys if isinstance(key, str)] + + if sort: + keys = sorted(keys, key=lambda key: ".".join(key)) + + return list(keys) + + def items( + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, + ): + yield from ( + (key, self.get(key)) + for key in self.keys( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, + ) + ) + + def values( + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, + ): + yield from ( + self.get(key) + for key in self.keys( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, + ) + ) + + def entry_class(self, key: NestedKey) -> type: + return self._tensordict.entry_class(key) + + @classmethod + def from_dict( + cls, + input_dict, + batch_size: torch.Size | None = None, + device: torch.device | None = None, + batch_dims: int | None = None, + names: List[str] | None = None, + dim: int | None = 0, + ): + result = cls( + input_dict, batch_size=batch_size, device=device, names=names, dim=dim + ) + if batch_dims is not None and batch_dims < result.ndim: + result.batch_size = result.batch_size[:batch_dims] + return result + + def is_contiguous(self) -> bool: + return self._cat_tensor.is_contiguous() + + def popitem(self) -> Tuple[NestedKey, CompatibleType]: + raise NotImplementedError + + def rename_key_( + self, old_key: NestedKey, new_key: NestedKey, safe: bool = False + ) -> T: + raise NotImplementedError + + masked_fill = TensorDict.masked_fill + masked_fill_ = TensorDict.masked_fill_ + masked_select = TensorDict.masked_select + + def from_dict_instance( + self, + input_dict, + batch_size=None, + device=None, + batch_dims=None, + names: List[str] | None = None, + ): + raise NotImplementedError + + _index_tensordict = TensorDict._index_tensordict + __repr__ = TensorDict.__repr__ + + def _iter_items_lazystack( tensordict: LazyStackedTensorDict, return_none_for_het_values: bool = False ) -> Iterator[tuple[str, CompatibleType]]: diff --git a/tensordict/_td.py b/tensordict/_td.py index f56e25052..85e547ea3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -340,7 +340,7 @@ def _new_unsafe( self._device = device self._tensordict = _tensordict = _StringOnlyDict() self._batch_size = batch_size - if source: # faster than calling items + if source is not None: # faster than calling items for key, value in source.items(): if nested and isinstance(value, dict): value = TensorDict._new_unsafe( @@ -4267,10 +4267,13 @@ def _items( _CustomOpTensorDict, _iter_items_lazystack, LazyStackedTensorDict, + TensorDictCatView, ) if isinstance(tensordict, LazyStackedTensorDict): return _iter_items_lazystack(tensordict, return_none_for_het_values=True) + if isinstance(tensordict, TensorDictCatView): + return tensordict.items() if isinstance(tensordict, _CustomOpTensorDict): # it's possible that a TensorDict contains a nested LazyStackedTensorDict, # or _CustomOpTensorDict, so as we iterate through the contents we need to diff --git a/tensordict/base.py b/tensordict/base.py index 1795504b6..7eccf1ddb 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2155,7 +2155,7 @@ def unsqueeze(self, *args, **kwargs): return result @abc.abstractmethod - def _unsqueeze(self, dim): ... + def _unsqueeze(self, dim: int) -> T: ... def _legacy_unsqueeze(self, dim: int) -> T: if dim < 0: @@ -2224,7 +2224,7 @@ def squeeze(self, *args, **kwargs): return result @abc.abstractmethod - def _squeeze(self, dim=None): ... + def _squeeze(self, dim: int | None = None) -> T: ... def _legacy_squeeze(self, dim: int | None = None) -> T: from tensordict._lazy import _SqueezedTensorDict @@ -2677,7 +2677,7 @@ def transpose(self, dim0, dim1): return result @abc.abstractmethod - def _transpose(self, dim0, dim1): ... + def _transpose(self, dim0: int, dim1: int) -> T: ... def _legacy_transpose(self, dim0, dim1): if dim0 < 0: @@ -9214,24 +9214,33 @@ def _maybe_set_shared_attributes(self, result, lock=False): if lock: result.lock_() - def to_tensordict(self) -> T: + def to_tensordict(self, *, clone: bool = True) -> T: """Returns a regular TensorDict instance from the TensorDictBase. + Keyword Args: + clone (bool, optional): if ``True``, the values are cloned. Otherwise, + a :class:`~tensordict.TensorDict` instance is returned with values not cloned. + Defaults to ``True``. + Returns: a new TensorDict object containing the same values. """ from tensordict import TensorDict + def _maybe_clone(value): + if clone and not _is_tensor_collection(type(value)): + return value.clone() + if not clone or is_non_tensor(value): + return value + return value.to_tensordict(clone=clone) + + d = { + key: _maybe_clone(value) + for key, value in self.items(is_leaf=_is_leaf_nontensor) + } return TensorDict( - { - key: ( - value.clone() - if not _is_tensor_collection(type(value)) - else value if is_non_tensor(value) else value.to_tensordict() - ) - for key, value in self.items(is_leaf=_is_leaf_nontensor) - }, + d, device=self.device, batch_size=self.batch_size, names=self._maybe_names(), diff --git a/test/test_tensordict.py b/test/test_tensordict.py index d0f00a738..3cb5785a3 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -42,6 +42,7 @@ PersistentTensorDict, set_get_defaults_to_none, TensorDict, + TensorDictCatView, ) from tensordict._lazy import _CustomOpTensorDict from tensordict._reductions import _reduce_td @@ -10823,6 +10824,231 @@ def _to_float(td, td_name, tmpdir): return td +class TestTensorDictCatView: + @pytest.fixture + def lazy_cat_tensordict(self): + with torch.device("cuda:0" if torch.cuda.is_available() else "cpu"): + return TensorDictCatView( + { + "a": torch.zeros(3, 4), + "b": {"c": torch.ones(3, 4)}, + ("b", "e"): torch.full((3, 4), 2.0), + }, + batch_size=[3], + ) + + def test_keys_values_items(self, lazy_cat_tensordict): + assert set(lazy_cat_tensordict.keys()) == {"a", "b"} + assert set(lazy_cat_tensordict["b"].keys()) == {"c", "e"} + assert set(lazy_cat_tensordict.keys(True, True)) == { + "a", + ("b", "c"), + ("b", "e"), + } + assert set( + lazy_cat_tensordict.keys(include_nested=False, leaves_only=True) + ) == {"a"} + assert set( + lazy_cat_tensordict.keys(include_nested=True, leaves_only=False) + ) == {"a", "b", ("b", "c"), ("b", "e")} + + # items + assert set(list(zip(*lazy_cat_tensordict.items()))[0]) == {"a", "b"} + assert set(list(zip(*lazy_cat_tensordict.items(True, True)))[0]) == { + "a", + ("b", "c"), + ("b", "e"), + } + + def test_get_set_(self, lazy_cat_tensordict): + assert "a" not in lazy_cat_tensordict._tensordict + _ = lazy_cat_tensordict.get("a") + assert "a" in lazy_cat_tensordict._tensordict + c = lazy_cat_tensordict.get(("b", "c")) + cc = c.clone() + assert c is lazy_cat_tensordict["b", "c"] + lazy_cat_tensordict.set_(("b", "c"), c + 1) + assert (c == cc + 1).all() + + def test_sub_td(self, lazy_cat_tensordict): + assert isinstance(lazy_cat_tensordict, TensorDictCatView) + assert not isinstance(lazy_cat_tensordict, TensorDict) + b = lazy_cat_tensordict["b"] + assert isinstance(b, TensorDictCatView), type(b) + assert b.batch_size == lazy_cat_tensordict.batch_size + + def test_data_ptr(self, lazy_cat_tensordict): + assert ( + len( + set( + lazy_cat_tensordict.data_ptr(storage=True) + .flatten_keys() + .stack_from_tensordict() + .tolist() + ) + ) + == 1 + ) + + def test_to_tensordict(self, lazy_cat_tensordict): + td = lazy_cat_tensordict.to_tensordict(clone=False) + assert isinstance(td, TensorDict) + assert ( + len( + set( + td.data_ptr(storage=True) + .flatten_keys() + .stack_from_tensordict() + .tolist() + ) + ) + == 1 + ) + td = lazy_cat_tensordict.to_tensordict() + assert isinstance(td, TensorDict) + assert ( + len( + set( + td.data_ptr(storage=True) + .flatten_keys() + .stack_from_tensordict() + .tolist() + ) + ) + > 1 + ) + + def test_materialize(self, lazy_cat_tensordict): + assert "a" not in lazy_cat_tensordict._tensordict + assert "c" not in lazy_cat_tensordict._tensordict["b"]._tensordict + lazy_cat_tensordict.materialize() + assert "a" in lazy_cat_tensordict._tensordict + assert "c" in lazy_cat_tensordict._tensordict["b"]._tensordict + + def test_clone(self, lazy_cat_tensordict): + td = lazy_cat_tensordict.clone(False) + assert td._cat_tensor is lazy_cat_tensordict._cat_tensor + td = lazy_cat_tensordict.clone() + assert isinstance(td, TensorDictCatView) + assert ( + td._cat_tensor.untyped_storage().data_ptr() + != lazy_cat_tensordict._cat_tensor.untyped_storage().data_ptr() + ) + + def test_pointwise(self, lazy_cat_tensordict): + out0 = lazy_cat_tensordict + lazy_cat_tensordict + out1 = lazy_cat_tensordict * 2 + out2 = lazy_cat_tensordict * 4 - out1 + assert (out0 == out1).all() + assert (out0 == out2).all() + assert lazy_cat_tensordict.batch_size == out0.batch_size + assert isinstance(out0, TensorDictCatView) + assert isinstance(out1, TensorDictCatView) + assert isinstance(out2, TensorDictCatView) + ct = lazy_cat_tensordict._cat_tensor.clone() + td = lazy_cat_tensordict.to_tensordict(clone=False) + assert (lazy_cat_tensordict._cat_tensor == ct).all() + assert (td["a"] == 0).all() + assert (lazy_cat_tensordict._cat_tensor == ct).all() + assert isinstance(lazy_cat_tensordict == td, TensorDict) + assert_allclose_td(lazy_cat_tensordict, td) + assert isinstance(td == lazy_cat_tensordict, TensorDict) + assert (lazy_cat_tensordict._cat_tensor == ct).all() + assert (lazy_cat_tensordict._cat_tensor == ct).all() + + def test_apply(self, lazy_cat_tensordict): + ref = (lazy_cat_tensordict + 1).to_tensordict(clone=False) + applied = lazy_cat_tensordict.apply(lambda x: x + 1) + assert set(ref.keys(True, True)) == set(applied.keys(True, True)) + assert_allclose_td(applied, ref) + assert_allclose_td(lazy_cat_tensordict.named_apply(lambda name, x: x + 1), ref) + assert_allclose_td( + lazy_cat_tensordict.named_apply(lambda name, x: x + 1, call_on_nested=True), + ref, + ) + + def test_to_device(self, lazy_cat_tensordict): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + elif torch.backends.mps.is_available(): + device = torch.device("mps:0") + else: + device = torch.device("cpu") + lazy_cat_tensordict_device = lazy_cat_tensordict.to(device) + assert lazy_cat_tensordict_device.device == device + assert lazy_cat_tensordict_device["a"].device == device + assert ( + lazy_cat_tensordict_device["a"].untyped_storage().data_ptr() + == lazy_cat_tensordict_device["b", "e"].untyped_storage().data_ptr() + ) + + def test_share_mem(self, lazy_cat_tensordict): + if lazy_cat_tensordict.device.type == "cpu": + assert not lazy_cat_tensordict.is_shared() + elif lazy_cat_tensordict.device.type == "cuda": + assert lazy_cat_tensordict.is_shared() + elif lazy_cat_tensordict.device.type == "mps": + assert not lazy_cat_tensordict.is_shared() + return + lazy_cat_tensordict.share_memory_() + assert lazy_cat_tensordict.is_shared() + + def test_setitem(self, lazy_cat_tensordict): + lazy_cat_tensordict[0] = lazy_cat_tensordict[0].clone() + 5 + assert (lazy_cat_tensordict[0] == lazy_cat_tensordict[1] + 5).all() + assert_allclose_td(lazy_cat_tensordict[0], lazy_cat_tensordict[1] + 5) + + def test_unbind(self, lazy_cat_tensordict): + assert "a" not in lazy_cat_tensordict._tensordict + td0, td1, td2 = lazy_cat_tensordict.unbind(0) + assert set(td0.keys(True, True)) == set(lazy_cat_tensordict.keys(True, True)) + assert isinstance(td0, TensorDict) + assert ( + td0.data_ptr(storage=True) == lazy_cat_tensordict.data_ptr(storage=True) + ).all() + + def test_split(self, lazy_cat_tensordict): + assert "a" not in lazy_cat_tensordict._tensordict + td0, td1 = lazy_cat_tensordict.split([1, 2]) + assert set(td0.keys(True, True)) == set(lazy_cat_tensordict.keys(True, True)) + assert isinstance(td0, TensorDict) + assert ( + td0.data_ptr(storage=True) == lazy_cat_tensordict.data_ptr(storage=True) + ).all() + assert set(td1.keys(True, True)) == set(lazy_cat_tensordict.keys(True, True)) + assert isinstance(td1, TensorDict) + assert ( + td1.data_ptr(storage=True) == lazy_cat_tensordict.data_ptr(storage=True) + ).all() + + def test_stack_onto(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_flatten_unflatten_keys(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_flatten_unflatten(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_reshape(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_view(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_permute(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_transpose(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_squeeze_unsqueeze(self, lazy_cat_tensordict): + raise NotImplementedError + + def test_repr(self, lazy_cat_tensordict): + assert repr(lazy_cat_tensordict) == str(lazy_cat_tensordict) == "" + + _SYNC_COUNTER = 0 @@ -10841,4 +11067,4 @@ def _sync_td(self): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main([__file__, "--capture", "no", "--exitfirst", "--tb", "short"] + unknown)