Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix end of epoch StatefulDataLoader restart #1439

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

ramanishsingh
Copy link
Contributor

@ramanishsingh ramanishsingh commented Feb 3, 2025

Add tests to reproduce and fix #1437

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2025
@ramanishsingh ramanishsingh marked this pull request as draft February 3, 2025 23:36
Copy link

pytorch-bot bot commented Feb 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1439

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0bdd8c2 with merge base fe6b405 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ramanishsingh
Copy link
Contributor Author

This does not solve the problem as it just restarts the dataloader and produces the same batches again.

@gailweiss
Copy link

if its any help, while creating the issue i noticed that after loading the state dict, the resulting state dict in the dataloader is different from the one that was loaded - for example, by setting "samples_yielded" to 0 when the loaded one had 100 (see the prints in #1437 ), (and possibly more differences - I haven't checked). looking at the code in this commit, it seems that samples_yielded is being set manually - maybe that is the root of the problem?

update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 8136e63 to a074b50 Compare February 4, 2025 22:48
@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 06:37
@ramanishsingh ramanishsingh marked this pull request as draft February 5, 2025 14:02
@ramanishsingh ramanishsingh marked this pull request as ready for review February 5, 2025 22:56
@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 5, 2025

@andrewkho
Thanks.
I took your implementation of BatchSamplerIterator from here.
I find that during the loading of the state dict, if the _StatefulRandomSamplerIterator is at its end, its self.next_yielded value is becoming None due to iter re-init from somewhere.
To tackle that, I am artificially making it 0 by checking if we are at the end of an epoch and exhausting the iterator (Line 534 stateful_dataloader.py) .
I think it is less brittle than checking the length of the sampler and skipping one whole epoch. Please lmk your thoughts.

update state dict if the iterator has finished

add comment about why were updating state dict

run precommit
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 4de1bb4 to 6d49b4f Compare February 7, 2025 07:05
@ramanishsingh
Copy link
Contributor Author

TLDR: After refactoring BatchSampler, the same batch sequence is repeated in the epoch following a reload due to _iterator_finished being True. Update the generator in the state_dict after each iteration to cache the latest state, ensuring RNG resumes correctly even if next_yielded is reset to 0.

Problem:
After breaking the BatchSampler into BatchSampler and _BatchSamplerIterator, we encountered an issue where the same sequence of batches is produced in the epoch immediately following a reload, mirroring the last epoch before saving the state_dict.

Root Cause:
This issue arises because the dl state_dict is saved after the epoch completes, resulting in _iterator_finished being set to True. To illustrate, consider the epoch after reloading as epoch 3. In the state_dict of the RandomSampler (a subset of the dl state_dict), key items include self.next_yielded and the state of the generator. When a StatefulDataLoader (SDL) is instantiated with num_workers = 0 and batches are retrieved, the iter method in SDL is invoked. This method utilizes next_iter_state (or the loaded_state_dict) to obtain an iterator. During this process, the generator, sampler_iter, etc., are reloaded. However, since _iterator_finished is True, the _StatefulSingleProcessDataLoaderIter that was generated is discarded, and a new one is created with state_dict=None. Consequently, we lose the RandomSampler state information because next_yielded is reset to 0, and the generator state remains at the start of epoch 2.

Proposed Solution:
While there may be more efficient solutions, one potential approach (that I have implemented) is to update the generator in the state_dict upon completing an iteration. By doing so, we cache the latest generator state, allowing us to resume RNG production from the correct point even when the RandomSampler is reset with next_yielded = 0.

@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 10, 2025

For future ref:
https://github.com/pytorch/pytorch/blob/652880e8403b58ca44d482f200a8991b8b326e88/torch/utils/data/sampler.py#L190

In torch.utils.data RandomSampler, we are changing the state of the generator even if self.num_samples % n==0 and we dont even use any samples from that permutation. A more efficient (so that we don't generate a randperm if we dont need one) and simpler solution would be to add a check self.num_samples % n>0 and then generate a random permutation.

reverse changes to sdl.py

generator to iterator

run precommit

update generator usage
@ramanishsingh ramanishsingh force-pushed the fix_EndOfEpoch_sdl_restart branch from 0a90c04 to 39995a3 Compare February 11, 2025 06:09
@gailweiss
Copy link

gailweiss commented Feb 11, 2025

Hi, I tried to be clever and implement a temporary workaround myself, but no luck :)
In the process however, I wrote a bunch of tests that I think reflect how the stateful dataloader should behave, it currently fails several. Maybe they will be helpful? They also surface the fact that once the stateful_dataloaders do get randomly initialised (see #1440 ), even if by a user just making them with different generators at first, even more things break - so it seems like loading the state dict doesnt completely wipe the dataloader's state...?

Code:

import torch
from copy import deepcopy


# test support functions: printing with notes, comparing dictionaries, comparing epochs,
# getting dataloaders, running tests, etc

class PrefPrint:
    def __init__(self):
        self.pref = ""
    def print(self, *a, **kw):
        print(self.pref, *a, **kw)
    
printer = PrefPrint()

def pprint(*a, **kw):
    printer.print(*a, **kw)

def same_continuation(orig, loaded, el, expect_false=False):
    l1 = [b.item() for b in orig]
    l2 = [b.item() for b in loaded]
    # print(l1, l2)
    if len(l1) != el:
        pprint(f"orig dl's epoch wrong length: {l1} (expected {el})")
        return False
    if len(l2) != el:
        pprint(f"loaded dl's epoch wrong length: {l2} (expected {el}) (orig dl's epoch was correct length)")
        return False
    res = (l1 == l2)
    if (not res) and (not expect_false):
        pprint("orig vs loaded dl epochs:", l1, l2)
    return res

def equal_state_dicts(d1, d2):
    def _comp_dicts(d1, d2, pref=">>\t"):
        if sorted(list(d1.keys())) != sorted(list(d2.keys())):
            return f"{pref} diff keys: {list(sd1.keys())}, {list(sd2.keys())}"
        for k in d1.keys():
            if isinstance(d1[k], dict):
                return _comp_dicts(d1[k], d2[k], pref=f"{pref}in {k}: ")
            elif isinstance(d1[k], torch.Tensor):
                if False in (d1[k] == d2[k]):
                    return f"{pref} diff on {k}: {d1[k].tolist()} vs {d2[k].tolist()}"
            elif d1[k] != d2[k]:
                return f"{pref} diff on {k}: {d1[k]} vs {d2[k]}"
        return ""
    res = _comp_dicts(d1, d2)
    if res:
        pprint(res)
        return False
    return True

def get_dl(from_seed=None, gen=None):
    assert None in [from_seed, gen]
    d = list(range(n_samples))
    if None is not from_seed:
        gen = torch.Generator()
        gen.manual_seed(from_seed)
    return DataLoader(d, generator=gen, batch_size=1, shuffle=True)

def run_test(f, *a, **kw):
    printer.pref = f"in {f.__name__}:"
    res = f(*a, **kw)
    printer.pref = ""
    return res

def test(f):
    tests.append(f)
    return f

tests = []
n_samples = 10 # length all the dataloaders will be in the tests

# tests themselves


# 1. 2 different inits create different shuffles
@test
def diffshuff_as_standard():
    # n = 10 -> 10! shuffles -> highly unlikely to accidentally get same epoch
    return not same_continuation(get_dl(), get_dl(), n_samples, expect_false=True)

# 2. 2 different inits from the same generator create the same shuffles
@test
def sameshuff_when_asked():
    if not same_continuation(get_dl(from_seed=1), get_dl(from_seed=1), n_samples):
        pprint("mismatch on seed 1")
        return False
    gen1, gen2 = torch.Generator(), torch.Generator()
    gen1.seed()
    gen2.set_state(gen1.get_state())
    if not same_continuation(get_dl(gen=gen1), get_dl(gen=gen2), n_samples):
        pprint("mismatch on random gen")
        return False
    return True

# 3. getting state dict and loading it after own state has changed recovers 
# previous state (ie state dict properly detached once taken)
@test
def go_back():
    dl1 = get_dl()
    sd = dl1.state_dict()
    a1 = [b.item() for b in dl1]
    dl1.load_state_dict(sd)
    a2 = [b.item() for b in dl1]
    if not a1 == a2:
        pprint("doesnt walk back steps")
        return False
    return True

# 4. loading a state dict taken from the middle of an epoch continues that epoch
@test
def resume_from_partial():
    dl1, dl2 = get_dl(), get_dl()
    a1 = []
    p = 3
    for i, b in enumerate(dl1):
        if i == p:
            sd = dl1.state_dict()
        if i > p:
            a1.append(b.item())
    dl2.load_state_dict(sd)
    a2 = [b.item() for b in dl2]
    res = a1 == a2
    if not res:
        pprint("diff continuation from partial sd")
        pprint(a1, a2)
    if not same_continuation(dl1, dl2, n_samples):
        pprint("next epoch after partial sd not aligned")
        res = False
    return res

# 5. loading a state dict taken after a dataloader has left a partial epoch resumes (like that dataloader) from that dataloader's next full epoch
@test
def resume_after_partial():
    dl1, dl2 = get_dl(), get_dl()
    a = []
    for i, b in enumerate(dl1):
        a.append(b.item())
        if i > 3:
            break
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    return same_continuation(dl1, dl2, n_samples)


# 6. loading a state_dict taken after a full epoch has completed continues smoothly from that dataloader's next epoch 
# this highlights the main issue raised in https://github.com/pytorch/data/issues/1437
@test
def resume_between():
    g1, g2 = 2, 2
    dl1, dl2 = get_dl(), get_dl()
    _ = [[b.item() for b in dl1] for _ in range(g1)]
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    for i in range(g2):
        if not same_continuation(dl1, dl2, n_samples):  # should be full epochs here
            pprint(f"epoch {i} after load is broken")
            return False
    return True


# 7. loading a state_dict taken partway through an epoch, specifically at the last batch (but before leaving the loop) resumes at same point (i.e., inside epoch, with 0 batches left), and then continues to the same next epoch as the original dataloader
@test
def resume_end():
    dl1, dl2 = get_dl(), get_dl()
    l1 = []
    for i, b in enumerate(dl1):
        l1.append(b.item())
        sd = dl1.state_dict()
        if i == n_samples - 1:
            break  # dl1 not finished
    pprint("dl1 first epoch is:", l1)

    dl2.load_state_dict(sd)
    l = [b.item() for b in dl2]
    if len(l) > 0:
        print("resuming finishing state does not lead to empty epoch")
        return False
    if not same_continuation(dl1, dl2, n_samples):
        pprint("resuming finishing state does not move (after finishing) to same next epoch")
        pprint("is this an off-by-one? dl1 next next epoch would be:", [b.item() for b in dl1])
        return False
    return True

# dataloader variants
from torchdata.stateful_dataloader import StatefulDataLoader as _DataLoader
import torch
from copy import deepcopy

class LoudDataLoader(_DataLoader):
    def load_state_dict(self, state_dict):
        sd1 = deepcopy(state_dict)
        super(LoudDataLoader, self).load_state_dict(sd1)
        equal_state_dicts(state_dict, self.state_dict())  # will print a difference if it finds one

# DataLoader to fix https://github.com/pytorch/data/issues/1440
class LoudDataLoader1440(LoudDataLoader):
    def __init__(self, *a, **kw):
        if None is kw.get("generator", None):
            kw["generator"] = torch.Generator()
            kw["generator"].seed()  # for some reason important for getting it going
        super().__init__(*a, **kw)



dlclasses = {"base":_DataLoader, "loud":LoudDataLoader, "loud1440":LoudDataLoader1440}
results = {}
for n, DLC in dlclasses.items():
    DataLoader = DLC
    print(f"\n================\nrunning {n} stateful dataloader tests")
    results[n] = {f.__name__ : run_test(f) for f in tests}

print("\n\n=======")
for n, r in results.items():
    print(f"\n=======\n {n} stateful dataloader test results:")
    names = [f.__name__ for f in tests]
    print("\n".join(f"{n}: \t\t[{r[n]}]" for n in names))

output:

================
running base stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]
in resume_end: orig vs loaded dl epochs: [2, 3, 1, 8, 9, 0, 6, 7, 4, 5] [2, 8, 1, 5, 6, 9, 3, 7, 0, 4]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [2, 9, 8, 3, 6, 7, 1, 0, 4, 5]

================
running loud stateful dataloader tests
in go_back: >>	in _index_sampler_state:  diff on samples_yielded: 0 vs 10
in resume_from_partial: >>	in _index_sampler_state:  diff on samples_yielded: 4 vs 0
in resume_after_partial: >>	in _index_sampler_state:  diff on samples_yielded: 5 vs 0
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_between: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]
in resume_end: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_end: orig vs loaded dl epochs: [2, 3, 1, 8, 9, 0, 6, 7, 4, 5] [2, 8, 1, 5, 6, 9, 3, 7, 0, 4]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [2, 9, 8, 3, 6, 7, 1, 0, 4, 5]

================
running loud1440 stateful dataloader tests
in go_back: >>	in _index_sampler_state:  diff on samples_yielded: 0 vs 10
in go_back: doesnt walk back steps
in resume_from_partial: >>	in _index_sampler_state:  diff on samples_yielded: 4 vs 0
in resume_from_partial: diff continuation from partial sd
in resume_from_partial: [8, 9, 4, 1, 2, 3] [6, 8, 4, 2, 1, 7]
in resume_from_partial: orig vs loaded dl epochs: [2, 9, 6, 1, 5, 8, 3, 0, 4, 7] [8, 9, 0, 6, 3, 4, 7, 5, 1, 2]
in resume_from_partial: next epoch after partial sd not aligned
in resume_after_partial: >>	in _index_sampler_state:  diff on samples_yielded: 5 vs 0
in resume_after_partial: loaded dl's epoch wrong length: [4, 3, 2, 5, 1] (expected 10) (orig dl's epoch was correct length)
in resume_between: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_between: loaded dl's epoch wrong length: [] (expected 10) (orig dl's epoch was correct length)
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [4, 6, 7, 9, 2, 3, 0, 5, 1, 8]
in resume_end: >>	in _index_sampler_state:  diff on samples_yielded: 10 vs 0
in resume_end: orig vs loaded dl epochs: [5, 3, 1, 6, 7, 9, 2, 8, 4, 0] [7, 8, 3, 9, 4, 2, 1, 0, 5, 6]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [0, 5, 6, 3, 4, 8, 2, 1, 7, 9]


=======

=======
 base stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

=======
 loud stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

=======
 loud1440 stateful dataloader test results:
diffshuff_as_standard: 		[True]
sameshuff_when_asked: 		[True]
go_back: 		[False]
resume_from_partial: 		[False]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

@ramanishsingh
Copy link
Contributor Author

@gailweiss can you try these examples using the code on this branch? I guess the code in this branch should be working fine (except for the random generator thing, which is taken care of in #1441).

@gailweiss
Copy link

@ramanishsingh sure thing!

It seems better, but not all fixed. And sadly still breaks if the dataloaders are initiated with random generators. New output for same code:

================
running base stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]

================
running loud stateful dataloader tests
in resume_after_partial: loaded dl's epoch wrong length: [8, 9, 3, 7, 4] (expected 10) (orig dl's epoch was correct length)
in resume_end: dl1 first epoch is: [5, 6, 1, 2, 0, 8, 9, 3, 7, 4]

================
running loud1440 stateful dataloader tests
in go_back: doesnt walk back steps
in resume_from_partial: diff continuation from partial sd
in resume_from_partial: [5, 2, 6, 7, 9, 1] [9, 2, 3, 7, 8, 6]
in resume_from_partial: orig vs loaded dl epochs: [3, 5, 2, 6, 1, 4, 0, 7, 9, 8] [2, 9, 5, 6, 4, 7, 0, 1, 8, 3]
in resume_from_partial: next epoch after partial sd not aligned
in resume_after_partial: loaded dl's epoch wrong length: [8, 5, 7, 0, 9] (expected 10) (orig dl's epoch was correct length)
in resume_between: orig vs loaded dl epochs: [8, 2, 9, 3, 0, 5, 4, 7, 1, 6] [0, 5, 6, 1, 3, 4, 8, 7, 2, 9]
in resume_between: epoch 0 after load is broken
in resume_end: dl1 first epoch is: [7, 3, 8, 4, 0, 2, 6, 1, 9, 5]
in resume_end: orig vs loaded dl epochs: [6, 9, 1, 0, 3, 4, 7, 5, 8, 2] [7, 2, 4, 9, 5, 1, 6, 3, 0, 8]
in resume_end: resuming finishing state does not move (after finishing) to same next epoch
in resume_end: is this an off-by-one? dl1 next next epoch would be: [3, 6, 7, 8, 2, 1, 9, 5, 0, 4]


=======

=======
 base stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[True]
resume_end: 		[True]

=======
 loud stateful dataloader test results:
diffshuff_as_standard: 		[False]
sameshuff_when_asked: 		[True]
go_back: 		[True]
resume_from_partial: 		[True]
resume_after_partial: 		[False]
resume_between: 		[True]
resume_end: 		[True]

=======
 loud1440 stateful dataloader test results:
diffshuff_as_standard: 		[True]
sameshuff_when_asked: 		[True]
go_back: 		[False]
resume_from_partial: 		[False]
resume_after_partial: 		[False]
resume_between: 		[False]
resume_end: 		[False]

@ramanishsingh
Copy link
Contributor Author

@gailweiss
Focusing on the base_dataloder (the randomness issue is fixed in #1441 ) I think according to you it is failing in this test

def resume_after_partial():
    dl1, dl2 = get_dl(), get_dl()
    a = []
    for i, b in enumerate(dl1):
        a.append(b.item())
        if i > 3:
            break
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    return same_continuation(dl1, dl2, n_samples)

However, if you want the continuation behavior, you'd need to create an iterator from the dataloader.
Here's an example and output

from torchdata.stateful_dataloader import StatefulDataLoader


def get_dl(num_workers=0):
    d = list(range(10))
    return StatefulDataLoader(d, batch_size=1, shuffle=True, num_workers=num_workers)


num_workers = 0
dataloader = get_dl(num_workers)
res_before_interruption = []
dl_iter = iter(dataloader)
for i, b in enumerate(dl_iter):
    res_before_interruption.append(b)
    if i == 2:
        print("Interrupting")
        break
    

res_after_interruption = []
for i, b in enumerate(dl_iter):
    res_after_interruption.append(b)
print("len(res_before_interruption)=",len(res_before_interruption),";    len(res_after_interruption)=",len(res_after_interruption))

print("res_before_interruption", res_before_interruption)
print("res_after_interruption", res_after_interruption)

print("===", "doing exp again, this time without interruption", "====")

dataloader = get_dl(num_workers)
res_wo_interruption = []
for i, b in enumerate(dataloader):
    res_wo_interruption.append(b)
print("res_wo_interruption", res_wo_interruption)

Output:

Interrupting
len(res_before_interruption)= 3 ;    len(res_after_interruption)= 7
res_before_interruption [tensor([5]), tensor([6]), tensor([1])]
res_after_interruption [tensor([2]), tensor([0]), tensor([8]), tensor([9]), tensor([3]), tensor([7]), tensor([4])]
=== doing exp again, this time without interruption ====
res_wo_interruption [tensor([5]), tensor([6]), tensor([1]), tensor([2]), tensor([0]), tensor([8]), tensor([9]), tensor([3]), tensor([7]), tensor([4])]


I hope this helps.

@gailweiss
Copy link

I have assumed orig_dl conveys the intended behaviour - if you look at the prints from resume_after_partial, you will see that orig_dl yields 10 batches from this state. One way or the other the two have to align though - I am not manipulating orig_dl in any way between the state transfer and the test that it is equivalent to loaded_dl

@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 14, 2025

Hi @gailweiss
I just want to confirm that we are talking about the same test.
Is this the test?

def resume_after_partial():
    dl1, dl2 = get_dl(), get_dl()
    a = []
    for i, b in enumerate(dl1):
        a.append(b.item())
        if i > 3:
            break
    sd = dl1.state_dict()
    dl2.load_state_dict(sd)
    return same_continuation(dl1, dl2, n_samples)

If so, this test should return that dl1 and dl2 are not same.
While the dl1.state_dict() method returns the latest state of the dataloader, it (the state) needs to be applied to the dataloader so that when the next time its iterator is called, it moves to the right pointer.
This is because, when we do for i, b in enumerate(dl1): it creates the __iter__ method which in turn calls the [get_iterator]
(https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/stateful_dataloader.py#L381) and it gets the iterator here which can use the state stored in next_iter_state.
Now even though dl1 has yielded three batches, its next_iter_state is None because we haven't updated it yet. It stays as None.
But for dl2 we do dl2.load_state_dict(sd) and thus its next_iter_state is populated and gets the information that it has yielded 3 batches/samples.
Thus when we call same_continuation(dl1, dl2, n_samples) it returns that these two dataloaders return different batches. Which should be the expected behavior. dl1's internal state will need to be updated if we are creating a new iterator using it.

This behavior is inline with how the normal Dataloader in torch.utils.data and normal lists behave.
I understand it is not exactly an apples to apples comparison as the normal Dataloader isn't stateful, but it's just how the Design of the StatefulDataloader is, you need to load the state_dict such the iterator gets the right state to start from.

Behavior of DataLoader

from torch.utils.data import DataLoader
d = list(range(5))
dl = DataLoader(d, batch_size=1)
for i, b in enumerate(dl):
    print(i, b)
    if i == 2:
      break
print("breaking and restarting")
for i, b in enumerate(dl):
    print(i, b)

Output

0 tensor([0])
1 tensor([1])
2 tensor([2])
breaking and restarting
0 tensor([0])
1 tensor([1])
2 tensor([2])
3 tensor([3])
4 tensor([4])

If you want statefulness without loading state dict, you can make an iterator (like I specified in my previous comment and example) and iterate over it.
Please let me know if this makes sense. :)

Other than that, I believe this PR solves #1437 . I have added tests for all those cases.

Thanks!

@gailweiss
Copy link

gailweiss commented Feb 14, 2025

Hi, I think we’re talking about the same test, but to me this is unexpected behaviour. In particular, to me the line:

dl2.load_state_dict(dl1.state_dict())

(which is equivalent to the lines

sd = dl1.state_dict()
dl2.load_state_dict(sd)

in this test, as there is no intervention between them, not even a loop break/exit)

should always lead to a dl2 that is identical to dl1. To have a case where this is not true - regardless of what was done with dl1 beforehand - seems to me counter to the goal of a state.

Obviously I do not understand the implementation behind this, but given the tests resume_from_end and resume_between are passing now, I imagine the dataloaders do have the mechanism to differentiate between whether the state was taken from within or without a loop, so hopefully this should be possible to manage?

@gailweiss
Copy link

If it helps, I think giving: 1) the iterator access to the dataloader, 2) wrapping the iterator in a yield loop, 3) listening in that loop for a GeneratorExit, and finally 4) updating the dataloader as necessary on loop entry/exit will allow maintaining this inside/outside loop state, which should in turn allow successful state transfer to other dataloaders. Example on a dummy iterator:

class MyIterator:
    def __init__(self, source):
        self.source = source
        self.i = -1
        self.n = len(source.d)
        
    def __next__(self):
        self.source.in_loop = True
        self.i += 1
        if self.i >= self.n:
            raise StopIteration
        return self.source.d[self.i]

    def close(self):
        self.source.in_loop = False
    

def wrapiterator(it):
    while True:
        try:
            yield next(it)
        except GeneratorExit:
            it.close()
            break
        except StopIteration:
            it.close()
            break

class MyRange:
    def __init__(self, n):
        self.d = list(range(1,n+1))
        self.in_loop = False
    def print_state(self, expected):
        print("loop state aligns with expected:", expected == self.in_loop)
    def __iter__(self):
        return wrapiterator(MyIterator(self))

with these classes, the code:


a = MyRange(3)
a.print_state(False)
for i in a:
    print(i)
    a.print_state(True)
    break
a.print_state(False)

for i in a:
    print(i)
    a.print_state(True)
a.print_state(False)

yields:

loop state aligns with expected: True
1
loop state aligns with expected: True
loop state aligns with expected: True
1
loop state aligns with expected: True
2
loop state aligns with expected: True
3
loop state aligns with expected: True
loop state aligns with expected: True

@ramanishsingh
Copy link
Contributor Author

ramanishsingh commented Feb 14, 2025

Hi @gailweiss
I understand your argument. The current behavior is just how the StatefulDataloader is designed and it is in parity with how the normal Dataloader works, and the expected behavior regarding the need to (re)load the state_dict might be different for different people. I will surely bring it up with the team, and if any changes are required, I suspect that it will be a slightly bigger refactor. Thus, it is out of scope of this current PR as this PR is aimed for #1437 , which I think it solves (please lmk if you see any issues with that). For your recent comments, after discussion with the team, I will start a new issue and tag you there. Please let me know if the current behavior is blocking you in any way (I hope the extra line of loading state dict should unblock you).
Thanks for your hard work!

@gailweiss
Copy link

Hi, sure, I understand how this may make the work bigger, and that you may want to turn it into a different PR. I personally do hold firm that the line dl2.load_state_dict(dl1.state_dict()) should always yield dl2 that continues identically to dl1 at that point. Looking at the message above, I think there has been a miscommunication at some point, so I want to be clear: for the purposes of this conversation, I don't really care what dl1 does one way or the other, and I am sure that it does follow the regular DataLoader as desired. My only goal here is that dl2 mimicks it correctly after receiving its state, and resume_after_partial surfaces a case where it does not.

Because all these tests relate to the transfer of state from dl1 to dl2, I suspect that whatever change will fix resume_after_partial (I suggest adding an in_loop flag to the state, and controlling this flag through the iterator, which can recognise both finished and broken loops by listening for GeneratorExit and StopIteration) will also be useful for fixing the original issue (end of epoch state transfer), and I think solving them all together will likely yield cleaner code than handling them independently. I note my suggestion only works if you assume no nesting of loops over the StatefulDataLoader! But my impression is that this assumption is already made for the StatefulDataLoader anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
4 participants