Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding state parser utility that can be used for retrieving and modifying worker states #1278

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/stateful_dataloader/test_state_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torch.testing._internal.common_utils import TestCase

from torch.utils.data import Dataset, IterableDataset
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader, StateParserUtil


class StatefulIterableDataset(IterableDataset, Stateful):
def __init__(self):
self.num_calls = 0

def __iter__(self):
return self

def __next__(self):
self.num_calls += 1
return self.num_calls

def load_state_dict(self, state_dict):
self.num_calls = state_dict["num_calls"]

def state_dict(self):
return {"num_calls": self.num_calls}


def identity(x):
return x


class TestIteratorDataset(TestCase):
def test_increasing_worker(self):
ds = StatefulIterableDataset()
dl = StatefulDataLoader(ds, num_workers=2, collate_fn=identity)
it = iter(dl)
next(it)
sd = dl.state_dict()
print(sd)
del dl

parser = StateParserUtil(sd)
worker_states = parser.fetch_dataset_state()
worker_states[2] = {"num_calls": 2}
worker_states[3] = {"num_calls": 3}
parser.set_dataset_state(worker_states)

# worker state doesn't equal num workers setting
with self.assertRaises(AssertionError):
parser.get_state_dict()
parser.set_num_workers(4)

# last worker yielded id is greater than num workers
parser.set_last_worker_yielded_id(10)
with self.assertRaises(AssertionError):
parser.get_state_dict()
parser.set_last_worker_yielded_id(0)

# load the modified state
new_sd = parser.get_state_dict()
print(new_sd)
dl = StatefulDataLoader(ds, num_workers=4, collate_fn=identity)
dl.load_state_dict(new_sd)
it = iter(dl)
values = []
for _ in range(4):
values.extend(next(it))
print(values)
self.assertEqual(values, [1, 3, 4, 2])


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchdata/stateful_dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .state_parser import StateParserUtil
from .stateful import Stateful
from .stateful_dataloader import StatefulDataLoader

__all__ = ["Stateful", "StatefulDataLoader"]
__all__ = ["Stateful", "StatefulDataLoader", "StateParserUtil"]
74 changes: 74 additions & 0 deletions torchdata/stateful_dataloader/state_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Union

logger = logging.getLogger(__name__)


class StateParserUtil:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we keep these as functions? I'm not sure of the benefit of having a class here

"""
Utility class that can be used to modify state returned by the dataloader
"""

def __init__(self, state_dict: Dict[str, Any]):
self._state_dict = state_dict
self._is_multiprocess_state = "_snapshot" in self._state_dict

def fetch_dataset_state(self) -> Dict[int, Any]:
# Handle both cases of single process and multiprocess
if not self._is_multiprocess_state:
return self._state_dict["dataset_state"]
return {
state["worker_id"]: state["dataset_state"]
for _, state in self._state_dict["_snapshot"]["_worker_snapshots"].items()
}

def set_last_worker_yielded_id(self, last_worker_yielded: int) -> None:
# Ensure that this number is within the number of workers
if not self._is_multiprocess_state:
logger.warning("Cannot set last worker yielded id on a single process state dict")
return
self._state_dict["_snapshot"]["_last_yielded_worker_id"] = last_worker_yielded

def set_num_workers(self, num_workers: int) -> None:
if not self._is_multiprocess_state:
logger.warning("Cannot set num_workers on a single process state dict")
return
self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"] = num_workers

def set_dataset_state(self, dataset_state: Union[Dict[int, Any], Any]) -> None:
if not self._is_multiprocess_state:
self._state_dict["dataset_state"] = dataset_state
return

for id, state in dataset_state.items():
worker_states = self._state_dict["_snapshot"]["_worker_snapshots"]
worker_key = f"worker_{id}"
if worker_key in worker_states:
worker_states[worker_key]["dataset_state"] = state
else:
worker_states[worker_key] = {"worker_id": id, "dataset_state": state, "fetcher_state": None}

def get_state_dict(self) -> Dict[str, Any]:
# Perform validations
# a) num_workers should match worker_snapshots
# b) last yielded worker id should be within num_workers
if not self._is_multiprocess_state:
return self._state_dict

last_yielded_id = self._state_dict["_snapshot"]["_last_yielded_worker_id"]
num_workers = self._state_dict["_snapshot"]["_main_snapshot"]["_num_workers"]
worker_ids = self._state_dict["_snapshot"]["_worker_snapshots"].keys()

assert (
len(worker_ids) == num_workers
), f"Number of worker states {len(worker_ids)} should be equal to num_workers setting {num_workers}"
assert (
len(set(worker_ids)) == num_workers
), f"Worker state for all from [0, {num_workers}) should be present. Instead found state for only {worker_ids} workers"
assert last_yielded_id < num_workers, "Last yielded id should be strictly within the number of workers"
return self._state_dict
Loading