Skip to content

Commit

Permalink
PR #22512: [XLA:GPU] Enable cuDNN kernel for NVFP4 block scaled dot
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22512

Support NVFP4 in addition to MXFP8 hardware acceleration for the "__op$block_scaled_dot" custom call.

This PR also addresses some nits from the internal review (like renaming a generic `CompositeType` to a more specific `CudnnMxType`).
Copybara import of the project:

--
32e76a8 by Sergey Kozub <[email protected]>:

[XLA:GPU] Enable cuDNN kernel for NVFP4 block scaled dot

Merging this change closes #22512

COPYBARA_INTEGRATE_REVIEW=#22512 from openxla:skozub/block_scaling_nvfp4 32e76a8
PiperOrigin-RevId: 725985050
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Feb 12, 2025
1 parent db17202 commit e96c270
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 23 deletions.
50 changes: 31 additions & 19 deletions xla/service/gpu/transforms/block_scaling_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,44 +198,51 @@ absl::StatusOr<HloInstruction*> ExpandDequantizeCustomCall(

// ----- Block scaled dot (cuDNN)

enum class CompositeType {
enum class CudnnMxType {
// Not a supported composite type.
CUSTOM_TYPE,
// Input: E4M3FN, scale: E8M0, block size: 32.
UNSUPPORTED_TYPE,
// Input: E4M3FN, scale: E8M0FNU, block size: 32.
MXFP8_E4M3FN,
// Input: E5M2, scale: E8M0, block size: 32.
// Input: E5M2, scale: E8M0FNU, block size: 32.
MXFP8_E5M2,
// Input: E2M1FN, scale: E4M3FN, block size: 16.
NVFP4,
};

CompositeType GetCompositeType(const Shape& input_shape,
const Shape& scale_shape) {
CudnnMxType GetCudnnMxType(const Shape& input_shape, const Shape& scale_shape) {
// Determine the block size from shapes.
int block_size = GetBlockSize(input_shape, scale_shape).value_or(0);

// MXFP8: the input could be either E4M3FN or E5M2.
if (input_shape.element_type() == PrimitiveType::F8E4M3FN &&
scale_shape.element_type() == PrimitiveType::F8E8M0FNU &&
block_size == BlockScalingRewriter::kBlockSizeMXFP8) {
return CompositeType::MXFP8_E4M3FN;
return CudnnMxType::MXFP8_E4M3FN;
}
if (input_shape.element_type() == PrimitiveType::F8E5M2 &&
scale_shape.element_type() == PrimitiveType::F8E8M0FNU &&
block_size == BlockScalingRewriter::kBlockSizeMXFP8) {
return CompositeType::MXFP8_E5M2;
return CudnnMxType::MXFP8_E5M2;
}

return CompositeType::CUSTOM_TYPE;
// NVFP4: the input is E2M1FN and the scale is E4M3FN.
if (input_shape.element_type() == PrimitiveType::F4E2M1FN &&
scale_shape.element_type() == PrimitiveType::F8E4M3FN &&
block_size == BlockScalingRewriter::kBlockSizeNVFP4) {
return CudnnMxType::NVFP4;
}

return CudnnMxType::UNSUPPORTED_TYPE;
}

bool IsSupportedByCudnn(CompositeType lhs, CompositeType rhs) {
bool IsSupportedByCudnn(CudnnMxType lhs, CudnnMxType rhs) {
// cuDNN supports mixing input types for MXFP8, but the E5M2/E5M2 combination
// is not supported.
return (lhs == CompositeType::MXFP8_E4M3FN &&
rhs == CompositeType::MXFP8_E4M3FN) ||
(lhs == CompositeType::MXFP8_E4M3FN &&
rhs == CompositeType::MXFP8_E5M2) ||
(lhs == CompositeType::MXFP8_E5M2 &&
rhs == CompositeType::MXFP8_E4M3FN);
return (lhs == CudnnMxType::MXFP8_E4M3FN &&
rhs == CudnnMxType::MXFP8_E4M3FN) ||
(lhs == CudnnMxType::MXFP8_E4M3FN && rhs == CudnnMxType::MXFP8_E5M2) ||
(lhs == CudnnMxType::MXFP8_E5M2 && rhs == CudnnMxType::MXFP8_E4M3FN) ||
(lhs == CudnnMxType::NVFP4 && rhs == CudnnMxType::NVFP4);
}

// Reshape inputs to shapes compatible with cuDNN.
Expand Down Expand Up @@ -293,7 +300,12 @@ absl::StatusOr<std::tuple<XlaOp, XlaOp, int64_t>> BuildScaledDotInputs(
scale_padding_config);
}

// Swizzle scales to match cuDNN kernel.
// Swizzle scales to match the cuDNN kernel.
//
// Transposing scales is necessary to match the `scale_vec::1X` layout in
// TMEM. This transpose can potentially be done in the kernel (at the cost of
// using non-vectorized loads or using an extra shared memory buffer).
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
TF_ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op));
int64_t scale_rows = scale_valid_shape.dimensions(1);
int64_t scale_cols = scale_valid_shape.dimensions(2);
Expand Down Expand Up @@ -369,8 +381,8 @@ absl::StatusOr<XlaOp> BuildBlockScaledDot(
// Use cuDNN kernel, if possible.
if (allow_cudnn && rhs_scale_op.valid() &&
IsSupportedByCudnn(
GetCompositeType(lhs_input->shape(), lhs_scale->shape()),
GetCompositeType(rhs_input->shape(), rhs_scale->shape()))) {
GetCudnnMxType(lhs_input->shape(), lhs_scale->shape()),
GetCudnnMxType(rhs_input->shape(), rhs_scale->shape()))) {
return BuildCudnnScaledDot(lhs_op, rhs_op, lhs_scale_op, rhs_scale_op,
dnums, result_type);
}
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/block_scaling_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class BlockScalingRewriter : public OpExpanderPass {

// Common block size constants.
static constexpr int kBlockSizeMXFP8 = 32;
static constexpr int kBlockSizeNVFP4 = 16;

private:
bool allow_cudnn_;
Expand Down
52 changes: 52 additions & 0 deletions xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ ENTRY main {
BlockScalingRewriter pass(/*allow_cudnn=*/true);
EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true));
}));

RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false),
"CHECK-NOT: __cudnn$blockScaledDot");
RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true),
"CHECK: __cudnn$blockScaledDot");
}

TEST_F(BlockScalingRewriterCudnnTest, Mxfp8_MixedTypes) {
Expand Down Expand Up @@ -77,6 +82,53 @@ ENTRY main {
BlockScalingRewriter pass(/*allow_cudnn=*/true);
EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true));
}));

RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false),
"CHECK-NOT: __cudnn$blockScaledDot");
RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true),
"CHECK: __cudnn$blockScaledDot");
}

// Scale E2M1FN inputs, as otherwise they become all zeros for the random
// distribution produced by the test due to low type precision.
// Use positive block scale values, as Blackwell MMA discards the sign bit on
// the scale tensor.
TEST_F(BlockScalingRewriterCudnnTest, Nvfp4) {
constexpr absl::string_view hlo_string = R"(
HloModule test
ENTRY main {
%mult_scalar = f16[] constant(6)
%mult = f16[256,256] broadcast(%mult_scalar), dimensions={}
%p0 = f16[256,256] parameter(0)
%p1 = f16[256,256] parameter(1)
%lhs = f4e2m1fn[256,256] convert(f16[256,256] multiply(%p0, %mult))
%rhs = f4e2m1fn[256,256] convert(f16[256,256] multiply(%p1, %mult))
%p2 = f8e4m3fn[256,16] parameter(2)
%p3 = f8e4m3fn[256,16] parameter(3)
%lhs_scale = f8e4m3fn[256,16] abs(%p2)
%rhs_scale = f8e4m3fn[256,16] abs(%p3)
ROOT %result = f32[256,256] custom-call(%lhs, %rhs, %lhs_scale, %rhs_scale),
custom_call_target="__op$block_scaled_dot"
})";

EXPECT_TRUE(RunAndCompare(
hlo_string, ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5),
/*reference_preprocessor=*/
[](HloModule* reference_module) {
BlockScalingRewriter pass(/*allow_cudnn=*/false);
EXPECT_THAT(RunHloPass(&pass, reference_module), IsOkAndHolds(true));
},
/*test_preprocessor=*/
[](HloModule* test_module) {
BlockScalingRewriter pass(/*allow_cudnn=*/true);
EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true));
}));

RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false),
"CHECK-NOT: __cudnn$blockScaledDot");
RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true),
"CHECK: __cudnn$blockScaledDot");
}

} // namespace
Expand Down
12 changes: 8 additions & 4 deletions xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,14 @@ absl::StatusOr<se::gpu::CudnnGraph> BuildGraphForCustomCallToBlockScaledDot(
return absl::InternalError("Unsupported data type for block scaled dot");
}

// cuDNN supports MXFP8 (block size 32, E8M0 scales).
TF_RET_CHECK(lhs_scale.type() == DataType::kF8E8M0FNU &&
rhs_scale.type() == DataType::kF8E8M0FNU);
const int block_size = BlockScalingRewriter::kBlockSizeMXFP8;
// cuDNN currently supports MXFP8 (block size 32, E8M0FNU scales) and NVFP4
// (block size 16, E4M3FN scales).
TF_RET_CHECK(lhs_scale.type() == rhs_scale.type());
TF_RET_CHECK(lhs_scale.type() == DataType::kF8E8M0FNU ||
lhs_scale.type() == DataType::kF8E4M3FN);
const int block_size = lhs_scale.type() == DataType::kF8E8M0FNU
? BlockScalingRewriter::kBlockSizeMXFP8
: BlockScalingRewriter::kBlockSizeNVFP4;

TF_ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph,
se::gpu::GetCudnnBlockScaledDotOperationGraph(
Expand Down

0 comments on commit e96c270

Please sign in to comment.