Skip to content

Commit

Permalink
Unify buffer_size, following the convention suggested in pytorch/data…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderbattig committed Mar 11, 2023
1 parent ab148da commit c3c7f5e
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions torch/utils/data/datapipes/iter/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ForkerIterDataPipe(IterDataPipe):
num_instances: number of instances of the datapipe to create
buffer_size: this restricts how far ahead the leading child DataPipe
can read relative to the slowest child DataPipe.
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
Defaults to ``10000``. Use ``None`` for the unlimited buffer.
copy: copy strategy to use for items yielded by each branch. Supported
options are ``None`` for no copying, ``"shallow"`` for shallow object
copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
Expand All @@ -94,7 +94,7 @@ def __new__(
cls,
datapipe: IterDataPipe,
num_instances: int,
buffer_size: int = 1000,
buffer_size: int = 10000,
copy: Optional[Literal["shallow", "deep"]] = None
):
if num_instances < 1:
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
self,
datapipe: IterDataPipe,
num_instances: int,
buffer_size: int = 1000,
buffer_size: int = 10000,
copy: Optional[Literal["shallow", "deep"]] = None
):
self.main_datapipe = datapipe
Expand All @@ -152,11 +152,7 @@ def __init__(
self.buffer: Deque = deque()
self.buffer_size = buffer_size
if self.buffer_size < 0:
warnings.warn(
"Unlimited buffer size is set for `fork`, "
"please be aware of OOM at random places",
UserWarning
)
raise ValueError("Buffer size needs to be positive or None")
if copy is None:
self.copy_fn = _no_op
elif copy == "shallow":
Expand Down Expand Up @@ -206,7 +202,7 @@ def get_next_element_by_instance(self, instance_id: int):
if self.slowest_ptr < new_min:
self.slowest_ptr = new_min
self.buffer.popleft()
if self.buffer_size >= 0 and self.leading_ptr > self.buffer_size + self.slowest_ptr:
if self.buffer_size is not None and self.leading_ptr > self.buffer_size + self.slowest_ptr:
raise BufferError("ForkerIterDataPipe buffer overflow," +
f"buffer size {self.buffer_size} is insufficient.")

Expand Down Expand Up @@ -360,7 +356,7 @@ class DemultiplexerIterDataPipe(IterDataPipe):
drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
DataPipes while waiting for their values to be yielded.
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
Defaults to ``100ß0``. Use ``None`` for the unlimited buffer.
Examples:
>>> # xdoctest: +REQUIRES(module:torchdata)
Expand All @@ -383,7 +379,7 @@ class DemultiplexerIterDataPipe(IterDataPipe):
[1, 3]
"""
def __new__(cls, datapipe: IterDataPipe, num_instances: int,
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 10000):
if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")

Expand All @@ -410,11 +406,7 @@ def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int,
self.num_instances = num_instances
self.buffer_size = buffer_size
if self.buffer_size < 0:
warnings.warn(
"Unlimited buffer size is set for `demux`, "
"please be aware of OOM at random places",
UserWarning
)
raise ValueError("Buffer size needs to be positive or None")
self.current_buffer_usage = 0
self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
self.classifier_fn = classifier_fn
Expand Down Expand Up @@ -442,7 +434,7 @@ def _find_next(self, instance_id: int) -> T_co:
return value
self.child_buffers[classification].append(value)
self.current_buffer_usage += 1
if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
if self.buffer_size is not None and self.current_buffer_usage > self.buffer_size:
raise BufferError(
f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.")

Expand Down

0 comments on commit c3c7f5e

Please sign in to comment.