Skip to content

Commit

Permalink
[XLA:GPU] Rewrite if else chains into switch statements.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726485936
  • Loading branch information
allanrenucci authored and Google-ML-Automation committed Feb 13, 2025
1 parent 83081f9 commit 5e7a670
Showing 1 changed file with 73 additions and 72 deletions.
145 changes: 73 additions & 72 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -732,19 +732,18 @@ absl::StatusOr<ModuleHandle> RocmExecutor::LoadModuleFromHsaco(
}

DeviceMemoryBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) {
if (memory_space == static_cast<int64_t>(MemoryType::kCollective)) {
return DeviceMemoryBase(DeviceAllocate(rocm_context_, size), size);
}
if (memory_space ==
static_cast<int64_t>(stream_executor::MemoryType::kHost)) {
auto result = HostAllocate(rocm_context_, size);
if (!result.ok()) {
switch (static_cast<MemoryType>(memory_space)) {
case MemoryType::kCollective:
case MemoryType::kDevice:
return DeviceMemoryBase(DeviceAllocate(rocm_context_, size), size);
case MemoryType::kHost:
if (auto result = HostAllocate(rocm_context_, size); result.ok()) {
return DeviceMemoryBase(*result, size);
}
return DeviceMemoryBase(nullptr, 0);
}
return DeviceMemoryBase(*result, size);
default:
LOG(FATAL) << "Unsupported memory space: " << memory_space;
}
CHECK_EQ(memory_space, 0);
return DeviceMemoryBase(DeviceAllocate(rocm_context_, size), size);
}
absl::StatusOr<std::unique_ptr<MemoryAllocation>>
RocmExecutor::HostMemoryAllocate(uint64_t size) {
Expand All @@ -757,68 +756,70 @@ void RocmExecutor::Deallocate(DeviceMemoryBase* mem) {

absl::StatusOr<std::unique_ptr<MemoryAllocator>>
RocmExecutor::CreateMemoryAllocator(MemoryType type) {
if (type == MemoryType::kUnified) {
return std::make_unique<GenericMemoryAllocator>(
[this](uint64_t size)
-> absl::StatusOr<std::unique_ptr<MemoryAllocation>> {
std::unique_ptr<ActivateContext> activation = Activate();
hipDeviceptr_t result = 0;
// "managed" memory is visible to both CPU and GPU.
TF_RETURN_IF_ERROR(ToStatus(
wrap::hipMallocManaged(&result, size, hipMemAttachGlobal),
"Failed to allocate managed memory"));
void* ptr = reinterpret_cast<void*>(result);
VLOG(2) << "allocated " << ptr << " for context " << rocm_context_
<< " of " << size << " bytes in unified memory";
return std::make_unique<GenericMemoryAllocation>(
ptr, size, [this](void* location, uint64_t size) {
std::unique_ptr<ActivateContext> activation = Activate();
hipDeviceptr_t pointer =
absl::bit_cast<hipDeviceptr_t>(location);
hipError_t res = wrap::hipFree(pointer);
if (res != hipSuccess) {
LOG(ERROR) << "failed to free unified memory at " << location
<< "; result: " << ToString(res);
} else {
VLOG(2) << "deallocated unified memory at " << location
<< " for context " << rocm_context_;
}
});
});
} else if (type == MemoryType::kCollective) {
return std::make_unique<GenericMemoryAllocator>(
[this](uint64_t size)
-> absl::StatusOr<std::unique_ptr<MemoryAllocation>> {
void* ptr = nullptr;
auto hipResult = wrap::hipMalloc(&ptr, size);
if (hipResult != hipSuccess) {
return absl::InternalError(absl::StrFormat(
"failed to allocate %s (%llu bytes) from device collective "
"memory: %s, "
"Last NCCL warning(error)",
tsl::strings::HumanReadableNumBytes(size), size,
hipGetErrorString(hipResult)));
}
VLOG(2) << "allocated " << ptr << " of " << size
<< " bytes of collective memory";
return std::make_unique<GenericMemoryAllocation>(
ptr, size, [this](void* location, uint64_t size) {
auto status = wrap::hipFree(location);
if (status != hipSuccess) {
LOG(ERROR) << "failed to free collective memory at "
<< location << "; result: " << status;
} else {
VLOG(2) << "deallocated collective memory at " << location;
}
});
});
} else if (type == MemoryType::kHost) {
return std::make_unique<GenericMemoryAllocator>([this](uint64_t size) {
return AllocateHostMemory(rocm_context_, size);
});
switch (type) {
case MemoryType::kUnified:
return std::make_unique<GenericMemoryAllocator>(
[this](uint64_t size)
-> absl::StatusOr<std::unique_ptr<MemoryAllocation>> {
std::unique_ptr<ActivateContext> activation = Activate();
hipDeviceptr_t result = nullptr;
// "managed" memory is visible to both CPU and GPU.
TF_RETURN_IF_ERROR(ToStatus(
wrap::hipMallocManaged(&result, size, hipMemAttachGlobal),
"Failed to allocate managed memory"));
void* ptr = reinterpret_cast<void*>(result);
VLOG(2) << "allocated " << ptr << " for context " << rocm_context_
<< " of " << size << " bytes in unified memory";
return std::make_unique<GenericMemoryAllocation>(
ptr, size, [this](void* location, uint64_t size) {
std::unique_ptr<ActivateContext> activation = Activate();
hipDeviceptr_t pointer =
absl::bit_cast<hipDeviceptr_t>(location);
hipError_t res = wrap::hipFree(pointer);
if (res != hipSuccess) {
LOG(ERROR) << "failed to free unified memory at "
<< location << "; result: " << ToString(res);
} else {
VLOG(2) << "deallocated unified memory at " << location
<< " for context " << rocm_context_;
}
});
});
case MemoryType::kCollective:
return std::make_unique<GenericMemoryAllocator>(
[](uint64_t size)
-> absl::StatusOr<std::unique_ptr<MemoryAllocation>> {
void* ptr = nullptr;
auto hipResult = wrap::hipMalloc(&ptr, size);
if (hipResult != hipSuccess) {
return absl::InternalError(absl::StrFormat(
"failed to allocate %s (%llu bytes) from device collective "
"memory: %s, "
"Last NCCL warning(error)",
tsl::strings::HumanReadableNumBytes(size), size,
hipGetErrorString(hipResult)));
}
VLOG(2) << "allocated " << ptr << " of " << size
<< " bytes of collective memory";
return std::make_unique<GenericMemoryAllocation>(
ptr, size, [](void* location, uint64_t size) {
auto status = wrap::hipFree(location);
if (status != hipSuccess) {
LOG(ERROR) << "failed to free collective memory at "
<< location << "; result: " << status;
} else {
VLOG(2) << "deallocated collective memory at " << location;
}
});
});
case MemoryType::kHost:
return std::make_unique<GenericMemoryAllocator>([this](uint64_t size) {
return AllocateHostMemory(rocm_context_, size);
});
default:
return absl::UnimplementedError(
absl::StrFormat("Unsupported memory type %d", type));
}
return absl::UnimplementedError(
absl::StrFormat("Unsupported memory type %d", type));
}

bool RocmExecutor::SynchronizeAllActivity() {
Expand Down

0 comments on commit 5e7a670

Please sign in to comment.