From 246ea235d954cf1063a31dbf492c59fe1e4e3ef3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 7 Nov 2024 11:34:53 +0000 Subject: [PATCH] 2024-11-07 nightly release (42c512c4bc9425aeacf9cca504b2212e65995c10) --- torchrec/quant/embedding_modules.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 84e13c85a..8bcbc0bb2 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -85,6 +85,8 @@ "__use_unflattened_lengths_for_batching" ) +MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output" + DEFAULT_ROW_ALIGNMENT = 16 @@ -913,7 +915,8 @@ def forward( lengths = _get_unflattened_lengths(lengths, len(embedding_names)) lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) else: - lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) + if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT, True): + lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) lengths = _get_unflattened_lengths(lengths, len(embedding_names)) jt = construct_jagged_tensors_inference( embeddings=lookup,