Skip to content

Commit

Permalink
] Support postproc inputs to be list or dict with outputs from other …
Browse files Browse the repository at this point in the history
…postproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
  • Loading branch information
che-sh authored and facebook-github-bot committed Feb 11, 2025
1 parent 0d9729b commit 17b0810
Show file tree
Hide file tree
Showing 3 changed files with 457 additions and 191 deletions.
97 changes: 97 additions & 0 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,103 @@ def forward(
return pred.sum(), pred


class TestModelWithPreprocCollectionArgs(nn.Module):
"""
Basic module with up to 3 postproc modules:
- postproc on idlist_features for non-weighted EBC
- postproc on idscore_features for weighted EBC
- postproc_inner on model input shared by both EBCs
- postproc_outer providing input to postproc_b (aka nested postproc)
Args:
tables,
weighted_tables,
device,
postproc_module_a,
postproc_module_b,
num_float_features,
Example:
>>> TestModelWithPreprocWithListArg(tables, weighted_tables, device)
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""

CONST_DICT_KEY = "const"
INPUT_TENSOR_DICT_KEY = "tensor_from_input"
POSTPTOC_TENSOR_DICT_KEY = "tensor_from_postproc"

def __init__(
self,
tables: List[EmbeddingBagConfig],
weighted_tables: List[EmbeddingBagConfig],
device: torch.device,
postproc_module_outer: nn.Module,
postproc_module_nested: nn.Module,
num_float_features: int = 10,
) -> None:
super().__init__()
self.dense = TestDenseArch(num_float_features, device)

self.ebc: EmbeddingBagCollection = EmbeddingBagCollection(
tables=tables,
device=device,
)
self.weighted_ebc = EmbeddingBagCollection(
tables=weighted_tables,
is_weighted=True,
device=device,
)
self.postproc_nonweighted = TestPreprocNonWeighted()
self.postproc_weighted = TestPreprocWeighted()
self._postproc_module_outer = postproc_module_outer
self._postproc_module_nested = postproc_module_nested

def forward(
self,
input: ModelInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runs preproc for EBC and weighted EBC, optionally runs postproc for input
Args:
input
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
modified_input = input

outer_postproc_input = self._postproc_module_outer(modified_input)

preproc_input_list = [
1,
modified_input.float_features,
outer_postproc_input,
]
preproc_input_dict = {
self.CONST_DICT_KEY: 1,
self.INPUT_TENSOR_DICT_KEY: modified_input.float_features,
self.POSTPTOC_TENSOR_DICT_KEY: outer_postproc_input,
}

modified_input = self._postproc_module_nested(
modified_input, preproc_input_list, preproc_input_dict
)

modified_idlist_features = self.postproc_nonweighted(
modified_input.idlist_features
)
modified_idscore_features = self.postproc_weighted(
modified_input.idscore_features
)
ebc_out = self.ebc(modified_idlist_features[0])
weighted_ebc_out = self.weighted_ebc(modified_idscore_features[0])

pred = torch.cat([ebc_out.values(), weighted_ebc_out.values()], dim=1)
return pred.sum(), pred


class TestNegSamplingModule(torch.nn.Module):
"""
Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from contextlib import ExitStack
from dataclasses import dataclass
from functools import partial
from typing import cast, List, Optional, Tuple, Type, Union
from typing import cast, Dict, List, Optional, Tuple, Type, Union
from unittest.mock import MagicMock

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torch.fx._symbolic_trace import is_fx_tracing
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand All @@ -36,6 +37,7 @@
ModelInput,
TestEBCSharder,
TestModelWithPreproc,
TestModelWithPreprocCollectionArgs,
TestNegSamplingModule,
TestPositionWeightedPreprocModule,
TestSparseNN,
Expand Down Expand Up @@ -1448,6 +1450,81 @@ def forward(
self.assertEqual(len(pipeline._pipelined_modules), 2)
self.assertEqual(len(pipeline._pipelined_postprocs), 1)

# pyre-ignore
@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
def test_pipeline_postproc_with_collection_args(self) -> None:
"""
Exercises scenario when postproc module has an argument that is a list or dict
with some elements being:
* static scalars
* static tensors (e.g. torch.ones())
* tensors derived from input batch (e.g. input.idlist_features["feature_0"])
* tensors derived from input batch and other postproc module (e.g. other_postproc(input.idlist_features["feature_0"]))
"""
test_runner = self

class PostprocOuter(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
) -> torch.Tensor:
return model_input.float_features * 0.1

class PostprocInner(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(
self,
model_input: ModelInput,
input_list: List[Union[torch.Tensor, int]],
input_dict: Dict[str, Union[torch.Tensor, int]],
) -> ModelInput:
if not is_fx_tracing():
for idx, value in enumerate(input_list):
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_list[{idx}] was a fx.Node: {value}"
)
model_input.float_features += value

for key, value in input_dict.items():
if isinstance(value, torch.fx.Node):
test_runner.fail(
f"input_dict[{key}] was a fx.Node: {value}"
)
model_input.float_features += value

return model_input

model = TestModelWithPreprocCollectionArgs(
tables=self.tables[:-1], # ignore last table as postproc will remove
weighted_tables=self.weighted_tables[:-1], # ignore last table
device=self.device,
postproc_module_outer=PostprocOuter(),
postproc_module_nested=PostprocInner(),
)

pipelined_model, pipeline = self._check_output_equal(
model,
self.sharding_type,
)

# both EC end EBC are pipelined
self.assertEqual(len(pipeline._pipelined_modules), 2)
# both outer and nested postproces are pipelined
self.assertEqual(len(pipeline._pipelined_postprocs), 4)


class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
@unittest.skipIf(
Expand Down
Loading

0 comments on commit 17b0810

Please sign in to comment.