Skip to content

Commit

Permalink
[XLA:GPU] Remove the restrictions that prevent us from fusing the sub…
Browse files Browse the repository at this point in the history
…channel dequantisation sequence from Triton tiling propagation.

There are two cases for the broadcast:
1) Channel quantisation case:
     a) we have 2d weights + 1d scales + 2d activations.
     b) In triton prolog we prepare the corresponding block pointers
     c) inside the for loop along the k-dim every time we load the same 1d tile for scalers, expand it to 2d [block_m,1], and broadcast to block_k elements along newly added dim to [block_m, block_k].
     d) then do the multiply and dot

2) Subchannel quantisation case:
     a) we have 2d weights [M,K] + 2d scales [M,K/q] + 2d activations where q is the subchannel size.
     b) In triton prolog we prepare the corresponding block pointers
     c) inside the for loop along the k-dim every time we load the 2d [M,1] tile for scalers and broadcast to block_k elements along the k dim to [block_m, block_k].
     d) then do the rest

I.e. the difference is that the scalers matrix is the 2d matrix from the very beginning but it is smaller along the k dim and we need to advance it along k dim only by one column instead of block_k columns. It is already 2d, so, we don't need to add the dimension.

We could emit the right code if we know that there was the subchannel broadcast and and we know the size of the subchannel. We do this analysis in the triton_tiling_propagation by detecting the broadcast with the follow up bitcast combination like [B,c,M] -> [B,c,q,M] -> [B,c*q,M]. I.e. we do the broadcast but the follow up bitcast merges the broadcasted dim with another nonempty dim.

This schema works for the cases when block_k == subchannel_size and for the case when we have split_k == 1.

These two restrictions could be addressed in the follow up cls.

PiperOrigin-RevId: 725211140
  • Loading branch information
loislo authored and Google-ML-Automation committed Feb 10, 2025
1 parent 920cc8a commit f44ba78
Show file tree
Hide file tree
Showing 13 changed files with 639 additions and 85 deletions.
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ xla_test(
"gpu_b200",
"gpu_amd_any",
],
shard_count = 10,
tags = [
"large",
"no_mac",
Expand All @@ -696,6 +697,7 @@ xla_test(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log",
Expand Down
11 changes: 8 additions & 3 deletions xla/backends/gpu/codegen/triton/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ absl::StatusOr<TritonModule> CreateTritonModule(
if (type == U16) {
ir_type = b.getI16Type();
} else if (type == S4) {
ir_type = b.getI4Type();
ir_type = b.getI4Type();
} else {
TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type));
}
Expand Down Expand Up @@ -1149,13 +1149,17 @@ absl::StatusOr<TritonModule> CreateTritonModule(
b.create<ttir::ReturnOp>();

if (DumpingEnabledForHloModule(*hlo_computation->parent())) {
auto suffix = absl::StrCat(fusion->name(), ".before_validation.ttir");
DumpToFileInDirOrStdout(
*hlo_computation->parent(), "triton_ir", "before_validation.ttir",
*hlo_computation->parent(), "", suffix,
DumpTritonIR(triton_module.get(),
fusion->GetModule()
->config()
.debug_options()
.xla_gpu_unsupported_annotate_with_emitter_loc()));
std::string fusion_suffix = absl::StrCat(hlo_computation->name(), ".hlo");
DumpToFileInDirOrStdout(*hlo_computation->parent(), "", fusion_suffix,
hlo_computation->ToString());
}

if (mlir::failed(mlir::verify(*triton_module))) {
Expand All @@ -1179,8 +1183,9 @@ absl::StatusOr<TritonModule> CreateTritonModule(
// TODO(loislo): Remove this dump once we have the Triton IR dump in
// CompileTritonToLLVM after the Triton optimization passes.
if (DumpingEnabledForHloModule(*hlo_computation->parent())) {
std::string suffix = absl::StrCat(fusion->name(), ".ttir");
DumpToFileInDirOrStdout(
*hlo_computation->parent(), "triton_ir", "ttir",
*hlo_computation->parent(), "", suffix,
DumpTritonIR(triton_module.get(),
fusion->GetModule()
->config()
Expand Down
Loading

0 comments on commit f44ba78

Please sign in to comment.