Skip to content

Commit

Permalink
Add runtime support for host calculation of offsets in ds fusion
Browse files Browse the repository at this point in the history
This patch adds the support for calculating offset on the host at
runtime when the offset depends on the loop induction variable. This is
done by extracting the offset computation, the induction variable
initialization and the induction variable update as independent
computations and they are evaluated on the host at runtime. This avoids
device-to-host copy for this fusion in these cases.
  • Loading branch information
shraiysh committed Dec 9, 2024
1 parent a041e1b commit 8c43e80
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 41 deletions.
6 changes: 4 additions & 2 deletions xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ cc_library(
"//xla/backends/gpu/collectives:gpu_clique_key",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/analysis:while_loop_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:hlo_proto_cc",
"//xla/service:pattern_matcher",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:hlo_fusion_analysis",
Expand All @@ -97,6 +97,7 @@ cc_library(
"//xla/service/gpu/runtime:nccl_collective_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/stream_executor:stream",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -135,6 +136,7 @@ xla_test(
"//xla/hlo/builder:xla_computation",
"//xla/hlo/builder/lib:constants",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/service:custom_call_target_registry",
"//xla/service:executable",
"//xla/service:hlo_module_config",
Expand All @@ -144,12 +146,12 @@ xla_test(
"//xla/service/gpu/runtime:dynamic_slice_thunk",
"//xla/service/gpu/runtime:sequential_thunk",
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/runtime:while_thunk",
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:gpu_types_header",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
"@com_google_absl//absl/algorithm:container",
Expand Down
183 changes: 161 additions & 22 deletions xla/service/gpu/fusions/custom.cc

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions xla/service/gpu/fusions/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ class CustomFusion : public FusionInterface {
// nested tuple.
class DynamicSliceFusion : public FusionInterface {
public:
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis) {}
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis,
const CallGraph& call_graph)
: analysis_(analysis), call_graph_(call_graph) {}

absl::StatusOr<FusionEmissionResult> Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const final;

private:
const HloFusionAnalysis& analysis_;
const CallGraph& call_graph_;
};

} // namespace gpu
Expand Down
147 changes: 146 additions & 1 deletion xla/service/gpu/fusions/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ limitations under the License.
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/service/gpu/runtime/sequential_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/service/gpu/runtime/while_thunk.h"
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_module_config.h"
Expand All @@ -42,7 +44,6 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/stream.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -3190,6 +3191,150 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDynamicSlice) {
false, true, error));
}

TEST_F(DynamicSliceFusionTest,
OffsetAsFunctionOfInductionVariableShouldUseOffsetModules) {
const char* hlo_fused = R"(
HloModule test, replica_count=2
%add (a: s32[], b: s32[]) -> s32[] {
%a = s32[] parameter(0)
%b = s32[] parameter(1)
ROOT %add = s32[] add(%a, %b)
}
%dynamic-slice-fusion (p0: s32[32,32], p1: s32[32,32], p2: s32[], p3: s32[]) -> s32[32,32] {
%p1 = s32[32,32]{1,0} parameter(1)
%p0 = s32[32,32]{1,0} parameter(0)
%rs.1 = s32[16,32]{1,0} reduce-scatter(%p0), replica_groups={{0,1}}, dimensions={0}, to_apply=%add
%p2 = s32[] parameter(2)
%p3 = s32[] parameter(3)
ROOT %dus.1 = s32[32,32]{1,0} dynamic-update-slice(%p1, %rs.1, %p2, %p3)
}
%body (param.1: (s32[], s32[32,32], s32[32,32])) -> (s32[], s32[32,32], s32[32,32]) {
%param.1 = (s32[], s32[32,32]{1,0}, s32[32,32]{1,0}) parameter(0)
%iter.1 = s32[] get-tuple-element(%param.1), index=0
%c1 = s32[] constant(1)
%add.2 = s32[] add(%iter.1, %c1)
%src = s32[32,32]{1,0} get-tuple-element(%param.1), index=1
%dest = s32[32,32]{1,0} get-tuple-element(%param.1), index=2
// Offset calculation as a function of the induction variable.
%add.1 = s32[] add(%iter.1, %iter.1)
%c3 = s32[] constant(3)
%multiply.1 = s32[] multiply(%add.1, %c3)
%c16 = s32[] constant(16)
%offset.1 = s32[] subtract(%multiply.1, %c16)
%c0 = s32[] constant(0)
%address_computation = s32[32,32]{1,0} fusion(%src, %dest, %offset.1, %c0), kind=kCustom, calls=%dynamic-slice-fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation","kernel_index":0}},"force_earliest_schedule":false}
ROOT %tuple = (s32[], s32[32,32]{1,0}, s32[32,32]{1,0}) tuple(%add.2, %src, %address_computation)
}
%condition (param.2: (s32[], s32[32,32], s32[32,32])) -> pred[] {
%param.2 = (s32[], s32[32,32]{1,0}, s32[32,32]{1,0}) parameter(0)
%iter.2 = s32[] get-tuple-element(%param.2), index=0
%c16.1 = s32[] constant(16)
ROOT %compare = pred[] compare(%iter.2, %c16.1), direction=LT
}
ENTRY %main (src.1: s32[32,32], dest.1: s32[32,32]) -> (s32[], s32[32,32], s32[32,32]) {
%c0.1 = s32[] constant(0)
%src.1 = s32[32,32]{1,0} parameter(0)
%dest.1 = s32[32,32]{1,0} parameter(1)
%tuple.1 = (s32[], s32[32,32]{1,0}, s32[32,32]{1,0}) tuple(%c0.1, %src.1, %dest.1)
ROOT %while = (s32[], s32[32,32]{1,0}, s32[32,32]{1,0}) while(%tuple.1), condition=%condition, body=%body
})";

const char* hlo_unfused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
body {
param.1 = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter.1 = s32[] get-tuple-element(param.1), index=0
src = s32[32,32] get-tuple-element(param.1), index=1
dest = s32[32,32] get-tuple-element(param.1), index=2
// Offset calculation as a function of the induction variable.
add.1 = s32[] add(iter.1, iter.1)
c3 = s32[] constant(3)
multiply.1 = s32[] multiply(add.1, c3)
c16 = s32[] constant(16)
offset.1 = s32[] subtract(multiply.1, c16)
c0 = s32[] constant(0)
rs = s32[16,32] reduce-scatter(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add
dus = s32[32,32] dynamic-update-slice(dest, rs, offset.1, c0)
c1 = s32[] constant(1)
add.2 = s32[] add(iter.1, c1)
ROOT tuple = tuple(add.2, src, dus)
}
condition {
param.2 = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter.2 = s32[] get-tuple-element(param.2), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter.2, c16), direction=LT
}
ENTRY main {
src = s32[32,32] parameter(0)
dest = s32[32,32] parameter(1)
c0 = s32[] constant(0)
tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), body=body, condition=condition
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> fused_module,
ParseAndReturnVerifiedModule(hlo_fused));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> exec,
CreateExecutable(fused_module->Clone(), false));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
ASSERT_NE(gpu_exec, nullptr);
auto while_thunk = absl::c_find_if(gpu_exec->GetThunk().thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kWhile;
});
ASSERT_NE(while_thunk, gpu_exec->GetThunk().thunks().end());
WhileThunk* while_thunk_ptr = dynamic_cast<WhileThunk*>(while_thunk->get());
auto ds_thunk =
absl::c_find_if(while_thunk_ptr->body_thunk_sequence()->thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kDynamicSlice;
});
ASSERT_NE(ds_thunk, while_thunk_ptr->body_thunk_sequence()->thunks().end());
DynamicSliceThunk* ds_thunk_ptr =
dynamic_cast<DynamicSliceThunk*>(ds_thunk->get());
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets =
ds_thunk_ptr->get_offsets();
// Atleast two offsets: one for the input, and one for the outputs.
ASSERT_GE(offsets.size(), 2);
ASSERT_TRUE(offsets[1].has_value());
std::vector<DynamicSliceThunk::Offset> output_offsets = *offsets[1];
ASSERT_EQ(output_offsets.size(), 2);

// The first value of offset must be an HloModule
HloModule** offset1 = std::get_if<HloModule*>(&output_offsets[0]);
ASSERT_NE(offset1, nullptr);
ASSERT_NE(*offset1, nullptr);

// The second offset must be a constant value
ASSERT_EQ(output_offsets[1], DynamicSliceThunk::Offset(0ul));

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> unfused_module,
ParseAndReturnVerifiedModule(hlo_unfused));
ErrorSpec error{1e-5, 1e-5};
RunAndCompareTwoModulesReplicated(
std::move(fused_module), std::move(unfused_module),
/*run_hlo_passes=*/false, /*use_threads=*/true, error);
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 4 additions & 1 deletion xla/service/gpu/fusions/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ std::unique_ptr<FusionInterface> GetFusionEmitter(
case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: {
const auto& config = backend_config.custom_fusion_config();
if (absl::StrContains(config.name(), "address_computation")) {
return std::make_unique<DynamicSliceFusion>(analysis);
const HloFusionInfo* hlo_fusion_info =
dynamic_cast<const HloFusionInfo*>(&fusion_info);
return std::make_unique<DynamicSliceFusion>(
analysis, hlo_fusion_info->GetCallGraph());
}
return std::make_unique<CustomFusion>();
}
Expand Down
8 changes: 6 additions & 2 deletions xla/service/gpu/fusions/fusions.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,22 @@ class HloFusionInfo : public FusionInfo {
public:
HloFusionInfo(const HloFusionAnalysis& analysis,
const HloFusionInstruction* instr,
const BufferAssignment* buffer_assignment)
const BufferAssignment* buffer_assignment,
CallGraph& call_graph)
: FusionInfo(analysis),
instr_(instr),
buffer_assignment_(buffer_assignment) {}
buffer_assignment_(buffer_assignment),
call_graph_(call_graph) {}

bool CanEmitDynamicUpdateSliceInPlace() const override;
std::optional<std::unique_ptr<FusionInterface>> GetCopyFusion()
const override;
const CallGraph& GetCallGraph() const { return call_graph_; };

private:
const HloFusionInstruction* instr_;
const BufferAssignment* buffer_assignment_;
const CallGraph& call_graph_;
};

class PreBufferAssignmentFusionInfo : public FusionInfo {
Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ IrEmitterUnnested::IrEmitterUnnested(IrEmitterContext* ir_emitter_context)
: IrEmitter(ir_emitter_context, /*is_nested=*/false),
send_recv_events_(std::make_shared<SendRecvAsyncEvents>()),
copy_events_(std::make_shared<CopyThunk::AsyncEvents>()),
elemental_emitter_(*ir_emitter_context, &b_) {}
elemental_emitter_(*ir_emitter_context, &b_),
call_graph(CallGraph::Build(&ir_emitter_context->hlo_module())) {}

std::unique_ptr<IrEmitterUnnested> IrEmitterUnnested::Create(
IrEmitterContext* ir_emitter_context) {
Expand Down Expand Up @@ -1505,8 +1506,9 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) {
const HloFusionAnalysis fusion_analysis =
HloFusionAnalysis::Create(*instr, device_info);

std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(HloFusionInfo(
fusion_analysis, instr, &ir_emitter_context_->buffer_assignment()));
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(
HloFusionInfo(fusion_analysis, instr,
&ir_emitter_context_->buffer_assignment(), *call_graph));
TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));

const ExecutionStreamAssignment& stream_assignment =
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ class IrEmitterUnnested : public IrEmitter {
std::shared_ptr<CopyThunk::AsyncEvents> copy_events_;

GpuElementalIrEmitter elemental_emitter_;

std::unique_ptr<CallGraph> call_graph;
};

} // namespace gpu
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ cc_library(
deps = [
":sequential_thunk",
":thunk",
":while_thunk",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/service:buffer_assignment",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:ir_emission_utils",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:memory_allocation",
"//xla/stream_executor:stream",
Expand Down
Loading

0 comments on commit 8c43e80

Please sign in to comment.