Skip to content

Commit

Permalink
Refactor itep logger
Browse files Browse the repository at this point in the history
Summary: put interface in the torchrec module and scuba impl in fbgemm, because it seems all GenericITEPModule needs a logger but only the references of GenericITEPModule in fbgemm module use scuba.

Differential Revision: D69308574
  • Loading branch information
peterfu0 authored and facebook-github-bot committed Feb 10, 2025
1 parent 1afbf08 commit a05d412
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
lookups=self._embedding_bag_collection._lookups,
pruning_interval=module._itep_module.pruning_interval,
enable_pruning=module._itep_module.enable_pruning,
itep_logger=module._itep_module.itep_logger,
)

def prefetch(
Expand Down
60 changes: 60 additions & 0 deletions torchrec/modules/itep_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from abc import ABC, abstractmethod
from typing import Mapping, Optional, Tuple, Union

logger: logging.Logger = logging.getLogger(__name__)


class ITEPLogger(ABC):
@abstractmethod
def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
pass

@abstractmethod
def log_run_info(
self,
) -> None:
pass


class ITEPLoggerDefault(ITEPLogger):
"""
noop logger as a default
"""

def __init__(
self,
) -> None:
"""
Initialize ITEPLoggerScuba.
"""
pass

def log_table_eviction_info(
self,
iteration: Optional[Union[bool, float, int]],
rank: Optional[int],
table_to_sizes_mapping: Mapping[str, Tuple[int, int]],
eviction_tables: Mapping[str, float],
) -> None:
logger.info(
f"iteration={iteration}, rank={rank}, table_to_sizes_mapping={table_to_sizes_mapping}, eviction_tables={eviction_tables}"
)

def log_run_info(
self,
) -> None:
pass
45 changes: 43 additions & 2 deletions torchrec/modules/itep_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_types import ShardedEmbeddingTable
from torchrec.modules.embedding_modules import reorder_inverse_indices
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault

from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor

try:
Expand Down Expand Up @@ -63,8 +65,8 @@ def __init__(
lookups: Optional[List[nn.Module]] = None,
enable_pruning: bool = True,
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
itep_logger: Optional[ITEPLogger] = None,
) -> None:

super(GenericITEPModule, self).__init__()

# Construct in-training embedding pruning args
Expand All @@ -75,6 +77,11 @@ def __init__(
table_name_to_unpruned_hash_sizes
)

self.itep_logger: ITEPLogger = (
itep_logger if itep_logger is not None else ITEPLoggerDefault()
)
self.itep_logger.log_run_info()

# Map each feature to a physical address_lookup/row_util buffer
self.feature_table_map: Dict[str, int] = {}
self.table_name_to_idx: Dict[str, int] = {}
Expand All @@ -97,6 +104,8 @@ def print_itep_eviction_stats(
cur_iter: int,
) -> None:
table_name_to_eviction_ratio = {}
buffer_idx_to_eviction_ratio = {}
buffer_idx_to_sizes = {}

num_buffers = len(self.buffer_offsets_list) - 1
for buffer_idx in range(num_buffers):
Expand All @@ -113,6 +122,8 @@ def print_itep_eviction_stats(
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
eviction_ratio
)
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length)

# Sort the mapping by eviction ratio in descending order
sorted_mapping = dict(
Expand All @@ -122,6 +133,34 @@ def print_itep_eviction_stats(
reverse=True,
)
)

logged_eviction_mapping = {}
for idx in sorted_mapping.keys():
try:
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
sorted_mapping[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

table_to_sizes_mapping = {}
for idx in buffer_idx_to_sizes.keys():
try:
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
buffer_idx_to_sizes[idx]
)
except KeyError:
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

self.itep_logger.log_table_eviction_info(
iteration=None,
rank=None,
table_to_sizes_mapping=table_to_sizes_mapping,
eviction_tables=logged_eviction_mapping,
)

# Print the sorted mapping
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")

Expand Down Expand Up @@ -263,8 +302,10 @@ def init_itep_state(self) -> None:
if self.current_device is None:
self.current_device = torch.device("cuda")

self.reversed_feature_table_map: Dict[int, str] = {
idx: feature_name for feature_name, idx in self.feature_table_map.items()
}
self.buffer_offsets_list = buffer_offsets

# Create buffers for address_lookup and row_util
self.create_itep_buffers(
buffer_size=buffer_size,
Expand Down

0 comments on commit a05d412

Please sign in to comment.