Skip to content

Commit

Permalink
2024-10-25 nightly release (43a20d0)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 25, 2024
1 parent 047654d commit 0ec19ba
Show file tree
Hide file tree
Showing 28 changed files with 88 additions and 12 deletions.
2 changes: 2 additions & 0 deletions torchrec/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -243,6 +244,7 @@ def _init_states(self) -> None:
if self._grouped_auc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -235,6 +236,7 @@ def _init_states(self) -> None:
if self._grouped_auprc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

CALIBRATION_NUM = "calibration_num"
CALIBRATION_DENOM = "calibration_denom"
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import Any, cast, Dict, List, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

CTR_NUM = "ctr_num"
CTR_DENOM = "ctr_denom"
Expand Down Expand Up @@ -61,6 +63,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


ERROR_SUM = "error_sum"
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
# pyre-fixme[14]: `update` overrides method defined in `RecMetricComputation`
# inconsistently.
def update(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def _generate_rec_metrics(
if metric_def and metric_def.arguments is not None:
kwargs = metric_def.arguments

kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile

rec_tasks: List[RecTaskInfo] = []
if metric_def.rec_tasks and metric_def.rec_task_indices:
raise ValueError(
Expand Down Expand Up @@ -468,7 +470,7 @@ def generate_metric_module(
metrics_config, world_size, my_rank, batch_size, process_group
)
"""
Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so
Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so
different training jobs have aligned mertics.
TODO: update metrics other than ThroughputMetric if it has dependency on batch_size
"""
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class MetricsConfig:
should_validate_update (bool): whether to check the inputs of update() and skip
update if the inputs are invalid. Invalid inputs include the case where all
examples have 0 weights for a batch.
enable_pt2_compile (bool): whether to enable PT2 compilation for metrics.
"""

rec_tasks: List[RecTaskInfo] = field(default_factory=list)
Expand All @@ -171,6 +172,7 @@ class MetricsConfig:
max_compute_interval: float = float("inf")
compute_on_all_ranks: bool = False
should_validate_update: bool = False
enable_pt2_compile: bool = False


DefaultTaskInfo = RecTaskInfo(
Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import Any, cast, Dict, List, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


ERROR_SUM = "error_sum"
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/multiclass_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_true_positives_at_k(
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

SUM_NDCG = "sum_ndcg"
NUM_SESSIONS = "num_sessions"
Expand Down Expand Up @@ -331,6 +332,7 @@ def __init__(
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_cross_entropy(
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ne_positive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_cross_entropy_positive(
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricException,
RecTaskInfo,
)
from torchrec.pt2.utils import pt2_compile_callable


class OutputMetricComputation(RecMetricComputation):
Expand All @@ -46,6 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=False,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/rauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -287,6 +288,7 @@ def _init_states(self) -> None:
if self._grouped_rauc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
15 changes: 4 additions & 11 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import abc
import itertools
import logging
import math
from collections import defaultdict, deque
from dataclasses import dataclass
Expand Down Expand Up @@ -47,8 +46,8 @@
MetricNamespaceBase,
MetricPrefix,
)
from torchrec.pt2.utils import pt2_compile_callable

logger: logging.Logger = logging.getLogger(__name__)

RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]]

Expand Down Expand Up @@ -138,6 +137,7 @@ def __init__(
process_group: Optional[dist.ProcessGroup] = None,
fused_update_limit: int = 0,
allow_missing_label_with_zero_weight: bool = False,
enable_pt2_compile: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -161,6 +161,7 @@ def __init__(
dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(),
persistent=True,
)
self.enable_pt2_compile = enable_pt2_compile

@staticmethod
def get_window_state_name(state_name: str) -> str:
Expand Down Expand Up @@ -246,6 +247,7 @@ def pre_compute(self) -> None:
"""
return

@pt2_compile_callable
def compute(self) -> List[MetricComputationReport]:
with record_function(f"## {self.__class__.__name__}:compute ##"):
if self._my_rank == 0 or self._compute_on_all_ranks:
Expand Down Expand Up @@ -525,24 +527,15 @@ def _update(
task_names = [task.name for task in self._tasks]

if not isinstance(predictions, torch.Tensor):
logger.info(
"Converting predictions to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
predictions = torch.stack(
[predictions[task_name] for task_name in task_names]
)

if not isinstance(labels, torch.Tensor):
logger.info(
"Converting labels to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
labels = torch.stack(
[labels[task_name] for task_name in task_names]
)
if weights is not None and not isinstance(weights, torch.Tensor):
logger.info(
"Converting weights to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
weights = torch.stack(
[weights[task_name] for task_name in task_names]
)
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/recall_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
self.run_ranking_of_labels: bool = session_metric_def.run_ranking_of_labels
self.session_var_name: Optional[str] = session_metric_def.session_var_name

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetric,
RecMetricComputation,
)
from torchrec.pt2.utils import pt2_compile_callable


class ScalarMetricComputation(RecMetricComputation):
Expand All @@ -41,6 +42,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=False,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
Loading

0 comments on commit 0ec19ba

Please sign in to comment.