diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index ce1a9b379903e..7d26220cdf7e8 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -70,7 +70,7 @@ class ForkerIterDataPipe(IterDataPipe): num_instances: number of instances of the datapipe to create buffer_size: this restricts how far ahead the leading child DataPipe can read relative to the slowest child DataPipe. - Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + Defaults to ``10000``. Use ``None`` for the unlimited buffer. copy: copy strategy to use for items yielded by each branch. Supported options are ``None`` for no copying, ``"shallow"`` for shallow object copies, and ``"deep"`` for deep object copies. Defaults to ``None``. @@ -94,7 +94,7 @@ def __new__( cls, datapipe: IterDataPipe, num_instances: int, - buffer_size: int = 1000, + buffer_size: int = 10000, copy: Optional[Literal["shallow", "deep"]] = None ): if num_instances < 1: @@ -143,7 +143,7 @@ def __init__( self, datapipe: IterDataPipe, num_instances: int, - buffer_size: int = 1000, + buffer_size: int = 10000, copy: Optional[Literal["shallow", "deep"]] = None ): self.main_datapipe = datapipe @@ -152,11 +152,7 @@ def __init__( self.buffer: Deque = deque() self.buffer_size = buffer_size if self.buffer_size < 0: - warnings.warn( - "Unlimited buffer size is set for `fork`, " - "please be aware of OOM at random places", - UserWarning - ) + raise ValueError("Buffer size needs to be positive or None") if copy is None: self.copy_fn = _no_op elif copy == "shallow": @@ -206,7 +202,7 @@ def get_next_element_by_instance(self, instance_id: int): if self.slowest_ptr < new_min: self.slowest_ptr = new_min self.buffer.popleft() - if self.buffer_size >= 0 and self.leading_ptr > self.buffer_size + self.slowest_ptr: + if self.buffer_size is not None and self.leading_ptr > self.buffer_size + self.slowest_ptr: raise BufferError("ForkerIterDataPipe buffer overflow," + f"buffer size {self.buffer_size} is insufficient.") @@ -360,7 +356,7 @@ class DemultiplexerIterDataPipe(IterDataPipe): drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` buffer_size: this defines the maximum number of inputs that the buffer can hold across all child DataPipes while waiting for their values to be yielded. - Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + Defaults to ``100ß0``. Use ``None`` for the unlimited buffer. Examples: >>> # xdoctest: +REQUIRES(module:torchdata) @@ -383,7 +379,7 @@ class DemultiplexerIterDataPipe(IterDataPipe): [1, 3] """ def __new__(cls, datapipe: IterDataPipe, num_instances: int, - classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000): + classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 10000): if num_instances < 1: raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") @@ -410,11 +406,7 @@ def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int, self.num_instances = num_instances self.buffer_size = buffer_size if self.buffer_size < 0: - warnings.warn( - "Unlimited buffer size is set for `demux`, " - "please be aware of OOM at random places", - UserWarning - ) + raise ValueError("Buffer size needs to be positive or None") self.current_buffer_usage = 0 self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)] self.classifier_fn = classifier_fn @@ -442,7 +434,7 @@ def _find_next(self, instance_id: int) -> T_co: return value self.child_buffers[classification].append(value) self.current_buffer_usage += 1 - if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: + if self.buffer_size is not None and self.current_buffer_usage > self.buffer_size: raise BufferError( f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.")