diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index f7b94b37b..2da7c3e86 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -269,6 +269,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # Mandatory dynamo configuration for Torchrec PT2 compilation torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.enable_trace_contextlib = False torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( True