Skip to content

Commit

Permalink
2024-10-26 nightly release (f606d5c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 26, 2024
1 parent 0ec19ba commit 116d0df
Show file tree
Hide file tree
Showing 21 changed files with 104 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
}
],
"strict": true,
"version": "0.0.101703592829"
"version": "0.0.101729681899"
}
2 changes: 1 addition & 1 deletion torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
# using int64. Numpy will automatically handle dense values >= 2 ** 31.
dense_np = np.array(dense, dtype=np.int32)
del dense
sparse_np = np.array(sparse, dtype=np.int32)
sparse_np = np.array(sparse, dtype=np.int64)
del sparse
labels_np = np.array(labels, dtype=np.int32)
del labels
Expand Down
2 changes: 1 addition & 1 deletion torchrec/datasets/tests/test_criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_tsv_to_npys(self) -> None:
self.assertEqual(dense.shape, (num_rows, INT_FEATURE_COUNT))
self.assertEqual(dense.dtype, np.float32)
self.assertEqual(sparse.shape, (num_rows, CAT_FEATURE_COUNT))
self.assertEqual(sparse.dtype, np.int32)
self.assertEqual(sparse.dtype, np.int64)
self.assertEqual(labels.shape, (num_rows, 1))
self.assertEqual(labels.dtype, np.int32)

Expand Down
22 changes: 22 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down Expand Up @@ -1080,13 +1091,24 @@ def __init__(
self.table_name_to_count: Dict[str, int] = {}
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}

# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
# `ShardedEmbeddingTable`.
for idx, config in enumerate(self._config.embedding_tables):
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
self._local_rows.append(config.local_rows)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_min`.
self._weight_init_mins.append(config.get_weight_init_min())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `get_weight_init_max`.
self._weight_init_maxs.append(config.get_weight_init_max())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
# `num_embeddings`.
self._num_embeddings.append(config.num_embeddings)
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
if config.name not in self.table_name_to_count:
self.table_name_to_count[config.name] = 0
self.table_name_to_count[config.name] += 1
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[16]

#!/usr/bin/env python3

Expand Down Expand Up @@ -431,6 +432,7 @@ def transform_module(
compile_mode: CompileMode,
world_size: int,
batch_size: int,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
ctx: ContextManager,
benchmark_unsharded_module: bool = False,
) -> torch.nn.Module:
Expand Down Expand Up @@ -1051,7 +1053,6 @@ def benchmark_module(
for compile_mode in compile_modes:
if not benchmark_unsharded:
# Test sharders should have a singular sharding_type
# pyre-ignore [16]
sharder._sharding_type = sharding_type.value
# pyre-ignore [6]
benchmark_type = benchmark_type_name(compile_mode, sharding_type)
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,9 @@ def _update_local(
) -> None:
raise NotImplementedError("Inference does not support update")

# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value of
# `None`.
def _update_preproc(self, values: KeyedJaggedTensor) -> KeyedJaggedTensor:
# pyre-fixme[7]: Expected `KeyedJaggedTensor` but got implicit return value
# of `None`.
pass


Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/object_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit return
# value of `None`.
) -> Awaitable[Awaitable[torch.Tensor]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[Tensor]]` but got implicit
# return value of `None`.
pass

# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
# pyre-fixme[7]: Expected `DistOut` but got implicit return value of `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Out]` but got implicit return value of
# `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Variable[Out]]` but got implicit
# return value of `None`.
pass
4 changes: 4 additions & 0 deletions torchrec/distributed/planner/tests/test_partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ def test_different_sharding_plan(self) -> None:
for shard in sharding_option.shards:
if shard.storage and shard.rank is not None:
greedy_perf_hbm_uses[
# pyre-fixme[6]: For 1st argument expected `SupportsIndex`
# but got `Optional[int]`.
shard.rank
] += shard.storage.hbm # pyre-ignore[16]

Expand All @@ -796,6 +798,8 @@ def test_different_sharding_plan(self) -> None:
for sharding_option in sharding_options:
for shard in sharding_option.shards:
if shard.storage and shard.rank:
# pyre-fixme[6]: For 1st argument expected `SupportsIndex` but
# got `Optional[int]`.
memory_balanced_hbm_uses[shard.rank] += shard.storage.hbm

self.assertTrue(max(memory_balanced_hbm_uses) < max(greedy_perf_hbm_uses))
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
aten = torch.ops.aten # pyre-ignore[5]


# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
class LocalShardsWrapper(torch.Tensor):
"""
A wrapper class to hold local shards of a DTensor.
Expand All @@ -37,7 +35,9 @@ class LocalShardsWrapper(torch.Tensor):
"""

__slots__ = ["_local_shards", "_storage_meta"]
# pyre-fixme[13]: Attribute `_local_shards` is never initialized.
_local_shards: List[torch.Tensor]
# pyre-fixme[13]: Attribute `_storage_meta` is never initialized.
_storage_meta: TensorStorageMetadata

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ def _update_local(
deduped_ids, dedup_permutation = deterministic_dedup(ids)
shard.update(deduped_ids, values[dedup_permutation])

# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
def _update_preproc(self, values: torch.Tensor) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
pass

def update(self, ids: torch.Tensor, values: torch.Tensor) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/tests/test_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ def _wait_impl(self) -> torch.Tensor:
class AwaitableTests(unittest.TestCase):
def test_callback(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
self.assertTrue(
torch.allclose(awaitable.wait(), torch.FloatTensor([2.0, 4.0, 6.0]))
)

def test_callback_chained(self) -> None:
awaitable = AwaitableInstance()
# pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got
# `(ret: Any) -> int`.
awaitable.callbacks.append(lambda ret: 2 * ret)
awaitable.callbacks.append(lambda ret: ret**2)
self.assertTrue(
Expand Down
13 changes: 6 additions & 7 deletions torchrec/distributed/tests/test_embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(self) -> None:
torch.nn.Module(),
]

# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
def create_context(self) -> ShrdCtx:
# pyre-fixme[7]: Expected `EmbeddingBagCollectionContext` but got implicit
# return value of `None`.
pass

def input_dist(
Expand All @@ -41,19 +41,18 @@ def input_dist(
*input,
# pyre-ignore[2]
**kwargs,
) -> Awaitable[Awaitable[CompIn]]:
# pyre-fixme[7]: Expected `Awaitable[Awaitable[KJTList]]` but got implicit
# return value of `None`.
) -> Awaitable[Awaitable[CompIn]]:
pass

# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of `None`.
def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut:
# pyre-fixme[7]: Expected `List[Tensor]` but got implicit return value of
# `None`.
pass

# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got implicit
# return value of `None`.
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
# implicit return value of `None`.
pass


Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_lazy_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

tempFile = None
with tempfile.NamedTemporaryFile(delete=False) as f:
# pyre-fixme[6]: For 2nd argument expected `SupportsWrite[bytes]` but
# got `_TemporaryFileWrapper[bytes]`.
pickle.dump(gm, f)
tempFile = f

Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,7 @@ def __init__(

def get_compiled_autograd_ctx(
self,
# pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter.
) -> ContextManager:
# this allows for pipelining
# to avoid doing a sum on None
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
# other metaclasses (i.e. AwaitableMeta) for customized
# behaviors, as Generic is non-trival metaclass in
# python 3.6 and below
# pyre-fixme[21]: Could not find name `GenericMeta` in `typing` (stubbed).
from typing import GenericMeta
except ImportError:
# In python 3.7+, GenericMeta doesn't exist as it's no
Expand Down Expand Up @@ -975,6 +974,9 @@ def __init__(
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
self._qcomm_codecs_registry = qcomm_codecs_registry

# pyre-fixme[56]: Pyre doesn't yet support decorators with ParamSpec applied to
# generic functions. Consider using a context manager instead of a decorator, if
# possible.
@abc.abstractclassmethod
# pyre-ignore [3]
def shard(
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ def maybe_reset_parameters(m: nn.Module) -> None:


def maybe_annotate_embedding_event(
event: EmbeddingEvent, module_fqn: Optional[str], sharding_type: Optional[str]
event: EmbeddingEvent,
module_fqn: Optional[str],
sharding_type: Optional[str],
# pyre-fixme[24]: Generic type `AbstractContextManager` expects 2 type parameters,
# received 1.
) -> AbstractContextManager[None]:
if module_fqn and sharding_type:
annotation = f"[{event.value}]_[{module_fqn}]_[{sharding_type}]"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/inference_legacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[0, 21]

"""Torchrec Inference
Torchrec inference provides a Torch.Deploy based library for GPU inference.
Expand Down
2 changes: 2 additions & 0 deletions torchrec/linter/module_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def print_error_message(
"""
lint_item = {
"path": python_path,
# pyre-fixme[16]: `AST` has no attribute `lineno`.
"line": node.lineno,
# pyre-fixme[16]: `AST` has no attribute `col_offset`.
"char": node.col_offset + 1,
"severity": severity,
"name": name,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _test_adjust_compute_interval(
)
mock_time.time = MagicMock(return_value=0.0)

# pyre-fixme[53]: Captured variable `batch` is not annotated.
def _train(metric_module: RecMetricModule) -> float:
for _ in range(metric_module.compute_interval_steps):
metric_module.update(batch)
Expand Down
35 changes: 30 additions & 5 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ def _fx_to_list(tensor: torch.Tensor) -> List[int]:
return tensor.long().tolist()


@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


@torch.fx.wrap
def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor:
"""
Slice tensor.
"""
return tensor[start:end]


def extract_module_or_tensor_callable(
module_or_callable: Union[
Callable[[], torch.nn.Module],
Expand Down Expand Up @@ -133,6 +149,8 @@ def convert_list_of_modules_to_modulelist(
# `Iterable[torch.nn.Module]`.
len(modules)
== sizes[0]
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.PyreReadOnly[Sized]`
# but got `Iterable[Module]`.
), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}"
if len(sizes) == 1:
return torch.nn.ModuleList(modules)
Expand Down Expand Up @@ -290,20 +308,27 @@ def construct_jagged_tensors_inference(
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
reverse_indices: Optional[torch.Tensor] = None,
remove_padding: bool = False,
) -> Dict[str, JaggedTensor]:
with record_function("## construct_jagged_tensors_inference ##"):
# [F * B] -> [F, B]
unflattened_lengths = _get_unflattened_lengths(lengths, len(embedding_names))

if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)
elif remove_padding:
embeddings = _slice_1d_tensor(
embeddings, 0, unflattened_lengths.sum().item()
)

ret: Dict[str, JaggedTensor] = {}
length_per_key: List[int] = _fx_to_list(
torch.sum(lengths.view(len(embedding_names), -1), dim=1)
)

lengths = lengths.view(len(embedding_names), -1)
lengths_tuple = torch.unbind(lengths, dim=0)
length_per_key: List[int] = _fx_to_list(torch.sum(unflattened_lengths, dim=1))

lengths_tuple = torch.unbind(unflattened_lengths, dim=0)

embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

Expand Down
Loading

0 comments on commit 116d0df

Please sign in to comment.