diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 84e13c85a..8bcbc0bb2 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -85,6 +85,8 @@ "__use_unflattened_lengths_for_batching" ) +MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT: str = "__use_batching_hinted_output" + DEFAULT_ROW_ALIGNMENT = 16 @@ -913,7 +915,8 @@ def forward( lengths = _get_unflattened_lengths(lengths, len(embedding_names)) lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) else: - lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) + if getattr(self, MODULE_ATTR_USE_BATCHING_HINTED_OUTPUT, True): + lookup = _get_batching_hinted_output(lengths=lengths, output=lookup) lengths = _get_unflattened_lengths(lengths, len(embedding_names)) jt = construct_jagged_tensors_inference( embeddings=lookup,