Skip to content

Commit

Permalink
PR #19834: [ds-fusion]Add support for async dynamic slice fusion
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19834

This patch adds async handling to dynamic slice fusion when the hero operation is a collective operation. Currently, only reduce-scatter is supported as a hero operation in dynamic slice thunk, so this patch also follows the same.

Added a test with compute, to ensure that communication and compute overlap in the thunks emitted.
Copybara import of the project:

--
13553a1 by Shraiysh Vaishay <[email protected]>:

Add support for async dynamic slice fusion

This patch adds async handling to dynamic slice fusion when the hero
operation is a collective operation. Currently, only reduce-scatter is
supported as a hero operation in dynamic slice thunk, so this patch also
follows the same.

Added a test with compute, to ensure that communication and compute
overlap in the thunks emitted.

--
e2c5986 by Shraiysh Vaishay <[email protected]>:

Addressed comments.

--
ae73bb0 by Shraiysh Vaishay <[email protected]>:

Rebase and fix build errors

--
c68b096 by Shraiysh Vaishay <[email protected]>:

Rebase

--
7609238 by Shraiysh Vaishay <[email protected]>:

Addressed comments

--
6b47e68 by Shraiysh Vaishay <[email protected]>:

Address comments

--
12890a5 by Shraiysh Vaishay <[email protected]>:

Address comments

--
a585534 by Shraiysh Vaishay <[email protected]>:

Fix Executable -> OpaueExecutable

Merging this change closes #19834

COPYBARA_INTEGRATE_REVIEW=#19834 from shraiysh:async-dynamic-slice-fusion a585534
PiperOrigin-RevId: 725980396
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Feb 12, 2025
1 parent c68d132 commit 4394aa9
Show file tree
Hide file tree
Showing 16 changed files with 503 additions and 153 deletions.
2 changes: 1 addition & 1 deletion xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ xla_test(
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_runner_interface",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_executable",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory",
Expand Down
33 changes: 25 additions & 8 deletions xla/backends/gpu/codegen/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,18 @@ absl::Status CollectSliceInfo(
if (arg_slice_instr == nullptr) {
return absl::OkStatus();
}
std::optional<HloInstruction*> async_caller = std::nullopt;
if (fusion_instr.parent()->IsAsyncComputation()) {
async_caller = fusion_instr.parent()->AsyncStart();
}

std::vector<DynamicSliceThunk::Offset> arg_offsets;
for (auto idx_op : arg_slice_instr->index_operands()) {
const auto* param = Cast<HloParameterInstruction>(idx_op);
const auto* offset_value = fusion_instr.operand(param->parameter_number());
const HloInstruction* offset_value =
async_caller.has_value()
? (*async_caller)->operand(param->parameter_number())
: fusion_instr.operand(param->parameter_number());

VLOG(2) << "Offset value:" << offset_value->ToString();

Expand Down Expand Up @@ -893,7 +900,9 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
fusion_instr.backend_config<xla::gpu::GpuBackendConfig>());
const std::string fusion_name =
backend_config.fusion_backend_config().custom_fusion_config().name();
TF_RET_CHECK(isDynamic == (fusion_name == "dynamic_address_computation"))
TF_RET_CHECK(isDynamic ==
(fusion_name ==
kDynamicSliceFusionWithDynamicAddressComputationConfigName))
<< "Dynamic index operation found in a fusion instruction that is not "
"labelled dynamic_address_computation";

Expand Down Expand Up @@ -961,13 +970,21 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
/*destination_value=*/nullptr});
auto collective_start_thunk =
std::make_unique<NcclThunkType>(thunk_info, instr, buffers);
auto collective_done_thunk = std::make_unique<NcclCollectiveDoneThunk>(
/*kind=*/collective_done_thunk_kind,
/*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*async_events=*/collective_start_thunk->async_events(),
/*async_stream_kind=*/AsyncStreamKind::kCollective);
std::shared_ptr<NcclCollectiveThunk::AsyncEvents> async_events =
collective_start_thunk->async_events();
seq.emplace_back(std::move(collective_start_thunk));
seq.emplace_back(std::move(collective_done_thunk));
// If the fusion is async, we do not emit the done thunk at the end.
if (fusion_instr.parent()->IsAsyncComputation()) {
ir_emitter_context.collectives_async_events().insert(
{fusion_instr.parent()->AsyncStart(), async_events});
} else {
auto collective_done_thunk = std::make_unique<NcclCollectiveDoneThunk>(
/*kind=*/collective_done_thunk_kind,
/*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*async_events=*/async_events,
/*async_stream_kind=*/AsyncStreamKind::kCollective);
seq.emplace_back(std::move(collective_done_thunk));
}
} else {
return implementable_status;
}
Expand Down
127 changes: 109 additions & 18 deletions xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
Expand All @@ -35,8 +35,8 @@ limitations under the License.
#include "xla/hlo/testlib/filecheck.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
Expand Down Expand Up @@ -83,6 +83,14 @@ namespace xla {
namespace gpu {
namespace {

using ::testing::ElementsAre;
using ::testing::Optional;
using ::testing::VariantWith;

MATCHER_P(ThunkKindIs, kind, "") {
return ExplainMatchResult(::testing::Eq(kind), arg->kind(), result_listener);
}

class DynamicSliceFusionTest : public HloTestBase {
public:
HloModuleConfig GetModuleConfigWithoutCommandBuffer() {
Expand All @@ -107,17 +115,8 @@ class DynamicSliceFusionTest : public HloTestBase {
if (!computation->IsFusionComputation()) {
continue;
}
auto backend_config = computation->FusionInstruction()
->backend_config<xla::gpu::GpuBackendConfig>();
if (backend_config.ok()) {
const FusionBackendConfig& fusion_backend_config =
backend_config.value().fusion_backend_config();
const std::string name =
fusion_backend_config.custom_fusion_config().name();
if (name == "dynamic_address_computation" ||
name == "address_computation") {
computations.push_back(computation);
}
if (IsDynamicSliceFusion(computation->FusionInstruction())) {
computations.push_back(computation);
}
}
return computations;
Expand Down Expand Up @@ -3138,11 +3137,22 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterSlice) {
test_runner_as_hlo_runner().ExecutableFromWrapped(
wrapped_executable.get()));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec);
ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 2ul);
auto& rs_start_thunk = gpu_exec->GetThunk().thunks()[0];
auto& rs_done_thunk = gpu_exec->GetThunk().thunks()[1];
ASSERT_EQ(rs_start_thunk->kind(), Thunk::kNcclReduceScatterStart);
ASSERT_EQ(rs_done_thunk->kind(), Thunk::kNcclReduceScatterDone);

// The pattern we have here is a static slice along with reduce-scatter
// operation. With this pattern, we can compute the offset at compile time and
// we do not need to emit a dynamic slice thunk to compute the offset at
// runtime. So, we expect to see kNcclReduceScatterStart and
// kNcclReduceScatterDone thunks. We also expect to see surrounding
// kWaitsForStreams thunks because dynamic slice fusion with a collective hero
// is converted into an async operation. The kWaitForStreams thunks are
// expected because of the async operation.
ASSERT_EQ(gpu_exec->GetThunk().thunks().size(), 4ul);
EXPECT_THAT(
gpu_exec->GetThunk().thunks(),
::testing::ElementsAre(ThunkKindIs(Thunk::kWaitForStreams),
ThunkKindIs(Thunk::kNcclReduceScatterStart),
ThunkKindIs(Thunk::kNcclReduceScatterDone),
ThunkKindIs(Thunk::kWaitForStreams)));

ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3};
EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt),
Expand Down Expand Up @@ -3279,6 +3289,87 @@ TEST_F(DynamicSliceFusionTest,
/*run_hlo_passes=*/false, /*use_threads=*/true, error));
}

TEST_F(DynamicSliceFusionTest,
AsyncDynamicSliceFusionWithCollectiveOverlapsWithComputeThunk) {
const char* hlo = R"(
HloModule test-clone, replica_count=2
add {
x = s32[] parameter(0)
y = s32[] parameter(1)
ROOT add = s32[] add(x, y)
}
dynamic-slice-fusion {
p1 = s32[2,2,32]{2,1,0} parameter(1)
p0 = s32[8,32]{1,0} parameter(0)
slice = s32[4,32]{1,0} slice(p0), slice={[4:8], [0:32]}
rs = s32[2,32]{1,0} reduce-scatter(slice), replica_groups={{0,1}}, dimensions={0}, to_apply=add
bitcast = s32[1,2,32]{2,1,0} bitcast(rs)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT dynamic-update-slice = s32[2,2,32]{2,1,0} dynamic-update-slice(p1, bitcast, p2, p3, p3)
}
ENTRY main {
source = s32[8,32]{1,0} parameter(1)
destination = s32[2,2,32]{2,1,0} parameter(0)
copy = s32[2,2,32]{2,1,0} copy(destination)
c1 = s32[] constant(1)
c0 = s32[] constant(0)
fusion-start = ((s32[8,32]{1,0}, s32[2,2,32]{2,1,0}, s32[], s32[]), s32[2,2,32]{2,1,0}, u32[]) fusion-start(source, copy, c1, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
fusion-done = s32[2,2,32]{2,1,0} fusion-done(fusion-start), backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
a = s32[1024,1024]{1,0} parameter(2)
b = s32[1024,1024]{1,0} parameter(3)
dot = s32[1024,1024]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT tuple = (s32[2,2,32]{2,1,0}, s32[1024,1024]{1,0}) tuple(fusion-done, dot)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<OpaqueExecutable> wrapped_exec,
CreateExecutable(hlo_module->Clone(), /*run_hlo_passes=*/false));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
std::move(wrapped_exec)));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
const ThunkSequence& thunks = gpu_exec->GetThunk().thunks();

// This is only needed to ensure that the next checks don't fail.
ASSERT_EQ(thunks.size(), 6);

// In the following checks, only the order of the thunks matter.
EXPECT_THAT(thunks,
::testing::ElementsAre(ThunkKindIs(Thunk::kCopy),
ThunkKindIs(Thunk::kWaitForStreams),
ThunkKindIs(Thunk::kDynamicSlice),
ThunkKindIs(Thunk::kKernel),
ThunkKindIs(Thunk::kNcclReduceScatterDone),
ThunkKindIs(Thunk::kWaitForStreams)));

// Check that the dynamic slice thunk only produces a start thunk, and not a
// done thunk.
DynamicSliceThunk* dynamic_slice_thunk =
dynamic_cast<DynamicSliceThunk*>(thunks[2].get());
ASSERT_NE(dynamic_slice_thunk, nullptr);
const SequentialThunk* embedded_thunk = dynamic_cast<const SequentialThunk*>(
dynamic_slice_thunk->embedded_thunk());
ASSERT_NE(embedded_thunk, nullptr);
EXPECT_THAT(
embedded_thunk->thunks(),
::testing::ElementsAre(ThunkKindIs(Thunk::kNcclReduceScatterStart)));

// Check that the offsets were propagated as constants, and not as device
// allocated buffers.
auto offsets = dynamic_slice_thunk->get_offsets();
EXPECT_THAT(offsets,
ElementsAre(std::nullopt,
Optional(ElementsAre(VariantWith<int64_t>(1),
VariantWith<int64_t>(0),
VariantWith<int64_t>(0)))));
}

} // namespace
} // namespace gpu
} // namespace xla
8 changes: 6 additions & 2 deletions xla/backends/gpu/codegen/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ std::unique_ptr<FusionInterface> GetFusionEmitter(

switch (analysis.GetEmitterFusionKind()) {
case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: {
const auto& config = backend_config.custom_fusion_config();
if (absl::StrContains(config.name(), "address_computation")) {
const absl::string_view& config_name =
backend_config.custom_fusion_config().name();
if (config_name ==
kDynamicSliceFusionWithStaticAddressComputationConfigName ||
config_name ==
kDynamicSliceFusionWithDynamicAddressComputationConfigName) {
return std::make_unique<DynamicSliceFusion>(analysis);
}
return std::make_unique<CustomFusion>();
Expand Down
21 changes: 20 additions & 1 deletion xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1473,18 +1473,37 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
}
if (metadata) {
instruction->set_metadata(*metadata);
if (instruction->IsAsynchronous()) {
instruction->async_wrapped_instruction()->set_metadata(*metadata);
}
}
if (original_value) {
instruction->set_original_value(*original_value);
if (instruction->IsAsynchronous()) {
instruction->async_wrapped_instruction()->set_original_value(
*original_value);
}
}
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
instruction->set_raw_backend_config_string(*backend_config);
if (instruction->IsAsynchronous()) {
instruction->async_wrapped_instruction()->set_raw_backend_config_string(
*backend_config);
}
}
if (frontend_attributes) {
instruction->set_frontend_attributes(*frontend_attributes);
if (instruction->IsAsynchronous()) {
instruction->async_wrapped_instruction()->set_frontend_attributes(
*frontend_attributes);
}
}
if (statistics_viz) {
instruction->set_statistics_viz(*statistics_viz);
if (instruction->IsAsynchronous()) {
instruction->async_wrapped_instruction()->set_statistics_viz(
*statistics_viz);
}
}

return AddInstruction(name, instruction, name_loc);
Expand Down
52 changes: 52 additions & 0 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5924,5 +5924,57 @@ ENTRY main {
"error: unexpected attribute \"result_accuracy\"");
}

TEST_F(HloParserTest,
AsyncInstructionWithAttributesShouldPropagateToWrappedInstruction) {
const char* hlo = R"(
HloModule test
ENTRY main {
a = s32[] parameter(0), origin={{"v1"}}
b = s32[] parameter(1), origin={{"v2"}}
add-start = ((s32[], s32[]), s32[], s32[]) add-start(a, b),
metadata={op_type="add" op_name="sample name" source_file="path/to/test.cc" source_line=68},
backend_config="foo\" bar",
frontend_attributes={attr_a="test_a",attr_b="b"},
statistics={visualizing_index=1,stat-1=33,stat-2=44}
ROOT add-done = s32[] add-done(add-start), origin={{"v3"}}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo));
// Check the wrapped instruction.
HloInstruction* wrapped_instr =
m->entry_computation()->root_instruction()->async_wrapped_instruction();
EXPECT_EQ(wrapped_instr->metadata().op_name(), "sample name");
EXPECT_EQ(wrapped_instr->metadata().op_type(), "add");
EXPECT_EQ(wrapped_instr->metadata().source_file(), "path/to/test.cc");
EXPECT_EQ(wrapped_instr->metadata().source_line(), 68);
EXPECT_EQ(wrapped_instr->raw_backend_config_string(), "foo\" bar");
EXPECT_EQ(wrapped_instr->frontend_attributes().map().size(), 2);
EXPECT_EQ(wrapped_instr->frontend_attributes().map().at("attr_a"), "test_a");
EXPECT_EQ(wrapped_instr->frontend_attributes().map().at("attr_b"), "b");
EXPECT_EQ(wrapped_instr->statistics_viz().stat_index_to_visualize(), 1);
EXPECT_EQ(wrapped_instr->statistics_viz().statistics_size(), 2);
EXPECT_EQ(wrapped_instr->statistics_viz().statistics(0).stat_name(),
"stat-1");
EXPECT_EQ(wrapped_instr->statistics_viz().statistics(0).stat_val(), 33);
EXPECT_EQ(wrapped_instr->statistics_viz().statistics(1).stat_name(),
"stat-2");
EXPECT_EQ(wrapped_instr->statistics_viz().statistics(1).stat_val(), 44);
EXPECT_EQ(OriginalValueToString(*wrapped_instr->original_value()),
"{\"v3\"}");
// Check the async-start and async-done instructions.
HloInstruction* async_done = m->entry_computation()->root_instruction();
HloInstruction* async_start = async_done->async_chain_start();
EXPECT_EQ(async_start->metadata().DebugString(),
wrapped_instr->metadata().DebugString());
EXPECT_EQ(async_start->raw_backend_config_string(),
wrapped_instr->raw_backend_config_string());
EXPECT_EQ(async_start->frontend_attributes().DebugString(),
wrapped_instr->frontend_attributes().DebugString());
EXPECT_EQ(async_start->statistics_viz().DebugString(),
wrapped_instr->statistics_viz().DebugString());
EXPECT_EQ(OriginalValueToString(*async_done->original_value()),
OriginalValueToString(*wrapped_instr->original_value()));
}

} // namespace
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ cc_library(
"//xla/ffi:ffi_api",
"//xla/ffi/api:c_api",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/mlir/utils:error_util",
"//xla/mlir_hlo:transforms_gpu_passes",
"//xla/service:buffer_assignment",
Expand Down Expand Up @@ -640,6 +641,7 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":backend_configs_cc",
":target_util",
"//xla:literal",
"//xla:shape_util",
Expand Down Expand Up @@ -1487,6 +1489,7 @@ cc_library(
"//xla/hlo/translate/hlo_to_mhlo:hlo_utils",
"//xla/hlo/translate/mhlo_to_hlo:location_exporter",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_traversal",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner",
Expand Down Expand Up @@ -1562,6 +1565,7 @@ cc_library(
"//xla/service:buffer_assignment",
"//xla/service:buffer_value",
"//xla/service:call_inliner",
"//xla/service:collective_ops_utils",
"//xla/service:collective_permute_decomposer",
"//xla/service:collective_pipeliner",
"//xla/service:collective_utils",
Expand Down
Loading

0 comments on commit 4394aa9

Please sign in to comment.