Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix collection inputs to postproc modules #2733

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

che-sh
Copy link
Contributor

@che-sh che-sh commented Feb 7, 2025

Summary:
Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.tensor([1,2,3])])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.tensor([1,2,3]) })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)

Differential Revision: D69292525


@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 7, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 11, 2025
Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 11, 2025
…postproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 12, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 12, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 12, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 13, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
Summary:

`_shard_modules` function is used in fx_traceability tests for SDD and SemiSync pipeline. It uses a default ShardingPlanner and topology that use hardcoded batch size (512) and HBM memory limit (32Gb), respectively. This change allows specifying the ShardingPlanner and Topology to more accurately reflect the machine capabilities. The change is intentionally limited to `_shard_modules` only and not public `shard_modules` to avoid changing the contract for the latter.

Reviewed By: sarckk

Differential Revision: D69163227
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69292525

che-sh added a commit to che-sh/torchrec that referenced this pull request Feb 14, 2025
…stproc modules (pytorch#2733)

Summary:

Postproc modules with collection inputs (list or dict) with non-static (derived from input or other postproc) elements were not properly rewritten - input elements remained fx.Nodes even during the actual model forward (i.e. outside rewrite, during pipeline execution)

To illustrate:

```
def forward(model_input: ...) -> ...:
    modified_input = model_input.float_features + 1
    sharded_module_input = self.postproc(model_input, modified_input)  # works
    sharded_module_input = self.postproc(model_input, [123])  # works
    sharded_module_input = self.postproc(model_input, [torch.ones_like(modified_input)])  # fails
    sharded_module_input = self.postproc(model_input, [modified_input])  # fails
    sharded_module_input = self.postproc(model_input, { 'a': 123 })  # works
    sharded_module_input = self.postproc(model_input, { 'a': torch.ones_like(modified_input) })  # fails
    sharded_module_input = self.postproc(model_input, { 'a': modified_input })  # fails

    return self.ebc(sharded_module_input)
```

Differential Revision: D69292525
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants