diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index c9b44581a..27f8c1b42 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -46,6 +46,7 @@ PartiallyMaterializedTensor, ) from torch import nn +from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard from torchrec.distributed.comm import get_local_rank, get_node_group_size from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, @@ -53,8 +54,10 @@ from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict from torchrec.distributed.embedding_types import ( compute_kernel_to_embedding_location, + DTensorMetadata, GroupedEmbeddingConfig, ) +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( Shard, ShardedTensor, @@ -213,6 +216,7 @@ class ShardParams: optimizer_states: List[Optional[Tuple[torch.Tensor]]] local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] + dtensor_metadata: List[DTensorMetadata] def get_optimizer_single_value_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, @@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( continue if table_config.name not in table_to_shard_params: table_to_shard_params[table_config.name] = ShardParams( - optimizer_states=[], local_metadata=[], embedding_weights=[] + optimizer_states=[], + local_metadata=[], + embedding_weights=[], + dtensor_metadata=[], ) optimizer_state_values = None if optimizer_states: @@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( table_to_shard_params[table_config.name].local_metadata.append( local_metadata ) + table_to_shard_params[table_config.name].dtensor_metadata.append( + table_config.dtensor_metadata + ) table_to_shard_params[table_config.name].embedding_weights.append(weight) seen_tables = set() @@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( # pyre-ignore def get_sharded_optim_state( momentum_idx: int, state_key: str - ) -> ShardedTensor: + ) -> Union[ShardedTensor, DTensor]: assert momentum_idx > 0 momentum_local_shards: List[Shard] = [] optimizer_sharded_tensor_metadata: ShardedTensorMetadata @@ -528,12 +538,41 @@ def get_sharded_optim_state( ) ) - # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. - return ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=momentum_local_shards, - sharded_tensor_metadata=optimizer_sharded_tensor_metadata, - process_group=self._pg, - ) + # Convert optimizer state to DTensor if enabled + if table_config.dtensor_metadata: + # if rowwise state we do Shard(0), regardless of how the table is sharded + if optim_state.dim() == 1: + stride = (1,) + placements = ( + (Replicate(), DTensorShard(0)) + if table_config.dtensor_metadata.mesh.ndim == 2 + else (DTensorShard(0),) + ) + else: + stride = table_config.dtensor_metadata.stride + placements = table_config.dtensor_metadata.placements + + return DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[x.tensor for x in momentum_local_shards], + local_offsets=[ # pyre-ignore[6] + x.metadata.shard_offsets + for x in momentum_local_shards + ], + ), + device_mesh=table_config.dtensor_metadata.mesh, + placements=placements, + shape=optimizer_sharded_tensor_metadata.size, + stride=stride, + run_check=False, + ) + else: + # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=momentum_local_shards, + sharded_tensor_metadata=optimizer_sharded_tensor_metadata, + process_group=self._pg, + ) num_states: int = min( # pyre-ignore diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index 15f0f65be..e7fc1e52b 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -68,10 +68,15 @@ def __new__( # we calculate the total tensor size by "concat" on second tensor dimension cat_tensor_shape = list(local_shards[0].size()) - if len(local_shards) > 1: # column-wise sharding + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding for shard in local_shards[1:]: cat_tensor_shape[1] += shard.size()[1] + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) wrapper_shape = torch.Size(cat_tensor_shape) chunks_meta = [ @@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.equal.default: cls.handle_equal, aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, } if func in dispatcher: @@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs): def handle_view(args, kwargs): view_shape = args[1] res_shards_list = [] - if ( - len(args[0].local_shards()) > 1 - and args[0].storage_metadata().size[0] == view_shape[0] - and args[0].storage_metadata().size[1] == view_shape[1] - ): - # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on - # init calls view_as() on the global tensor shape - # will fail because the view shape is not applicable to individual shards. - res_shards_list = [ - aten.view.default(shard, shard.shape, **kwargs) - for shard in args[0].local_shards() - ] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") else: # view is called per shard res_shards_list = [ @@ -220,6 +236,16 @@ def handle_clone(args, kwargs): ] return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_new_empty(args, kwargs): + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + @property def device(self) -> torch._C.device: # type: ignore[override] return ( diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 07a0f33d8..9bf975d3e 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -2158,7 +2158,6 @@ def test_sharded_quant_mc_ec_rw( eviction_policy=DistanceLFU_EvictionPolicy(), ) }, - # pyre-ignore [6] Incompatible parameter type embedding_configs=mi.tables, ), ) diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 9bc272212..20f883e19 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -88,7 +88,6 @@ def __init__( ), ManagedCollisionCollection( managed_collision_modules=mc_modules, - # pyre-ignore embedding_configs=tables, ), return_remapped_features=self._return_remapped, diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index 7dee58c33..e891e8841 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -78,7 +78,6 @@ def __init__( ), ManagedCollisionCollection( managed_collision_modules=mc_modules, - # pyre-ignore embedding_configs=tables, ), return_remapped_features=self._return_remapped, diff --git a/torchrec/inference/inference_legacy/tests/test_modules.py b/torchrec/inference/inference_legacy/tests/test_modules.py index 9557d4b52..2b4e97869 100644 --- a/torchrec/inference/inference_legacy/tests/test_modules.py +++ b/torchrec/inference/inference_legacy/tests/test_modules.py @@ -40,5 +40,6 @@ def test_quantize_shard_cuda(self) -> None: quantized_model = quantize_inference_model(model) sharded_model, _ = shard_quant_model(quantized_model) + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`. sharded_qebc = sharded_model._module.sparse.ebc self.assertEqual(len(sharded_qebc.tbes), 1) diff --git a/torchrec/metrics/hindsight_target_pr.py b/torchrec/metrics/hindsight_target_pr.py new file mode 100644 index 000000000..800052ecf --- /dev/null +++ b/torchrec/metrics/hindsight_target_pr.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +TARGET_PRECISION = "target_precision" +THRESHOLD_GRANULARITY = 1000 + + +def compute_precision( + num_true_positives: torch.Tensor, num_false_positives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_positives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_positives).double(), + ) + + +def compute_recall( + num_true_positives: torch.Tensor, num_false_negitives: torch.Tensor +) -> torch.Tensor: + return torch.where( + num_true_positives + num_false_negitives == 0.0, + 0.0, + num_true_positives / (num_true_positives + num_false_negitives), + ) + + +def compute_threshold_idx( + num_true_positives: torch.Tensor, + num_false_positives: torch.Tensor, + target_precision: float, +) -> int: + for i in range(THRESHOLD_GRANULARITY): + if ( + compute_precision(num_true_positives[i], num_false_positives[i]) + >= target_precision + ): + return i + + return THRESHOLD_GRANULARITY - 1 + + +def compute_true_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + return tp_sum + + +def compute_false_pos_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + fp_sum[i] = torch.sum(weights * ((predictions >= threshold) * (1 - labels)), -1) + return fp_sum + + +def compute_false_neg_sum( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> torch.Tensor: + predictions = predictions.double() + fn_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + fn_sum[i] = torch.sum(weights * ((predictions <= threshold) * labels), -1) + return fn_sum + + +def get_pr_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, torch.Tensor]: + if weights is None: + weights = torch.ones_like(predictions) + return { + "true_pos_sum": compute_true_pos_sum(labels, predictions, weights), + "false_pos_sum": compute_false_pos_sum(labels, predictions, weights), + "false_neg_sum": compute_false_neg_sum(labels, predictions, weights), + } + + +class HindsightTargetPRMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Hingsight Target PR. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + target_precision (float): If provided, computes the minimum threshold to achieve the target precision. + """ + + def __init__( + self, *args: Any, target_precision: float = 0.5, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self._add_state( + "true_pos_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_pos_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "false_neg_sum", + torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._target_precision: float = target_precision + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None: + raise RecMetricException( + "Inputs 'predictions' should not be None for HindsightTargetPRMetricComputation update" + ) + states = get_pr_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + true_pos_sum = cast(torch.Tensor, self.true_pos_sum) + false_pos_sum = cast(torch.Tensor, self.false_pos_sum) + false_neg_sum = cast(torch.Tensor, self.false_neg_sum) + threshold_idx = compute_threshold_idx( + true_pos_sum, + false_pos_sum, + self._target_precision, + ) + window_threshold_idx = compute_threshold_idx( + self.get_window_state("true_pos_sum"), + self.get_window_state("false_pos_sum"), + self._target_precision, + ) + reports = [ + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PR, + metric_prefix=MetricPrefix.LIFETIME, + value=torch.Tensor(threshold_idx), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PR, + metric_prefix=MetricPrefix.WINDOW, + value=torch.Tensor(window_threshold_idx), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PRECISION, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_precision( + true_pos_sum[threshold_idx], + false_pos_sum[threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_PRECISION, + metric_prefix=MetricPrefix.WINDOW, + value=compute_precision( + self.get_window_state("true_pos_sum")[window_threshold_idx], + self.get_window_state("false_pos_sum")[window_threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_RECALL, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_recall( + true_pos_sum[threshold_idx], + false_neg_sum[threshold_idx], + ), + ), + MetricComputationReport( + name=MetricName.HINDSIGHT_TARGET_RECALL, + metric_prefix=MetricPrefix.WINDOW, + value=compute_recall( + self.get_window_state("true_pos_sum")[window_threshold_idx], + self.get_window_state("false_neg_sum")[window_threshold_idx], + ), + ), + ] + return reports + + +class HindsightTargetPRMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.HINDSIGHT_TARGET_PR + _computation_class: Type[RecMetricComputation] = HindsightTargetPRMetricComputation diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index b7228acbe..0be8329f1 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -24,6 +24,7 @@ from torchrec.metrics.cali_free_ne import CaliFreeNEMetric from torchrec.metrics.calibration import CalibrationMetric from torchrec.metrics.ctr import CTRMetric +from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric from torchrec.metrics.mae import MAEMetric from torchrec.metrics.metrics_config import ( BatchSizeStage, @@ -94,6 +95,7 @@ RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric, RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric, RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric, + RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric, } diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index ac9edf440..e85867862 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -47,6 +47,7 @@ class RecMetricEnum(RecMetricEnumBase): TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" CALI_FREE_NE = "cali_free_ne" UNWEIGHTED_NE = "unweighted_ne" + HINDSIGHT_TARGET_PR = "hindsight_target_pr" @dataclass(unsafe_hash=True, eq=True) diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 55dbd72e2..1afd83e60 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -82,6 +82,10 @@ class MetricName(MetricNameBase): CALI_FREE_NE = "cali_free_ne" UNWEIGHTED_NE = "unweighted_ne" + HINDSIGHT_TARGET_PR = "hindsight_target_pr" + HINDSIGHT_TARGET_PRECISION = "hindsight_target_precision" + HINDSIGHT_TARGET_RECALL = "hindsight_target_recall" + class MetricNamespaceBase(StrValueMixin, Enum): pass @@ -131,6 +135,8 @@ class MetricNamespace(MetricNamespaceBase): CALI_FREE_NE = "cali_free_ne" UNWEIGHTED_NE = "unweighted_ne" + HINDSIGHT_TARGET_PR = "hindsight_target_pr" + class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" diff --git a/torchrec/metrics/tests/test_hindsight_target_pr.py b/torchrec/metrics/tests/test_hindsight_target_pr.py new file mode 100644 index 000000000..2fd9102c8 --- /dev/null +++ b/torchrec/metrics/tests/test_hindsight_target_pr.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.hindsight_target_pr import ( + compute_precision, + compute_recall, + compute_threshold_idx, + HindsightTargetPRMetric, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_value_test_launcher, + TestMetric, +) + + +WORLD_SIZE = 4 +THRESHOLD_GRANULARITY = 1000 + + +class TestHindsightTargetPRMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return torch.Tensor(threshold_idx) + + +class TestHindsightTargetPrecisionMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return compute_precision( + states["true_pos_sum"][threshold_idx], + states["false_pos_sum"][threshold_idx], + ) + + +class TestHindsightTargetRecallMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + predictions = predictions.double() + tp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fp_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + fn_sum = torch.zeros(THRESHOLD_GRANULARITY, dtype=torch.double) + thresholds = torch.linspace(0, 1, steps=THRESHOLD_GRANULARITY) + for i, threshold in enumerate(thresholds): + tp_sum[i] = torch.sum(weights * ((predictions >= threshold) * labels), -1) + fp_sum[i] = torch.sum( + weights * ((predictions >= threshold) * (1 - labels)), -1 + ) + fn_sum[i] = torch.sum(weights * ((predictions <= threshold) * labels), -1) + return { + "true_pos_sum": tp_sum, + "false_pos_sum": fp_sum, + "false_neg_sum": fn_sum, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + threshold_idx = compute_threshold_idx( + states["true_pos_sum"], states["false_pos_sum"], 0.5 + ) + return compute_recall( + states["true_pos_sum"][threshold_idx], + states["false_neg_sum"][threshold_idx], + ) + + +# Fused tests are not supported for this metric. +class TestHindsightTargetPRMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = HindsightTargetPRMetric + pr_task_name: str = "hindsight_target_pr" + precision_task_name: str = "hindsight_target_precision" + recall_task_name: str = "hindsight_target_recall" + + def test_unfused_hindsight_target_precision(self) -> None: + rec_metric_value_test_launcher( + target_clazz=HindsightTargetPRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestHindsightTargetPrecisionMetric, + metric_name=TestHindsightTargetPRMetricTest.precision_task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unfused_hindsight_target_recall(self) -> None: + rec_metric_value_test_launcher( + target_clazz=HindsightTargetPRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestHindsightTargetRecallMetric, + metric_name=TestHindsightTargetPRMetricTest.recall_task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index c9621050c..4693fd39c 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -11,7 +11,7 @@ import abc from logging import getLogger, Logger -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import torch @@ -300,7 +300,7 @@ class ManagedCollisionCollection(nn.Module): def __init__( self, managed_collision_modules: Dict[str, ManagedCollisionModule], - embedding_configs: List[BaseEmbeddingConfig], + embedding_configs: Sequence[BaseEmbeddingConfig], need_preprocess: bool = True, ) -> None: super().__init__() @@ -351,7 +351,7 @@ def _create_feature_order( if features_order != list(range(len(features_order))): self._features_order = features_order - def embedding_configs(self) -> List[BaseEmbeddingConfig]: + def embedding_configs(self) -> Sequence[BaseEmbeddingConfig]: return self._embedding_configs def forward( diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index cf04b86da..58c4fa466 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -81,13 +81,11 @@ def test_zch_ebc_ec_train(self) -> None: } mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, - # pyre-ignore[6] embedding_configs=embedding_bag_configs, ) mcc_ec = ManagedCollisionCollection( managed_collision_modules=deepcopy(mc_modules), - # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( @@ -282,13 +280,11 @@ def test_zch_ebc_ec_eval(self) -> None: } mcc_ebc = ManagedCollisionCollection( managed_collision_modules=mc_modules, - # pyre-ignore[6] embedding_configs=embedding_bag_configs, ) mcc_ec = ManagedCollisionCollection( managed_collision_modules=deepcopy(mc_modules), - # pyre-ignore[6] embedding_configs=embedding_configs, ) mc_ebc = ManagedCollisionEmbeddingBagCollection( @@ -409,7 +405,6 @@ def test_mc_collection_traceable(self) -> None: } mcc = ManagedCollisionCollection( managed_collision_modules=mc_modules, - # pyre-ignore[6] embedding_configs=embedding_configs, ) mcc.train(False) diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index edd587db2..a55bf6893 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -27,7 +27,7 @@ from torch import optim from torch.distributed._shard.sharded_tensor import ShardedTensor - +from torch.distributed.tensor import DTensor OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer] @@ -111,6 +111,8 @@ def _update_param_state_dict_object( param_state_dict_to_load: Dict[str, Any], parent_keys: List[Union[str, int, float, bool, None]], ) -> None: + # Import at function level to avoid circular dependency. + from torchrec.distributed.shards_wrapper import LocalShardsWrapper for k, v in current_param_state_dict.items(): new_v = param_state_dict_to_load[k] @@ -134,6 +136,23 @@ def _update_param_state_dict_object( ) for shard, new_shard in zip(v.local_shards(), new_v.local_shards()): shard.tensor.detach().copy_(new_shard.tensor) + elif isinstance(v, DTensor): + assert isinstance(new_v, DTensor) + if isinstance(v.to_local(), LocalShardsWrapper): + assert isinstance(new_v.to_local(), LocalShardsWrapper) + num_shards = len(v.to_local().local_shards()) # pyre-ignore[16] + num_new_shards = len(new_v.to_local().local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}" + ) + for shard, new_shard in zip( + v.to_local().local_shards(), new_v.to_local().local_shards() + ): + shard.detach().copy_(new_shard) + else: + assert isinstance(new_v.to_local(), torch.Tensor) + v.detach().copy_(new_v) elif isinstance(v, torch.Tensor): v.detach().copy_(new_v) else: diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 7368a04b5..9c9ed2faf 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -10,7 +10,7 @@ import copy import itertools from collections import defaultdict -from typing import Callable, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import torch import torch.nn as nn @@ -251,7 +251,7 @@ def _get_device(module: nn.Module) -> torch.device: def _update_embedding_configs( - embedding_configs: List[BaseEmbeddingConfig], + embedding_configs: Sequence[BaseEmbeddingConfig], quant_config: Union[QuantConfig, torch.quantization.QConfig], tables_to_rows_post_pruning: Optional[Dict[str, int]] = None, ) -> None: