From 0ec19ba48ee681ed51ab5b1962adb3a18b77d9d1 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Fri, 25 Oct 2024 11:35:01 +0000 Subject: [PATCH] 2024-10-25 nightly release (43a20d0e564a5d032d5ec0b653b0311b5e828e96) --- torchrec/metrics/accuracy.py | 2 ++ torchrec/metrics/auc.py | 2 ++ torchrec/metrics/auprc.py | 2 ++ torchrec/metrics/calibration.py | 2 ++ torchrec/metrics/ctr.py | 3 +++ torchrec/metrics/mae.py | 2 ++ torchrec/metrics/metric_module.py | 4 +++- torchrec/metrics/metrics_config.py | 2 ++ torchrec/metrics/mse.py | 3 +++ torchrec/metrics/multiclass_recall.py | 2 ++ torchrec/metrics/ndcg.py | 2 ++ torchrec/metrics/ne.py | 2 ++ torchrec/metrics/ne_positive.py | 2 ++ torchrec/metrics/output.py | 2 ++ torchrec/metrics/precision.py | 2 ++ torchrec/metrics/rauc.py | 2 ++ torchrec/metrics/rec_metric.py | 15 ++++--------- torchrec/metrics/recall.py | 2 ++ torchrec/metrics/recall_session.py | 2 ++ torchrec/metrics/scalar.py | 2 ++ torchrec/metrics/segmented_ne.py | 2 ++ torchrec/metrics/serving_calibration.py | 2 ++ torchrec/metrics/serving_ne.py | 2 ++ torchrec/metrics/tensor_weighted_avg.py | 2 ++ torchrec/metrics/tower_qps.py | 2 ++ torchrec/metrics/weighted_avg.py | 2 ++ torchrec/metrics/xauc.py | 2 ++ torchrec/pt2/utils.py | 29 +++++++++++++++++++++++++ 28 files changed, 88 insertions(+), 12 deletions(-) diff --git a/torchrec/metrics/accuracy.py b/torchrec/metrics/accuracy.py index 95537aab3..7d09f1736 100644 --- a/torchrec/metrics/accuracy.py +++ b/torchrec/metrics/accuracy.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable THRESHOLD = "threshold" @@ -84,6 +85,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: ) self._threshold: float = threshold + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 3ce5d054b..7cbb45b7e 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -21,6 +21,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable PREDICTIONS = "predictions" LABELS = "labels" @@ -243,6 +244,7 @@ def _init_states(self) -> None: if self._grouped_auc: getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/auprc.py b/torchrec/metrics/auprc.py index 89ab5f799..309b1b4b3 100644 --- a/torchrec/metrics/auprc.py +++ b/torchrec/metrics/auprc.py @@ -21,6 +21,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable PREDICTIONS = "predictions" LABELS = "labels" @@ -235,6 +236,7 @@ def _init_states(self) -> None: if self._grouped_auprc: getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/calibration.py b/torchrec/metrics/calibration.py index fc6bb91d9..d7a7e2b7a 100644 --- a/torchrec/metrics/calibration.py +++ b/torchrec/metrics/calibration.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable CALIBRATION_NUM = "calibration_num" CALIBRATION_DENOM = "calibration_denom" @@ -65,6 +66,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/ctr.py b/torchrec/metrics/ctr.py index c3ac30758..f4e1568c9 100644 --- a/torchrec/metrics/ctr.py +++ b/torchrec/metrics/ctr.py @@ -10,6 +10,7 @@ from typing import Any, cast, Dict, List, Optional, Type import torch + from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( MetricComputationReport, @@ -17,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable CTR_NUM = "ctr_num" CTR_DENOM = "ctr_denom" @@ -61,6 +63,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/mae.py b/torchrec/metrics/mae.py index 9b0439d45..89fc74ff7 100644 --- a/torchrec/metrics/mae.py +++ b/torchrec/metrics/mae.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable ERROR_SUM = "error_sum" @@ -72,6 +73,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable # pyre-fixme[14]: `update` overrides method defined in `RecMetricComputation` # inconsistently. def update( diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index f757d4ad7..d04a4aa5e 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -393,6 +393,8 @@ def _generate_rec_metrics( if metric_def and metric_def.arguments is not None: kwargs = metric_def.arguments + kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile + rec_tasks: List[RecTaskInfo] = [] if metric_def.rec_tasks and metric_def.rec_task_indices: raise ValueError( @@ -468,7 +470,7 @@ def generate_metric_module( metrics_config, world_size, my_rank, batch_size, process_group ) """ - Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so + Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so different training jobs have aligned mertics. TODO: update metrics other than ThroughputMetric if it has dependency on batch_size """ diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index 6875f2907..7ff5af552 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -158,6 +158,7 @@ class MetricsConfig: should_validate_update (bool): whether to check the inputs of update() and skip update if the inputs are invalid. Invalid inputs include the case where all examples have 0 weights for a batch. + enable_pt2_compile (bool): whether to enable PT2 compilation for metrics. """ rec_tasks: List[RecTaskInfo] = field(default_factory=list) @@ -171,6 +172,7 @@ class MetricsConfig: max_compute_interval: float = float("inf") compute_on_all_ranks: bool = False should_validate_update: bool = False + enable_pt2_compile: bool = False DefaultTaskInfo = RecTaskInfo( diff --git a/torchrec/metrics/mse.py b/torchrec/metrics/mse.py index 4941c5e40..86a92beeb 100644 --- a/torchrec/metrics/mse.py +++ b/torchrec/metrics/mse.py @@ -10,6 +10,7 @@ from typing import Any, cast, Dict, List, Optional, Type import torch + from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix from torchrec.metrics.rec_metric import ( MetricComputationReport, @@ -17,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable ERROR_SUM = "error_sum" @@ -80,6 +82,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/multiclass_recall.py b/torchrec/metrics/multiclass_recall.py index 97fe71c57..b91c83d98 100644 --- a/torchrec/metrics/multiclass_recall.py +++ b/torchrec/metrics/multiclass_recall.py @@ -18,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable def compute_true_positives_at_k( @@ -109,6 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/ndcg.py b/torchrec/metrics/ndcg.py index dca38b0d9..84174dfa0 100644 --- a/torchrec/metrics/ndcg.py +++ b/torchrec/metrics/ndcg.py @@ -19,6 +19,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable SUM_NDCG = "sum_ndcg" NUM_SESSIONS = "num_sessions" @@ -331,6 +332,7 @@ def __init__( persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/ne.py b/torchrec/metrics/ne.py index 41f14a92e..9dc58f80b 100644 --- a/torchrec/metrics/ne.py +++ b/torchrec/metrics/ne.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable def compute_cross_entropy( @@ -148,6 +149,7 @@ def __init__( ) self.eta = 1e-12 + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/ne_positive.py b/torchrec/metrics/ne_positive.py index 2d2147f3d..1ec2e7a1c 100644 --- a/torchrec/metrics/ne_positive.py +++ b/torchrec/metrics/ne_positive.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable def compute_cross_entropy_positive( @@ -130,6 +131,7 @@ def __init__( ) self.eta = 1e-12 + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/output.py b/torchrec/metrics/output.py index 5f9d0de54..8b3a636dd 100644 --- a/torchrec/metrics/output.py +++ b/torchrec/metrics/output.py @@ -21,6 +21,7 @@ RecMetricException, RecTaskInfo, ) +from torchrec.pt2.utils import pt2_compile_callable class OutputMetricComputation(RecMetricComputation): @@ -46,6 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=False, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/precision.py b/torchrec/metrics/precision.py index c069077bb..3104f6039 100644 --- a/torchrec/metrics/precision.py +++ b/torchrec/metrics/precision.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable THRESHOLD = "threshold" @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: ) self._threshold: float = threshold + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/rauc.py b/torchrec/metrics/rauc.py index 03d711657..7a62bc0a1 100644 --- a/torchrec/metrics/rauc.py +++ b/torchrec/metrics/rauc.py @@ -21,6 +21,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable PREDICTIONS = "predictions" LABELS = "labels" @@ -287,6 +288,7 @@ def _init_states(self) -> None: if self._grouped_rauc: getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device)) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 67df6c29f..4d0e7dc7b 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -11,7 +11,6 @@ import abc import itertools -import logging import math from collections import defaultdict, deque from dataclasses import dataclass @@ -47,8 +46,8 @@ MetricNamespaceBase, MetricPrefix, ) +from torchrec.pt2.utils import pt2_compile_callable -logger: logging.Logger = logging.getLogger(__name__) RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]] @@ -138,6 +137,7 @@ def __init__( process_group: Optional[dist.ProcessGroup] = None, fused_update_limit: int = 0, allow_missing_label_with_zero_weight: bool = False, + enable_pt2_compile: bool = False, *args: Any, **kwargs: Any, ) -> None: @@ -161,6 +161,7 @@ def __init__( dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(), persistent=True, ) + self.enable_pt2_compile = enable_pt2_compile @staticmethod def get_window_state_name(state_name: str) -> str: @@ -246,6 +247,7 @@ def pre_compute(self) -> None: """ return + @pt2_compile_callable def compute(self) -> List[MetricComputationReport]: with record_function(f"## {self.__class__.__name__}:compute ##"): if self._my_rank == 0 or self._compute_on_all_ranks: @@ -525,24 +527,15 @@ def _update( task_names = [task.name for task in self._tasks] if not isinstance(predictions, torch.Tensor): - logger.info( - "Converting predictions to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" - ) predictions = torch.stack( [predictions[task_name] for task_name in task_names] ) if not isinstance(labels, torch.Tensor): - logger.info( - "Converting labels to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" - ) labels = torch.stack( [labels[task_name] for task_name in task_names] ) if weights is not None and not isinstance(weights, torch.Tensor): - logger.info( - "Converting weights to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" - ) weights = torch.stack( [weights[task_name] for task_name in task_names] ) diff --git a/torchrec/metrics/recall.py b/torchrec/metrics/recall.py index 5031045c7..a52b517ce 100644 --- a/torchrec/metrics/recall.py +++ b/torchrec/metrics/recall.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable THRESHOLD = "threshold" @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None: ) self._threshold: float = threshold + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/recall_session.py b/torchrec/metrics/recall_session.py index 29cd4085d..282813fa7 100644 --- a/torchrec/metrics/recall_session.py +++ b/torchrec/metrics/recall_session.py @@ -21,6 +21,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable logger: logging.Logger = logging.getLogger(__name__) @@ -128,6 +129,7 @@ def __init__( self.run_ranking_of_labels: bool = session_metric_def.run_ranking_of_labels self.session_var_name: Optional[str] = session_metric_def.session_var_name + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/scalar.py b/torchrec/metrics/scalar.py index 6f8759b88..cd51bff06 100644 --- a/torchrec/metrics/scalar.py +++ b/torchrec/metrics/scalar.py @@ -17,6 +17,7 @@ RecMetric, RecMetricComputation, ) +from torchrec.pt2.utils import pt2_compile_callable class ScalarMetricComputation(RecMetricComputation): @@ -41,6 +42,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=False, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/segmented_ne.py b/torchrec/metrics/segmented_ne.py index f75565665..b5aad2994 100644 --- a/torchrec/metrics/segmented_ne.py +++ b/torchrec/metrics/segmented_ne.py @@ -19,6 +19,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable PREDICTIONS = "predictions" LABELS = "labels" @@ -206,6 +207,7 @@ def __init__( ) self.eta = 1e-12 + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/serving_calibration.py b/torchrec/metrics/serving_calibration.py index 5b5cf1fdc..dc9277b78 100644 --- a/torchrec/metrics/serving_calibration.py +++ b/torchrec/metrics/serving_calibration.py @@ -18,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable CALIBRATION_NUM = "calibration_num" CALIBRATION_DENOM = "calibration_denom" @@ -57,6 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/serving_ne.py b/torchrec/metrics/serving_ne.py index 37b868828..38a7f902f 100644 --- a/torchrec/metrics/serving_ne.py +++ b/torchrec/metrics/serving_ne.py @@ -18,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable NUM_EXAMPLES = "num_examples" @@ -98,6 +99,7 @@ def _get_bucket_metric_states( eta=self.eta, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/tensor_weighted_avg.py b/torchrec/metrics/tensor_weighted_avg.py index e979bbff3..35144e9e2 100644 --- a/torchrec/metrics/tensor_weighted_avg.py +++ b/torchrec/metrics/tensor_weighted_avg.py @@ -18,6 +18,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor: @@ -54,6 +55,7 @@ def __init__( persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/tower_qps.py b/torchrec/metrics/tower_qps.py index 8e72824c6..853657a72 100644 --- a/torchrec/metrics/tower_qps.py +++ b/torchrec/metrics/tower_qps.py @@ -22,6 +22,7 @@ RecMetricException, RecModelOutput, ) +from torchrec.pt2.utils import pt2_compile_callable WARMUP_STEPS = 100 @@ -78,6 +79,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._previous_ts = 0 self._steps = 0 + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/weighted_avg.py b/torchrec/metrics/weighted_avg.py index 4d8466a3a..295ede81f 100644 --- a/torchrec/metrics/weighted_avg.py +++ b/torchrec/metrics/weighted_avg.py @@ -16,6 +16,7 @@ RecMetric, RecMetricComputation, ) +from torchrec.pt2.utils import pt2_compile_callable def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor: @@ -40,6 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable def update( self, *, diff --git a/torchrec/metrics/xauc.py b/torchrec/metrics/xauc.py index 31c747ced..fcf91c5ec 100644 --- a/torchrec/metrics/xauc.py +++ b/torchrec/metrics/xauc.py @@ -17,6 +17,7 @@ RecMetricComputation, RecMetricException, ) +from torchrec.pt2.utils import pt2_compile_callable ERROR_SUM = "error_sum" @@ -101,6 +102,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: persistent=True, ) + @pt2_compile_callable # pyre-fixme[14]: `update` overrides method defined in `RecMetricComputation` # inconsistently. def update( diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index 4f62b0998..e62a9a6a4 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -8,6 +8,9 @@ # pyre-strict +import functools +from typing import Any, Callable + import torch from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -151,3 +154,29 @@ def size(self): def deregister_fake_classes() -> None: torch._library.fake_class_registry.deregister_fake_class("fbgemm::AtomicCounter") torch._library.fake_class_registry.deregister_fake_class("fbgemm::TensorQueue") + + +# pyre-ignore[24] +def pt2_compile_callable(f: Callable) -> Callable: + """ + This method is used to decorate the update and compute methods of a metric computation class. + If the metric computation class has enable_pt2_compile attribute set to True, + then the update and compute methods will be compiled using torch.compile. + """ + + @functools.wraps(f) + # pyre-ignore[3] + def inner_forward( + ref: torch.nn.Module, + *args: Any, + **kwargs: Any, + ) -> Any: + if hasattr(ref, "enable_pt2_compile") and ref.enable_pt2_compile: + pt2_compiled_attr_name = f"_{f.__name__}_pt2_compiled" + if not hasattr(ref, pt2_compiled_attr_name): + setattr(ref, pt2_compiled_attr_name, torch.compile(f)) + return getattr(ref, pt2_compiled_attr_name)(ref, *args, **kwargs) + + return f(ref, *args, **kwargs) + + return inner_forward