Skip to content

Commit

Permalink
Refactor _build_args_kwards into an instance method on CallArgs + Arg…
Browse files Browse the repository at this point in the history
…Info (#2742)

Summary:


Torchrec rewriting logic got a bit hairy over the years, this sequence of changes aims to refactor the rewrite logic to be less convoluted and more maintainable in the future.

This change: 
* almost all code in `_build_args_kwargs` deals with the fields of ArgInfoStep, and remaining part handles looping over `ArgInfo.steps` - so this change just colocates "behavior" (`_build_args_kwargs` logic) with data it belongs to. 
* introduces helper functions/factory methods for various types of ArgInfoStep
* encapsulates the logic of handling a `List[ArgInfo]` into a `CallArgs` class (+changes a bit - explicitly separating args nad kwargs, vs. having them differ by empty/present `ArgInfo.name` field)

Internal

Diff stack navigation:
1. D69292525 and below - before refactoring
2. D69438143 - Refactor get_node_args and friends into a class 
3. D69461227 - refactor "joint lists" in ArgInfo into a list of ArgInfoStep
4. D69461226 - refactor `_build_args_kwargs` into instance methods on ArgInfo and ArgInfoStep (**you are here**)
5. D69461228 - split monolithic `ArgInfoStep` into a class hierarchy

Differential Revision: D69461226
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 13, 2025
1 parent aeb69a3 commit 212d6c1
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 200 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_wait_for_batch, # noqa
ArgInfo, # noqa
ArgInfoStep, # noqa
CallArgs, # noqa
DataLoadingThread, # noqa
In, # noqa
Out, # noqa
Expand Down
44 changes: 25 additions & 19 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,10 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(len(ebc.forward._args[0].steps), 2)
[step1, step2] = ebc.forward._args[0].steps
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
Expand All @@ -1036,13 +1037,13 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].steps[0].postproc_module,
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module,
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1053,9 +1054,10 @@ def test_pipeline_postproc_not_shared_with_arg_transform(self) -> None:
input_attr_names = {"idlist_features", "idscore_features"}
for i in range(len(pipeline._pipelined_postprocs)):
postproc_mod = pipeline._pipelined_postprocs[i]
self.assertEqual(len(postproc_mod._args), 1)
self.assertEqual(len(postproc_mod._args[0].steps), 2)
[step1, step2] = postproc_mod._args[0].steps
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
self.assertEqual(len(postproc_mod._args.args[0].steps), 2)
[step1, step2] = postproc_mod._args.args[0].steps

self.assertTrue(step2.input_attr in input_attr_names)

Expand Down Expand Up @@ -1112,9 +1114,10 @@ def test_pipeline_postproc_recursive(self) -> None:

# Check pipelined args
for ebc in [pipelined_ebc, pipelined_weighted_ebc]:
self.assertEqual(len(ebc.forward._args), 1)
self.assertEqual(len(ebc.forward._args[0].steps), 2)
[step1, step2] = ebc.forward._args[0].steps
self.assertEqual(len(ebc.forward._args.args), 1)
self.assertEqual(len(ebc.forward._args.kwargs), 0)
self.assertEqual(len(ebc.forward._args.args[0].steps), 2)
[step1, step2] = ebc.forward._args.args[0].steps

self.assertEqual(step1.input_attr, "")
self.assertEqual(step1.is_getitem, False)
Expand All @@ -1124,13 +1127,13 @@ def test_pipeline_postproc_recursive(self) -> None:
self.assertIsNone(step2.postproc_module)

self.assertEqual(
pipelined_ebc.forward._args[0].steps[0].postproc_module,
pipelined_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
pipelined_model.module.postproc_nonweighted,
)
self.assertEqual(
pipelined_weighted_ebc.forward._args[0].steps[0].postproc_module,
pipelined_weighted_ebc.forward._args.args[0].steps[0].postproc_module,
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
pipelined_model.module.postproc_weighted,
Expand All @@ -1147,8 +1150,9 @@ def test_pipeline_postproc_recursive(self) -> None:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_nonweighted`.
if postproc_mod == pipelined_model.module.postproc_nonweighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
[step.input_attr for step in args.steps], ["", "idlist_features"]
Expand All @@ -1161,8 +1165,9 @@ def test_pipeline_postproc_recursive(self) -> None:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `postproc_weighted`.
elif postproc_mod == pipelined_model.module.postproc_weighted:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 2)
self.assertEqual(
[step.input_attr for step in args.steps], ["", "idscore_features"]
Expand All @@ -1173,8 +1178,9 @@ def test_pipeline_postproc_recursive(self) -> None:
self.assertEqual(args.steps[0].postproc_module, parent_postproc_mod)
self.assertIsNone(args.steps[1].postproc_module)
elif postproc_mod == parent_postproc_mod:
self.assertEqual(len(postproc_mod._args), 1)
args = postproc_mod._args[0]
self.assertEqual(len(postproc_mod._args.args), 1)
self.assertEqual(len(postproc_mod._args.kwargs), 0)
args = postproc_mod._args.args[0]
self.assertEqual(len(args.steps), 1)
self.assertEqual(args.steps[0].input_attr, "")
self.assertFalse(args.steps[0].is_getitem)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
TrainPipelineSparseDistTestBase,
)
from torchrec.distributed.train_pipeline.utils import (
_build_args_kwargs,
_rewrite_model,
ArgInfo,
ArgInfoStep,
CallArgs,
NodeArgsHelper,
PipelinedForward,
PipelinedPostproc,
Expand Down Expand Up @@ -261,80 +261,80 @@ def test_restore_from_snapshot(self) -> None:
@parameterized.expand(
[
(
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
name="id_list_features",
),
ArgInfo(
steps=[],
name="id_score_list_features",
),
],
CallArgs(
args=[],
kwargs={
"id_list_features": ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
]
),
# Empty attrs to ignore any attr based logic.
"id_score_list_features": ArgInfo(
steps=[],
),
},
),
0,
["id_list_features", "id_score_list_features"],
),
(
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
name=None,
),
ArgInfo(
steps=[],
name=None,
),
],
CallArgs(
args=[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
),
ArgInfo(
steps=[],
),
],
kwargs={},
),
2,
[],
),
(
[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
name=None,
),
ArgInfo(
steps=[],
name="id_score_list_features",
),
],
CallArgs(
args=[
# Empty attrs to ignore any attr based logic.
ArgInfo(
steps=[
ArgInfoStep(
input_attr="",
is_getitem=False,
postproc_module=None,
constant=None,
)
],
)
],
kwargs={"id_score_list_features": ArgInfo(steps=[])},
),
1,
["id_score_list_features"],
),
]
)
def test_build_args_kwargs(
self,
fwd_args: List[ArgInfo],
fwd_args: CallArgs,
args_len: int,
kwarges_keys: List[str],
) -> None:
args, kwargs = _build_args_kwargs("initial_input", fwd_args)
args, kwargs = fwd_args.build_args_kwargs("initial_input")
self.assertEqual(len(args), args_len)
self.assertEqual(list(kwargs.keys()), kwarges_keys)

Expand Down
Loading

0 comments on commit 212d6c1

Please sign in to comment.