From ccfcc944c953f39a67806ed0530c4b69f599d368 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 6 Nov 2024 11:34:50 +0000 Subject: [PATCH] 2024-11-06 nightly release (509b0d2277f3b6c119a78c61cf78ffb572b0ad41) --- torchrec/distributed/planner/stats.py | 14 +++++++---- .../include/torchrec/inference/GPUExecutor.h | 2 +- .../inference_legacy/src/GPUExecutor.cpp | 2 +- torchrec/inference/modules.py | 2 ++ torchrec/modules/fp_embedding_modules.py | 23 ++++++++++++++++++- .../tests/test_fp_embedding_modules.py | 18 +++++++++++++++ torchrec/modules/utils.py | 11 +++++++++ 7 files changed, 64 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 9455b7549..bc3f090f9 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -16,6 +16,7 @@ from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.constants import BIGINT_DTYPE, NUM_POOLINGS from torchrec.distributed.planner.shard_estimators import _calculate_shard_io_sizes from torchrec.distributed.planner.storage_reservations import ( @@ -421,11 +422,14 @@ def log( if hasattr(sharder, "fused_params") and sharder.fused_params else None ) - cache_load_factor = str( - so.cache_load_factor - if so.cache_load_factor is not None - else sharder_cache_load_factor - ) + cache_load_factor = "None" + # Surfacing cache load factor does not make sense if not using uvm caching. + if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: + cache_load_factor = str( + so.cache_load_factor + if so.cache_load_factor is not None + else sharder_cache_load_factor + ) hash_size = so.tensor.shape[0] param_table.append( [ diff --git a/torchrec/inference/include/torchrec/inference/GPUExecutor.h b/torchrec/inference/include/torchrec/inference/GPUExecutor.h index 00c93668b..d2d289670 100644 --- a/torchrec/inference/include/torchrec/inference/GPUExecutor.h +++ b/torchrec/inference/include/torchrec/inference/GPUExecutor.h @@ -32,7 +32,7 @@ #include "torchrec/inference/BatchingQueue.h" #include "torchrec/inference/Observer.h" #include "torchrec/inference/ResultSplit.h" -#include "torchrec/inference/include/torchrec/inference/Observer.h" +#include "torchrec/inference/include/torchrec/inference/Observer.h" // @manual namespace torchrec { diff --git a/torchrec/inference/inference_legacy/src/GPUExecutor.cpp b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp index 38b00ad21..8178ed3f0 100644 --- a/torchrec/inference/inference_legacy/src/GPUExecutor.cpp +++ b/torchrec/inference/inference_legacy/src/GPUExecutor.cpp @@ -25,7 +25,7 @@ #include #include #include -#include +#include // @manual // remove this after we switch over to multipy externally for torchrec #ifdef FBCODE_CAFFE2 diff --git a/torchrec/inference/modules.py b/torchrec/inference/modules.py index 1dd1735bc..fb8c9c21d 100644 --- a/torchrec/inference/modules.py +++ b/torchrec/inference/modules.py @@ -488,6 +488,7 @@ def shard_quant_model( sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, device_memory_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, + ddr_cap: Optional[int] = None, ) -> Tuple[torch.nn.Module, ShardingPlan]: """ Shard a quantized TorchRec model, used for generating the most optimal model for inference and @@ -557,6 +558,7 @@ def shard_quant_model( compute_device=compute_device, local_world_size=world_size, hbm_cap=hbm_cap, + ddr_cap=ddr_cap, ) batch_size = 1 model_plan = trec_dist.planner.EmbeddingShardingPlanner( diff --git a/torchrec/modules/fp_embedding_modules.py b/torchrec/modules/fp_embedding_modules.py index b3de7d66f..2f1a3abdb 100644 --- a/torchrec/modules/fp_embedding_modules.py +++ b/torchrec/modules/fp_embedding_modules.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Dict, List, Set, Union +from typing import Dict, List, Set, Tuple, Union import torch import torch.nn as nn @@ -55,6 +55,15 @@ def apply_feature_processors_to_kjt( ) +class FeatureProcessorDictWrapper(FeatureProcessorsCollection): + def __init__(self, feature_processors: nn.ModuleDict) -> None: + super().__init__() + self._feature_processors = feature_processors + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + return apply_feature_processors_to_kjt(features, self._feature_processors) + + class FeatureProcessedEmbeddingBagCollection(nn.Module): """ FeatureProcessedEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of feature processor modules. @@ -125,6 +134,18 @@ def __init__( feature_names_set.update(table_config.feature_names) self._feature_names: List[str] = list(feature_names_set) + def split( + self, + ) -> Tuple[FeatureProcessorsCollection, EmbeddingBagCollection]: + if isinstance(self._feature_processors, nn.ModuleDict): + return ( + FeatureProcessorDictWrapper(self._feature_processors), + self._embedding_bag_collection, + ) + else: + assert isinstance(self._feature_processors, FeatureProcessorsCollection) + return self._feature_processors, self._embedding_bag_collection + def forward( self, features: KeyedJaggedTensor, diff --git a/torchrec/modules/tests/test_fp_embedding_modules.py b/torchrec/modules/tests/test_fp_embedding_modules.py index 77c63759f..ccb6d7175 100644 --- a/torchrec/modules/tests/test_fp_embedding_modules.py +++ b/torchrec/modules/tests/test_fp_embedding_modules.py @@ -95,6 +95,15 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None: self.assertEqual(pooled_embeddings.values().size(), (3, 16)) self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16]) + # Test split method, FP then EBC + fp, ebc = fp_ebc.split() + fp_kjt = fp(features) + pooled_embeddings_split = ebc(fp_kjt) + + self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings_split.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16]) + class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase): def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection: @@ -144,3 +153,12 @@ def test_position_weighted_collection_module_ebc(self) -> None: pooled_embeddings_gm_script.offset_per_key(), pooled_embeddings.offset_per_key(), ) + + # Test split method, FP then EBC + fp, ebc = fp_ebc.split() + fp_kjt = fp(features) + pooled_embeddings_split = ebc(fp_kjt) + + self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings_split.values().size(), (3, 16)) + self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16]) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index d83ed9254..c2fb835e3 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -48,6 +48,17 @@ def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor return tensor[start:end] +# PLEASE DO NOT USE THIS FUNCTION, THIS FUNCTION IS FOR BACKWARD COMPATIBILITY ONLY +# USE THE ONE IN torchrec/quant/embedding_modules.py +# TODO(@shuaoxiong): remove this function after we make sure all models switch to the new reference +@torch.fx.wrap +def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor: + """ + Unflatten lengths tensor from [F * B] to [F, B]. + """ + return lengths.view(num_features, -1) + + def extract_module_or_tensor_callable( module_or_callable: Union[ Callable[[], torch.nn.Module],