Skip to content

Commit

Permalink
PR #20332: [ds-fusion] Add runtime support for host calculation of of…
Browse files Browse the repository at this point in the history
…fsets in ds fusion

Imported from GitHub PR #20332

This patch adds the support for calculating offset on the host at runtime when the offset depends on the loop induction variable. This is done by extracting the offset computation, the induction variable initialization and the induction variable update as independent computations and they are evaluated on the host at runtime. This avoids device-to-host copy for this fusion in these cases.
Copybara import of the project:

--
5c85fe7 by Shraiysh Vaishay <[email protected]>:

Add runtime support for host calculation of offsets in ds fusion

This patch adds the support for calculating offset on the host at
runtime when the offset depends on the loop induction variable. This is
done by extracting the offset computation, the induction variable
initialization and the induction variable update as independent
computations and they are evaluated on the host at runtime. This avoids
device-to-host copy for this fusion in these cases.

--
b5573b0 by Shraiysh Vaishay <[email protected]>:

Addressed comments

--
decde73 by Shraiysh Vaishay <[email protected]>:

Rebase

--
f98d9dc by Shraiysh Vaishay <[email protected]>:

Rebase

Merging this change closes #20332

COPYBARA_INTEGRATE_REVIEW=#20332 from shraiysh:ds_fusion_3 f98d9dc
PiperOrigin-RevId: 726440397
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Feb 13, 2025
1 parent 6b470af commit 988e55c
Show file tree
Hide file tree
Showing 11 changed files with 512 additions and 31 deletions.
3 changes: 3 additions & 0 deletions xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ cc_library(
"//xla/backends/gpu/runtime:thunk",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/analysis:while_loop_analysis",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
Expand All @@ -139,6 +140,7 @@ cc_library(
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:custom_kernel_fusion",
"//xla/stream_executor:stream",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -174,6 +176,7 @@ xla_test(
"//xla/backends/gpu/runtime:dynamic_slice_thunk",
"//xla/backends/gpu/runtime:sequential_thunk",
"//xla/backends/gpu/runtime:thunk",
"//xla/backends/gpu/runtime:while_thunk",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/hlo/builder:xla_builder",
Expand Down
239 changes: 219 additions & 20 deletions xla/backends/gpu/codegen/custom.cc

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions xla/backends/gpu/codegen/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ class CustomFusion : public FusionInterface {
// nested tuple.
class DynamicSliceFusion : public FusionInterface {
public:
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis) {}
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis,
const CallGraph& call_graph)
: analysis_(analysis), call_graph_(call_graph) {}

absl::StatusOr<FusionEmissionResult> Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const final;

private:
const HloFusionAnalysis& analysis_;
const CallGraph& call_graph_;
};

} // namespace gpu
Expand Down
262 changes: 262 additions & 0 deletions xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/error_spec.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
Expand Down Expand Up @@ -3370,6 +3371,267 @@ TEST_F(DynamicSliceFusionTest,
VariantWith<int64_t>(0)))));
}

TEST_F(DynamicSliceFusionTest,
OffsetAsFunctionOfInductionVariableShouldUseOffsetModules) {
const char* hlo_fused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
dynamic-slice-fusion {
p1 = s32[32,32] parameter(1)
p0 = s32[32,32] parameter(0)
rs = s32[16,32] reduce-scatter(p0), replica_groups={{0,1}}, dimensions={0}, to_apply=add
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT dus = s32[32,32] dynamic-update-slice(p1, rs, p2, p3)
}
body {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c1 = s32[] constant(1)
add = s32[] add(iter, c1)
src = s32[32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2
// Offset calculation as a function of the induction variable.
add.1 = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply = s32[] multiply(add.1, c3)
c16 = s32[] constant(16)
offset = s32[] subtract(multiply, c16)
c0 = s32[] constant(0)
address_computation = s32[32,32] fusion(src, dest, offset, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
ROOT tuple = (s32[], s32[32,32], s32[32,32]) tuple(add, src, address_computation)
}
condition {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
c0 = s32[] constant(0)
src = s32[32,32] parameter(0)
dest = s32[32,32] parameter(1)
tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), condition=condition, body=body
})";
const char* hlo_unfused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
body {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
src = s32[32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2
// Offset calculation as a function of the induction variable.
add = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply = s32[] multiply(add, c3)
c16 = s32[] constant(16)
offset = s32[] subtract(multiply, c16)
c0 = s32[] constant(0)
rs_start = ((s32[32,32]), s32[16,32]) reduce-scatter-start(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add
rs = s32[16,32] reduce-scatter-done(rs_start)
dus = s32[32,32] dynamic-update-slice(dest, rs, offset, c0)
c1 = s32[] constant(1)
add.1 = s32[] add(iter, c1)
ROOT tuple = tuple(add.1, src, dus)
}
condition {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
src = s32[32,32] parameter(0)
dest = s32[32,32] parameter(1)
c0 = s32[] constant(0)
tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), body=body, condition=condition
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> fused_module,
ParseAndReturnVerifiedModule(hlo_fused));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> wrapped_exec,
CreateExecutable(fused_module->Clone(), false));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
std::move(wrapped_exec)));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
ASSERT_NE(gpu_exec, nullptr);

auto while_thunk = absl::c_find_if(gpu_exec->GetThunk().thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kWhile;
});
ASSERT_NE(while_thunk, gpu_exec->GetThunk().thunks().end());
WhileThunk* while_thunk_ptr = dynamic_cast<WhileThunk*>(while_thunk->get());

auto ds_thunk =
absl::c_find_if(while_thunk_ptr->body_thunk_sequence()->thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kDynamicSlice;
});
ASSERT_NE(ds_thunk, while_thunk_ptr->body_thunk_sequence()->thunks().end());
DynamicSliceThunk* ds_thunk_ptr =
dynamic_cast<DynamicSliceThunk*>(ds_thunk->get());
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets =
ds_thunk_ptr->get_offsets();

// Expect two offsets: one for the input, and one for the outputs.
ASSERT_EQ(offsets.size(), 2);
ASSERT_TRUE(offsets[1].has_value());
std::vector<DynamicSliceThunk::Offset> output_offsets = *offsets[1];
ASSERT_EQ(output_offsets.size(), 2);

// The first value of offset must be an HloModule
HloModule** offset_0 = std::get_if<HloModule*>(&output_offsets[0]);
ASSERT_NE(offset_0, nullptr);
ASSERT_NE(*offset_0, nullptr);

// The second offset must be a constant value
ASSERT_EQ(output_offsets[1], DynamicSliceThunk::Offset(0l));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> unfused_module,
ParseAndReturnVerifiedModule(hlo_unfused));

EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(fused_module), std::move(unfused_module),
/*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt));
}

TEST_F(DynamicSliceFusionTest, MultipleOffsetsAsFunctionOfInductionVariable) {
const char* hlo_fused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
dynamic-slice-fusion {
p0 = s32[16,32,32] parameter(0)
p1 = s32[32,32] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
p4 = s32[] parameter(4)
ds = s32[1,32,32] dynamic-slice(p0, p2, p4, p4), dynamic_slice_sizes={1,32,32}
bitcast = s32[32,32] bitcast(ds)
rs = s32[16,32] reduce-scatter(bitcast), replica_groups={{0,1}}, dimensions={0}, to_apply=add
ROOT dus = s32[32,32] dynamic-update-slice(p1, rs, p3, p4)
}
body {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c1 = s32[] constant(1)
add = s32[] add(iter, c1)
src = s32[16,32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2
// Offset calculation as a function of the induction variable.
// offset.1 = 5i-32
c5 = s32[] constant(5)
c32 = s32[] constant(32)
multiply.1 = s32[] multiply(c5, iter)
offset.1 = s32[] subtract(multiply.1, c32)
// offset.2 = 6i-16
add.1 = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply.2 = s32[] multiply(add.1, c3)
c16 = s32[] constant(16)
offset.2 = s32[] subtract(multiply.2, c16)
c0 = s32[] constant(0)
address_computation = s32[32,32] fusion(src, dest, offset.1, offset.2, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
ROOT tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(add, src, address_computation)
}
condition {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
c0 = s32[] constant(0)
src = s32[16,32,32] parameter(0)
dest = s32[32,32] parameter(1)
tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[16,32,32], s32[32,32]) while(tuple), condition=condition, body=body
})";
const char* hlo_unfused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
body {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
src = s32[16,32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2
// Offset calculation as a function of the induction variable.
// offset.1 = 5i-32
c5 = s32[] constant(5)
c32 = s32[] constant(32)
multiply.1 = s32[] multiply(c5, iter)
offset.1 = s32[] subtract(multiply.1, c32)
// offset.2 = 6i-16
add = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply.2 = s32[] multiply(add, c3)
c16 = s32[] constant(16)
offset.2 = s32[] subtract(multiply.2, c16)
c0 = s32[] constant(0)
ds = s32[1,32,32] dynamic-slice(src, offset.1, c0, c0), dynamic_slice_sizes={1,32,32}
reshape = s32[32,32] reshape(ds)
rs_start = ((s32[32,32]), s32[16,32]) reduce-scatter-start(reshape), dimensions={0}, replica_groups={{0,1}}, to_apply=add
rs = s32[16,32] reduce-scatter-done(rs_start)
dus = s32[32,32] dynamic-update-slice(dest, rs, offset.2, c0)
c1 = s32[] constant(1)
add.1 = s32[] add(iter, c1)
ROOT tuple = tuple(add.1, src, dus)
}
condition {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
src = s32[16,32,32] parameter(0)
dest = s32[32,32] parameter(1)
c0 = s32[] constant(0)
tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[16,32,32], s32[32,32]) while(tuple), body=body, condition=condition
}
)";

EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
/*module_0_str=*/hlo_unfused, /*module_1_str=*/hlo_fused,
/*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt));
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 4 additions & 1 deletion xla/backends/gpu/codegen/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ std::unique_ptr<FusionInterface> GetFusionEmitter(
kDynamicSliceFusionWithStaticAddressComputationConfigName ||
config_name ==
kDynamicSliceFusionWithDynamicAddressComputationConfigName) {
return std::make_unique<DynamicSliceFusion>(analysis);
const HloFusionInfo* hlo_fusion_info =
dynamic_cast<const HloFusionInfo*>(&fusion_info);
return std::make_unique<DynamicSliceFusion>(
analysis, hlo_fusion_info->GetCallGraph());
}
return std::make_unique<CustomFusion>();
}
Expand Down
8 changes: 6 additions & 2 deletions xla/backends/gpu/codegen/fusions.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,22 @@ class HloFusionInfo : public FusionInfo {
public:
HloFusionInfo(const HloFusionAnalysis& analysis,
const HloFusionInstruction* instr,
const BufferAssignment* buffer_assignment)
const BufferAssignment* buffer_assignment,
const CallGraph& call_graph)
: FusionInfo(analysis),
instr_(instr),
buffer_assignment_(buffer_assignment) {}
buffer_assignment_(buffer_assignment),
call_graph_(call_graph) {}

bool CanEmitDynamicUpdateSliceInPlace() const override;
std::optional<std::unique_ptr<FusionInterface>> GetCopyFusion()
const override;
const CallGraph& GetCallGraph() const { return call_graph_; };

private:
const HloFusionInstruction* instr_;
const BufferAssignment* buffer_assignment_;
const CallGraph& call_graph_;
};

class PreBufferAssignmentFusionInfo : public FusionInfo {
Expand Down
1 change: 0 additions & 1 deletion xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ cc_library(
deps = [
":sequential_thunk",
":thunk",
":while_thunk",
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
Expand Down
1 change: 0 additions & 1 deletion xla/backends/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/buffer_allocations.h"
Expand Down
10 changes: 7 additions & 3 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ namespace gpu {
IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context)
: IrEmitter(ir_emitter_context, /*is_nested=*/false),
send_recv_events_(std::make_shared<SendRecvAsyncEvents>()),
copy_events_(std::make_shared<CopyThunk::AsyncEvents>()) {}
copy_events_(std::make_shared<CopyThunk::AsyncEvents>()),
call_graph_(CallGraph::Build(&ir_emitter_context->hlo_module())) {}

std::unique_ptr<IrEmitterUnnested> IrEmitterUnnested::Create(
IrEmitterContext* ir_emitter_context) {
Expand Down Expand Up @@ -1543,8 +1544,11 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) {
const HloFusionAnalysis fusion_analysis =
HloFusionAnalysis::Create(*instr, device_info);
VLOG(3) << "IrEmitterUnnested::EmitFusion:start";
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(HloFusionInfo(
fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()));
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(
/*fusion_info=*/HloFusionInfo(
/*analysis=*/fusion_analysis, instr,
/*buffer_assignment=*/&ir_emitter_context_->buffer_assignment(),
/*call_graph=*/*call_graph_));
TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));

const ExecutionStreamAssignment& stream_assignment =
Expand Down
Loading

0 comments on commit 988e55c

Please sign in to comment.