Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ds-fusion] Add runtime support for host calculation of offsets in ds fusion #20332

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ cc_library(
"//xla/backends/gpu/runtime:thunk",
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/analysis:while_loop_analysis",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
Expand All @@ -139,6 +140,7 @@ cc_library(
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:custom_kernel_fusion",
"//xla/stream_executor:stream",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -174,6 +176,7 @@ xla_test(
"//xla/backends/gpu/runtime:dynamic_slice_thunk",
"//xla/backends/gpu/runtime:sequential_thunk",
"//xla/backends/gpu/runtime:thunk",
"//xla/backends/gpu/runtime:while_thunk",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/hlo/builder:xla_builder",
Expand Down
239 changes: 219 additions & 20 deletions xla/backends/gpu/codegen/custom.cc

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions xla/backends/gpu/codegen/custom.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ class CustomFusion : public FusionInterface {
// nested tuple.
class DynamicSliceFusion : public FusionInterface {
public:
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis)
: analysis_(analysis) {}
explicit DynamicSliceFusion(const HloFusionAnalysis& analysis,
const CallGraph& call_graph)
: analysis_(analysis), call_graph_(call_graph) {}

absl::StatusOr<FusionEmissionResult> Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const final;

private:
const HloFusionAnalysis& analysis_;
const CallGraph& call_graph_;
};

} // namespace gpu
Expand Down
262 changes: 262 additions & 0 deletions xla/backends/gpu/codegen/dynamic_slice_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/error_spec.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
Expand Down Expand Up @@ -3279,6 +3280,267 @@ TEST_F(DynamicSliceFusionTest,
/*run_hlo_passes=*/false, /*use_threads=*/true, error));
}

TEST_F(DynamicSliceFusionTest,
OffsetAsFunctionOfInductionVariableShouldUseOffsetModules) {
const char* hlo_fused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
dynamic-slice-fusion {
p1 = s32[32,32] parameter(1)
p0 = s32[32,32] parameter(0)
rs = s32[16,32] reduce-scatter(p0), replica_groups={{0,1}}, dimensions={0}, to_apply=add
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT dus = s32[32,32] dynamic-update-slice(p1, rs, p2, p3)
}
body {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c1 = s32[] constant(1)
add = s32[] add(iter, c1)
src = s32[32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2

// Offset calculation as a function of the induction variable.
add.1 = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply = s32[] multiply(add.1, c3)
c16 = s32[] constant(16)
offset = s32[] subtract(multiply, c16)

c0 = s32[] constant(0)
address_computation = s32[32,32] fusion(src, dest, offset, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
ROOT tuple = (s32[], s32[32,32], s32[32,32]) tuple(add, src, address_computation)
}
condition {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
c0 = s32[] constant(0)
src = s32[32,32] parameter(0)
dest = s32[32,32] parameter(1)
tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), condition=condition, body=body
})";
const char* hlo_unfused = R"(
HloModule test, replica_count=2

add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}

body {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
src = s32[32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2

// Offset calculation as a function of the induction variable.
add = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply = s32[] multiply(add, c3)
c16 = s32[] constant(16)
offset = s32[] subtract(multiply, c16)

c0 = s32[] constant(0)
rs_start = ((s32[32,32]), s32[16,32]) reduce-scatter-start(src), dimensions={0}, replica_groups={{0,1}}, to_apply=add
rs = s32[16,32] reduce-scatter-done(rs_start)
dus = s32[32,32] dynamic-update-slice(dest, rs, offset, c0)
c1 = s32[] constant(1)
add.1 = s32[] add(iter, c1)
ROOT tuple = tuple(add.1, src, dus)
}

condition {
param = (s32[], s32[32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}

ENTRY main {
src = s32[32,32] parameter(0)
dest = s32[32,32] parameter(1)
c0 = s32[] constant(0)
tuple = (s32[], s32[32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[32,32], s32[32,32]) while(tuple), body=body, condition=condition
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> fused_module,
ParseAndReturnVerifiedModule(hlo_fused));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<OpaqueExecutable> wrapped_exec,
CreateExecutable(fused_module->Clone(), false));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> exec,
test_runner_as_hlo_runner().ExecutableFromWrapped(
std::move(wrapped_exec)));
GpuExecutable* gpu_exec = dynamic_cast<GpuExecutable*>(exec.get());
ASSERT_NE(gpu_exec, nullptr);

auto while_thunk = absl::c_find_if(gpu_exec->GetThunk().thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kWhile;
});
ASSERT_NE(while_thunk, gpu_exec->GetThunk().thunks().end());
WhileThunk* while_thunk_ptr = dynamic_cast<WhileThunk*>(while_thunk->get());

auto ds_thunk =
absl::c_find_if(while_thunk_ptr->body_thunk_sequence()->thunks(),
[](const std::unique_ptr<Thunk>& thunk) {
return thunk->kind() == Thunk::kDynamicSlice;
});
ASSERT_NE(ds_thunk, while_thunk_ptr->body_thunk_sequence()->thunks().end());
DynamicSliceThunk* ds_thunk_ptr =
dynamic_cast<DynamicSliceThunk*>(ds_thunk->get());
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets =
ds_thunk_ptr->get_offsets();

// Expect two offsets: one for the input, and one for the outputs.
ASSERT_EQ(offsets.size(), 2);
ASSERT_TRUE(offsets[1].has_value());
std::vector<DynamicSliceThunk::Offset> output_offsets = *offsets[1];
ASSERT_EQ(output_offsets.size(), 2);

// The first value of offset must be an HloModule
HloModule** offset_0 = std::get_if<HloModule*>(&output_offsets[0]);
ASSERT_NE(offset_0, nullptr);
ASSERT_NE(*offset_0, nullptr);

// The second offset must be a constant value
ASSERT_EQ(output_offsets[1], DynamicSliceThunk::Offset(0l));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> unfused_module,
ParseAndReturnVerifiedModule(hlo_unfused));

EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
std::move(fused_module), std::move(unfused_module),
/*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt));
}

TEST_F(DynamicSliceFusionTest, MultipleOffsetsAsFunctionOfInductionVariable) {
const char* hlo_fused = R"(
HloModule test, replica_count=2
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
dynamic-slice-fusion {
p0 = s32[16,32,32] parameter(0)
p1 = s32[32,32] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
p4 = s32[] parameter(4)
ds = s32[1,32,32] dynamic-slice(p0, p2, p4, p4), dynamic_slice_sizes={1,32,32}
bitcast = s32[32,32] bitcast(ds)
rs = s32[16,32] reduce-scatter(bitcast), replica_groups={{0,1}}, dimensions={0}, to_apply=add
ROOT dus = s32[32,32] dynamic-update-slice(p1, rs, p3, p4)
}
body {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c1 = s32[] constant(1)
add = s32[] add(iter, c1)
src = s32[16,32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2

// Offset calculation as a function of the induction variable.
// offset.1 = 5i-32
c5 = s32[] constant(5)
c32 = s32[] constant(32)
multiply.1 = s32[] multiply(c5, iter)
offset.1 = s32[] subtract(multiply.1, c32)
// offset.2 = 6i-16
add.1 = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply.2 = s32[] multiply(add.1, c3)
c16 = s32[] constant(16)
offset.2 = s32[] subtract(multiply.2, c16)

c0 = s32[] constant(0)
address_computation = s32[32,32] fusion(src, dest, offset.1, offset.2, c0), kind=kCustom, calls=dynamic-slice-fusion, backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}}
ROOT tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(add, src, address_computation)
}
condition {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}
ENTRY main {
c0 = s32[] constant(0)
src = s32[16,32,32] parameter(0)
dest = s32[32,32] parameter(1)
tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[16,32,32], s32[32,32]) while(tuple), condition=condition, body=body
})";
const char* hlo_unfused = R"(
HloModule test, replica_count=2

add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}

body {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
src = s32[16,32,32] get-tuple-element(param), index=1
dest = s32[32,32] get-tuple-element(param), index=2

// Offset calculation as a function of the induction variable.
// offset.1 = 5i-32
c5 = s32[] constant(5)
c32 = s32[] constant(32)
multiply.1 = s32[] multiply(c5, iter)
offset.1 = s32[] subtract(multiply.1, c32)
// offset.2 = 6i-16
add = s32[] add(iter, iter)
c3 = s32[] constant(3)
multiply.2 = s32[] multiply(add, c3)
c16 = s32[] constant(16)
offset.2 = s32[] subtract(multiply.2, c16)

c0 = s32[] constant(0)
ds = s32[1,32,32] dynamic-slice(src, offset.1, c0, c0), dynamic_slice_sizes={1,32,32}
reshape = s32[32,32] reshape(ds)
rs_start = ((s32[32,32]), s32[16,32]) reduce-scatter-start(reshape), dimensions={0}, replica_groups={{0,1}}, to_apply=add
rs = s32[16,32] reduce-scatter-done(rs_start)
dus = s32[32,32] dynamic-update-slice(dest, rs, offset.2, c0)
c1 = s32[] constant(1)
add.1 = s32[] add(iter, c1)
ROOT tuple = tuple(add.1, src, dus)
}

condition {
param = (s32[], s32[16,32,32], s32[32,32]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c16 = s32[] constant(16)
ROOT compare = pred[] compare(iter, c16), direction=LT
}

ENTRY main {
src = s32[16,32,32] parameter(0)
dest = s32[32,32] parameter(1)
c0 = s32[] constant(0)
tuple = (s32[], s32[16,32,32], s32[32,32]) tuple(c0, src, dest)
ROOT while = (s32[], s32[16,32,32], s32[32,32]) while(tuple), body=body, condition=condition
}
)";

EXPECT_TRUE(RunAndCompareTwoModulesReplicated(
/*module_0_str=*/hlo_unfused, /*module_1_str=*/hlo_fused,
/*run_hlo_passes=*/false, /*use_threads=*/true, std::nullopt));
}

} // namespace
} // namespace gpu
} // namespace xla
5 changes: 4 additions & 1 deletion xla/backends/gpu/codegen/fusions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ std::unique_ptr<FusionInterface> GetFusionEmitter(
case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: {
const auto& config = backend_config.custom_fusion_config();
if (absl::StrContains(config.name(), "address_computation")) {
return std::make_unique<DynamicSliceFusion>(analysis);
const HloFusionInfo* hlo_fusion_info =
dynamic_cast<const HloFusionInfo*>(&fusion_info);
return std::make_unique<DynamicSliceFusion>(
analysis, hlo_fusion_info->GetCallGraph());
}
return std::make_unique<CustomFusion>();
}
Expand Down
8 changes: 6 additions & 2 deletions xla/backends/gpu/codegen/fusions.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,22 @@ class HloFusionInfo : public FusionInfo {
public:
HloFusionInfo(const HloFusionAnalysis& analysis,
const HloFusionInstruction* instr,
const BufferAssignment* buffer_assignment)
const BufferAssignment* buffer_assignment,
const CallGraph& call_graph)
: FusionInfo(analysis),
instr_(instr),
buffer_assignment_(buffer_assignment) {}
buffer_assignment_(buffer_assignment),
call_graph_(call_graph) {}

bool CanEmitDynamicUpdateSliceInPlace() const override;
std::optional<std::unique_ptr<FusionInterface>> GetCopyFusion()
const override;
const CallGraph& GetCallGraph() const { return call_graph_; };

private:
const HloFusionInstruction* instr_;
const BufferAssignment* buffer_assignment_;
const CallGraph& call_graph_;
};

class PreBufferAssignmentFusionInfo : public FusionInfo {
Expand Down
3 changes: 1 addition & 2 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,13 @@ cc_library(
deps = [
":sequential_thunk",
":thunk",
":while_thunk",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla/hlo/evaluator:hlo_evaluator",
"//xla/service:buffer_assignment",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:ir_emission_utils",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:memory_allocation",
"//xla/stream_executor:stream",
Expand Down
1 change: 0 additions & 1 deletion xla/backends/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/buffer_allocations.h"
Expand Down
Loading