-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix end of epoch StatefulDataLoader restart #1439
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0bdd8c2 with merge base fe6b405 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This does not solve the problem as it just restarts the dataloader and produces the same batches again. |
if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem? |
update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit
8136e63
to
a074b50
Compare
@andrewkho |
update state dict if the iterator has finished add comment about why were updating state dict run precommit
4de1bb4
to
6d49b4f
Compare
TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0. Problem: Root Cause: Proposed Solution: |
For future ref: In torch.utils.data RandomSampler, we are changing the state of the generator even if |
reverse changes to sdl.py generator to iterator run precommit update generator usage
0a90c04
to
39995a3
Compare
Hi, I tried to be clever and implement a temporary workaround myself, but no luck :) Code:
output:
|
@gailweiss can you try these examples using the code on this branch? I guess the code in this branch should be working fine (except for the random generator thing, which is taken care of in #1441). |
@ramanishsingh sure thing! It seems better, but not all fixed. And sadly still breaks if the dataloaders are initiated with random generators. New output for same code:
|
@gailweiss
However, if you want the continuation behavior, you'd need to create an iterator from the dataloader. from torchdata.stateful_dataloader import StatefulDataLoader
def get_dl(num_workers=0):
d = list(range(10))
return StatefulDataLoader(d, batch_size=1, shuffle=True, num_workers=num_workers)
num_workers = 0
dataloader = get_dl(num_workers)
res_before_interruption = []
dl_iter = iter(dataloader)
for i, b in enumerate(dl_iter):
res_before_interruption.append(b)
if i == 2:
print("Interrupting")
break
res_after_interruption = []
for i, b in enumerate(dl_iter):
res_after_interruption.append(b)
print("len(res_before_interruption)=",len(res_before_interruption),"; len(res_after_interruption)=",len(res_after_interruption))
print("res_before_interruption", res_before_interruption)
print("res_after_interruption", res_after_interruption)
print("===", "doing exp again, this time without interruption", "====")
dataloader = get_dl(num_workers)
res_wo_interruption = []
for i, b in enumerate(dataloader):
res_wo_interruption.append(b)
print("res_wo_interruption", res_wo_interruption) Output:
I hope this helps. |
I have assumed orig_dl conveys the intended behaviour - if you look at the prints from resume_after_partial, you will see that orig_dl yields 10 batches from this state. One way or the other the two have to align though - I am not manipulating orig_dl in any way between the state transfer and the test that it is equivalent to loaded_dl |
Hi @gailweiss def resume_after_partial():
dl1, dl2 = get_dl(), get_dl()
a = []
for i, b in enumerate(dl1):
a.append(b.item())
if i > 3:
break
sd = dl1.state_dict()
dl2.load_state_dict(sd)
return same_continuation(dl1, dl2, n_samples) If so, this test should return that This behavior is inline with how the normal Behavior of DataLoader from torch.utils.data import DataLoader
d = list(range(5))
dl = DataLoader(d, batch_size=1)
for i, b in enumerate(dl):
print(i, b)
if i == 2:
break
print("breaking and restarting")
for i, b in enumerate(dl):
print(i, b) Output
If you want statefulness without loading state dict, you can make an iterator (like I specified in my previous comment and example) and iterate over it. Other than that, I believe this PR solves #1437 . I have added tests for all those cases. Thanks! |
Hi, I think we’re talking about the same test, but to me this is unexpected behaviour. In particular, to me the line:
(which is equivalent to the lines
in this test, as there is no intervention between them, not even a loop break/exit) should always lead to a dl2 that is identical to dl1. To have a case where this is not true - regardless of what was done with dl1 beforehand - seems to me counter to the goal of a state. Obviously I do not understand the implementation behind this, but given the tests resume_from_end and resume_between are passing now, I imagine the dataloaders do have the mechanism to differentiate between whether the state was taken from within or without a loop, so hopefully this should be possible to manage? |
If it helps, I think giving: 1) the iterator access to the dataloader, 2) wrapping the iterator in a yield loop, 3) listening in that loop for a GeneratorExit, and finally 4) updating the dataloader as necessary on loop entry/exit will allow maintaining this inside/outside loop state, which should in turn allow successful state transfer to other dataloaders. Example on a dummy iterator:
with these classes, the code:
yields:
|
Hi @gailweiss |
Hi, sure, I understand how this may make the work bigger, and that you may want to turn it into a different PR. I personally do hold firm that the line Because all these tests relate to the transfer of state from |
Add tests to reproduce and fix #1437