From 1dacbad67161bd13807aa1b24f142fe2185c9cbf Mon Sep 17 00:00:00 2001 From: hpp Date: Mon, 5 Aug 2024 11:31:05 +0800 Subject: [PATCH] fix MPRS request index cycle --- torchdata/dataloader2/communication/iter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 73a54d1c9..95cc8ca20 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -431,17 +431,16 @@ def __iter__(self): disabled_pipe[res_idx] = True cnt_disabled_pipes += 1 disabled = True - req_idx = next(req_idx_cycle) else: # Only request if buffer is empty and has not reached the limit if len(self.res_buffers[res_idx]) == 0 and ( self._limit is None or self._request_cnt < self._limit ): self.datapipes[req_idx].protocol.request_next() - req_idx = next(req_idx_cycle) self._request_cnt += 1 total_req_cnt += 1 total_res_cnt += 1 + req_idx = next(req_idx_cycle) res_idx = next(res_idx_cycle) if not disabled: yield response.value