Skip to content

Commit

Permalink
Replace "synchronized lists" with a list of objects in ArgInfo (pytor…
Browse files Browse the repository at this point in the history
…ch#2742)

Summary:


Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future.

This change: ArgInfo uses a "synchronized lists" pattern, having 4 attributes, each being a list, semantically representing different fields of a data structure (i.e. input_attrs[0], is_getitems[0], ... all relate to a single transformation on the input; all lists must have same number of elements). This diff refactors them into an actual list of a (new) `ArgInfoStep` class instances that encapsulate the related fields.

Internal

Diff stack navigation:
1. D69292525 and below - before refactoring
2. D69438143 - Refactor get_node_args and friends into a class 
3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep (**you are here**)
4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep
5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy

Differential Revision: D69461227
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 13, 2025
1 parent 4e39f34 commit 16d5d77
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 118 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_to_device, # noqa
_wait_for_batch, # noqa
ArgInfo, # noqa
ArgInfoStep, # noqa
DataLoadingThread, # noqa
In, # noqa
Out, # noqa
Expand Down
89 changes: 50 additions & 39 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.utils import (
ArgInfoStep,
DataLoadingThread,
get_h2d_func,
PipelinedForward,
Expand Down Expand Up @@ -1024,22 +1025,24 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
self.assertEqual(len(ebc.forward._args[0].steps), 2)
[step1, step2] = ebc.forward._args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].postproc_modules[0],
pipelined_ebc.forward._args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1051,15 +1054,18 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
for i in range(len(pipeline._pipelined_postprocs)):
postproc_mod = pipeline._pipelined_postprocs[i]
self.assertEqual(len(postproc_mod._args), 1)
self.assertEqual(len(postproc_mod._args[0].steps), 2)
[step1, step2] = postproc_mod._args[0].steps

self.assertTrue(step2.input_attr in input_attr_names)

input_attr_name = postproc_mod._args[0].input_attrs[1]
self.assertTrue(input_attr_name in input_attr_names)
self.assertEqual(postproc_mod._args[0].input_attrs, ["", input_attr_name])
input_attr_names.remove(input_attr_name)
input_attr_names.remove(step2.input_attr)

self.assertEqual(postproc_mod._args[0].is_getitems, [False, False])
self.assertFalse(step1.is_getitem)
self.assertFalse(step2.is_getitem)
# no parent postproc module in FX graph
self.assertEqual(postproc_mod._args[0].postproc_modules, [None, None])
self.assertIsNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

# pyre-ignore
@unittest.skipIf(
Expand Down Expand Up @@ -1107,22 +1113,24 @@ def test_pipeline_postproc_recursive(self) -> None:
# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(ebc.forward._args[0].input_attrs, ["", 0])
self.assertEqual(ebc.forward._args[0].is_getitems, [False, True])
self.assertEqual(len(ebc.forward._args[0].postproc_modules), 2)
self.assertIsInstance(
ebc.forward._args[0].postproc_modules[0], PipelinedPostproc
)
self.assertEqual(ebc.forward._args[0].postproc_modules[1], None)
self.assertEqual(len(ebc.forward._args[0].steps), 2)
[step1, step2] = ebc.forward._args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
self.assertEqual(step2.input_attr, 0)
self.assertEqual(step2.is_getitem, True)
self.assertIsNotNone(step1.postproc_module)
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].postproc_modules[0],
pipelined_ebc.forward._args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].postproc_modules[0],
pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1141,33 +1149,36 @@ def test_pipeline_postproc_recursive(self) -> None:
if postproc_mod == pipelined_model.module.postproc_nonweighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idlist_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(len(args.steps), 2)
self.assertEqual(
args.postproc_modules[0],
parent_postproc_mod,
[step.input_attr for step in args.steps], ["", "idlist_features"]
)
self.assertEqual(args.postproc_modules[1], None)
self.assertEqual(
[step.is_getitem for step in args.steps], [False, False]
)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
elif postproc_mod == pipelined_model.module.postproc_weighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, ["", "idscore_features"])
self.assertEqual(args.is_getitems, [False, False])
self.assertEqual(len(args.postproc_modules), 2)
self.assertEqual(len(args.steps), 2)
self.assertEqual(
[step.input_attr for step in args.steps], ["", "idscore_features"]
)
self.assertEqual(
args.postproc_modules[0],
parent_postproc_mod,
[step.is_getitem for step in args.steps], [False, False]
)
self.assertEqual(args.postproc_modules[1], None)
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
elif postproc_mod == parent_postproc_mod:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(args.input_attrs, [""])
self.assertEqual(args.is_getitems, [False])
self.assertEqual(args.postproc_modules, [None])
self.assertEqual(len(args.steps), 1)
self.assertEqual(args.steps[0].input_attr, "")
self.assertFalse(args.steps[0].is_getitem)
self.assertIsNone(args.steps[0].postproc_module)

# pyre-ignore
@unittest.skipIf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_build_args_kwargs,
_rewrite_model,
ArgInfo,
ArgInfoStep,
NodeArgsHelper,
PipelinedForward,
PipelinedPostproc,
Expand Down Expand Up @@ -110,17 +111,17 @@ def test_rewrite_model(self) -> None:
self.assertEqual(
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `sparse`.
sharded_model.module.sparse.ebc.forward._args[0].postproc_modules[0],
sharded_model.module.sparse.ebc.forward._args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_module`.
sharded_model.module.postproc_module,
)
self.assertEqual(
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `sparse`.
sharded_model.module.sparse.weighted_ebc.forward._args[0].postproc_modules[
0
],
sharded_model.module.sparse.weighted_ebc.forward._args[0]
.steps[0]
.postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_module`.
sharded_model.module.postproc_module,
Expand Down Expand Up @@ -263,19 +264,18 @@ def test_restore_from_snapshot(self) -> None:
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name="id_list_features",
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
steps=[],
name="id_score_list_features",
),
],
Expand All @@ -286,19 +286,18 @@ def test_restore_from_snapshot(self) -> None:
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name=None,
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
steps=[],
name=None,
),
],
Expand All @@ -309,19 +308,18 @@ def test_restore_from_snapshot(self) -> None:
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
input_attrs=[
"",
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
is_getitems=[False],
postproc_modules=[None],
constants=[None],
name=None,
),
ArgInfo(
input_attrs=[],
is_getitems=[],
postproc_modules=[],
constants=[],
steps=[],
name="id_score_list_features",
),
],
Expand Down
Loading

0 comments on commit 16d5d77

Please sign in to comment.