Skip to content

Commit

Permalink
2025-02-01 nightly release (27fdfd6)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Feb 1, 2025
1 parent 7fca948 commit ee19f31
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
45 changes: 32 additions & 13 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,35 +147,48 @@ def _connect(self, dataloader_iter: Iterator[In]) -> None:
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._connected = True

def _next_batch(self, dataloader_iter: Iterator[In]) -> In:
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
return next_batch

def _wait_for_batch(self, cur_batch: In) -> None:
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)

def _backward(self, losses: torch.Tensor) -> None:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()

def _copy_batch_to_gpu(self, cur_batch: In) -> None:
with record_function("## copy_batch_to_gpu ##"):
with self._stream_context(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)

def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._connected:
self._connect(dataloader_iter)

# Fetch next batch
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
next_batch = self._next_batch(dataloader_iter)
cur_batch = self._cur_batch
assert cur_batch is not None

if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
self._wait_for_batch(cur_batch)

with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)

if self._model.training:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
self._backward(losses)

# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
with record_function("## copy_batch_to_gpu ##"):
with self._stream_context(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._copy_batch_to_gpu(cur_batch)

# Update
if self._model.training:
Expand Down Expand Up @@ -471,6 +484,14 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
if not self.enqueue_batch(dataloader_iter):
return

def _wait_for_batch(self) -> None:
with record_function("## wait_for_batch ##"):
_wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream)

def _backward(self, losses: torch.Tensor) -> None:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()

def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._model_attached:
self.attach(self._model)
Expand All @@ -486,8 +507,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

with record_function("## wait_for_batch ##"):
_wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream)
self._wait_for_batch()

if len(self.batches) >= 2:
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
Expand All @@ -504,8 +524,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:

if self._model.training:
# backward
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
self._backward(losses)

# update
with record_function("## optimizer ##"):
Expand Down
2 changes: 1 addition & 1 deletion torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def compute(self) -> Dict[str, MetricValue]:
right before logging the metrics results to the data sink.
"""
self.compute_count += 1
self.check_memory_usage(self.compute_count)
with record_function("## RecMetricModule:compute ##"):
self.check_memory_usage(self.compute_count)
ret: Dict[str, MetricValue] = {}
if self.rec_metrics:
self._adjust_compute_interval()
Expand Down

0 comments on commit ee19f31

Please sign in to comment.