Skip to content

Commit

Permalink
Add method to fetch if a feature is weighted
Browse files Browse the repository at this point in the history
Summary: ATT, add a method to fetch feature weighted info to help reduce dynamic at runtime

Differential Revision: D69222164
  • Loading branch information
ZhengkaiZ authored and facebook-github-bot committed Feb 6, 2025
1 parent 9269e73 commit 7fdc58b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]:
embedding_shard_metadata.append(table.local_metadata)
return embedding_shard_metadata

def features_weighted(self) -> List[bool]:
is_weighted = []
for table in self.embedding_tables:
is_weighted.extend([table.is_weighted] * table.num_features())
return is_weighted


F = TypeVar("F", bound=Multistreamable)
T = TypeVar("T")
Expand Down
9 changes: 9 additions & 0 deletions torchrec/distributed/sharding/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ def features_per_rank(self) -> List[int]:
features_per_rank.append(num_features)
return features_per_rank

def is_weighted_per_rank(self) -> List[List[bool]]:
is_weighted = []
for grouped_embedding_configs in self._grouped_embedding_configs_per_rank:
is_weighted_per_rank = []
for grouped_config in grouped_embedding_configs:
is_weighted_per_rank.extend(grouped_config.features_weighted())
is_weighted.append(is_weighted_per_rank)
return is_weighted


class TwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
"""
Expand Down

0 comments on commit 7fdc58b

Please sign in to comment.