Skip to content

Commit

Permalink
Execute host-to-host copies on the host.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726445733
  • Loading branch information
SandSnip3r authored and Google-ML-Automation committed Feb 13, 2025
1 parent e7c62c0 commit d5e9a18
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
14 changes: 3 additions & 11 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,6 @@ bool SetBuffersToMemorySpaceColor(
return changed;
}

void SetHostComputeFrontendAttribute(HloInstruction& host_instruction) {
FrontendAttributes frontend_attributes =
host_instruction.frontend_attributes();
frontend_attributes.mutable_map()->insert(
{kXlaComputeTypeAttr, kXlaComputeTypeHost});
host_instruction.set_frontend_attributes(frontend_attributes);
}

} // namespace

bool HostOffloader::InstructionIsAllowedBetweenMoveToHostAndDus(
Expand Down Expand Up @@ -243,7 +235,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
"memory space. Converting into host compute. This is likely to have "
"a very high overhead.",
instruction->name());
SetHostComputeFrontendAttribute(*instruction);
host_offload_utils::SetHostComputeFrontendAttribute(*instruction);
}
if (!already_saved_buffer) {
const HloInstruction* instruction =
Expand Down Expand Up @@ -1107,11 +1099,11 @@ absl::StatusOr<bool> HostOffloader::HandleDynamicUpdateSlices() {
operand_memory_space == Layout::kDefaultMemorySpace;
if (host_to_device) {
// This is only supported via host compute.
SetHostComputeFrontendAttribute(*dus);
host_offload_utils::SetHostComputeFrontendAttribute(*dus);
changed = true;
} else if (host_to_host) {
// Host to host. Execute as host compute. Also set as host memory space.
SetHostComputeFrontendAttribute(*dus);
host_offload_utils::SetHostComputeFrontendAttribute(*dus);
SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace);
changed = true;
} else if (device_to_host) {
Expand Down
3 changes: 3 additions & 0 deletions xla/hlo/transforms/host_offloading_prepare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ absl::StatusOr<bool> ConvertToCustomCall(HloModule* module) {
}
}
}
if (changed && module->has_schedule()) {
TF_RETURN_IF_ERROR(module->schedule().Update());
}
return changed;
}

Expand Down
8 changes: 8 additions & 0 deletions xla/service/host_offload_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,13 @@ bool ComputeTypeIsHost(const HloInstruction* hlo_instruction) {
kXlaComputeTypeHost);
}

void SetHostComputeFrontendAttribute(HloInstruction& host_instruction) {
FrontendAttributes frontend_attributes =
host_instruction.frontend_attributes();
frontend_attributes.mutable_map()->insert(
{kXlaComputeTypeAttr, kXlaComputeTypeHost});
host_instruction.set_frontend_attributes(frontend_attributes);
}

} // namespace host_offload_utils
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/host_offload_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ bool IsSynchronousCopyFromOrToHost(const HloInstruction* instruction);

bool ComputeTypeIsHost(const HloInstruction* hlo_instruction);

// Sets the frontend attribute of the instruction to indicate that the
// instruction should be lowered as host compute.
void SetHostComputeFrontendAttribute(HloInstruction& host_instruction);

} // namespace host_offload_utils
} // namespace xla

Expand Down

0 comments on commit d5e9a18

Please sign in to comment.