From e96c27095f8619137cdacdeebdd4055058f393d8 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 12 Feb 2025 03:25:23 -0800 Subject: [PATCH] PR #22512: [XLA:GPU] Enable cuDNN kernel for NVFP4 block scaled dot Imported from GitHub PR https://github.com/openxla/xla/pull/22512 Support NVFP4 in addition to MXFP8 hardware acceleration for the "__op$block_scaled_dot" custom call. This PR also addresses some nits from the internal review (like renaming a generic `CompositeType` to a more specific `CudnnMxType`). Copybara import of the project: -- 32e76a88b2107c079e26826417d22664cbf809a3 by Sergey Kozub : [XLA:GPU] Enable cuDNN kernel for NVFP4 block scaled dot Merging this change closes #22512 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/22512 from openxla:skozub/block_scaling_nvfp4 32e76a88b2107c079e26826417d22664cbf809a3 PiperOrigin-RevId: 725985050 --- .../gpu/transforms/block_scaling_rewriter.cc | 50 +++++++++++------- .../gpu/transforms/block_scaling_rewriter.h | 1 + .../block_scaling_rewriter_cudnn_test.cc | 52 +++++++++++++++++++ .../transforms/cudnn_custom_call_compiler.cc | 12 +++-- 4 files changed, 92 insertions(+), 23 deletions(-) diff --git a/xla/service/gpu/transforms/block_scaling_rewriter.cc b/xla/service/gpu/transforms/block_scaling_rewriter.cc index bf05f8023b144..23bc6bba46133 100644 --- a/xla/service/gpu/transforms/block_scaling_rewriter.cc +++ b/xla/service/gpu/transforms/block_scaling_rewriter.cc @@ -198,17 +198,18 @@ absl::StatusOr ExpandDequantizeCustomCall( // ----- Block scaled dot (cuDNN) -enum class CompositeType { +enum class CudnnMxType { // Not a supported composite type. - CUSTOM_TYPE, - // Input: E4M3FN, scale: E8M0, block size: 32. + UNSUPPORTED_TYPE, + // Input: E4M3FN, scale: E8M0FNU, block size: 32. MXFP8_E4M3FN, - // Input: E5M2, scale: E8M0, block size: 32. + // Input: E5M2, scale: E8M0FNU, block size: 32. MXFP8_E5M2, + // Input: E2M1FN, scale: E4M3FN, block size: 16. + NVFP4, }; -CompositeType GetCompositeType(const Shape& input_shape, - const Shape& scale_shape) { +CudnnMxType GetCudnnMxType(const Shape& input_shape, const Shape& scale_shape) { // Determine the block size from shapes. int block_size = GetBlockSize(input_shape, scale_shape).value_or(0); @@ -216,26 +217,32 @@ CompositeType GetCompositeType(const Shape& input_shape, if (input_shape.element_type() == PrimitiveType::F8E4M3FN && scale_shape.element_type() == PrimitiveType::F8E8M0FNU && block_size == BlockScalingRewriter::kBlockSizeMXFP8) { - return CompositeType::MXFP8_E4M3FN; + return CudnnMxType::MXFP8_E4M3FN; } if (input_shape.element_type() == PrimitiveType::F8E5M2 && scale_shape.element_type() == PrimitiveType::F8E8M0FNU && block_size == BlockScalingRewriter::kBlockSizeMXFP8) { - return CompositeType::MXFP8_E5M2; + return CudnnMxType::MXFP8_E5M2; } - return CompositeType::CUSTOM_TYPE; + // NVFP4: the input is E2M1FN and the scale is E4M3FN. + if (input_shape.element_type() == PrimitiveType::F4E2M1FN && + scale_shape.element_type() == PrimitiveType::F8E4M3FN && + block_size == BlockScalingRewriter::kBlockSizeNVFP4) { + return CudnnMxType::NVFP4; + } + + return CudnnMxType::UNSUPPORTED_TYPE; } -bool IsSupportedByCudnn(CompositeType lhs, CompositeType rhs) { +bool IsSupportedByCudnn(CudnnMxType lhs, CudnnMxType rhs) { // cuDNN supports mixing input types for MXFP8, but the E5M2/E5M2 combination // is not supported. - return (lhs == CompositeType::MXFP8_E4M3FN && - rhs == CompositeType::MXFP8_E4M3FN) || - (lhs == CompositeType::MXFP8_E4M3FN && - rhs == CompositeType::MXFP8_E5M2) || - (lhs == CompositeType::MXFP8_E5M2 && - rhs == CompositeType::MXFP8_E4M3FN); + return (lhs == CudnnMxType::MXFP8_E4M3FN && + rhs == CudnnMxType::MXFP8_E4M3FN) || + (lhs == CudnnMxType::MXFP8_E4M3FN && rhs == CudnnMxType::MXFP8_E5M2) || + (lhs == CudnnMxType::MXFP8_E5M2 && rhs == CudnnMxType::MXFP8_E4M3FN) || + (lhs == CudnnMxType::NVFP4 && rhs == CudnnMxType::NVFP4); } // Reshape inputs to shapes compatible with cuDNN. @@ -293,7 +300,12 @@ absl::StatusOr> BuildScaledDotInputs( scale_padding_config); } - // Swizzle scales to match cuDNN kernel. + // Swizzle scales to match the cuDNN kernel. + // + // Transposing scales is necessary to match the `scale_vec::1X` layout in + // TMEM. This transpose can potentially be done in the kernel (at the cost of + // using non-vectorized loads or using an extra shared memory buffer). + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x TF_ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op)); int64_t scale_rows = scale_valid_shape.dimensions(1); int64_t scale_cols = scale_valid_shape.dimensions(2); @@ -369,8 +381,8 @@ absl::StatusOr BuildBlockScaledDot( // Use cuDNN kernel, if possible. if (allow_cudnn && rhs_scale_op.valid() && IsSupportedByCudnn( - GetCompositeType(lhs_input->shape(), lhs_scale->shape()), - GetCompositeType(rhs_input->shape(), rhs_scale->shape()))) { + GetCudnnMxType(lhs_input->shape(), lhs_scale->shape()), + GetCudnnMxType(rhs_input->shape(), rhs_scale->shape()))) { return BuildCudnnScaledDot(lhs_op, rhs_op, lhs_scale_op, rhs_scale_op, dnums, result_type); } diff --git a/xla/service/gpu/transforms/block_scaling_rewriter.h b/xla/service/gpu/transforms/block_scaling_rewriter.h index d6940e48dd3e6..ecdb1ad60e8a4 100644 --- a/xla/service/gpu/transforms/block_scaling_rewriter.h +++ b/xla/service/gpu/transforms/block_scaling_rewriter.h @@ -84,6 +84,7 @@ class BlockScalingRewriter : public OpExpanderPass { // Common block size constants. static constexpr int kBlockSizeMXFP8 = 32; + static constexpr int kBlockSizeNVFP4 = 16; private: bool allow_cudnn_; diff --git a/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc b/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc index 2f8304255754a..7f4e744671707 100644 --- a/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc +++ b/xla/service/gpu/transforms/block_scaling_rewriter_cudnn_test.cc @@ -50,6 +50,11 @@ ENTRY main { BlockScalingRewriter pass(/*allow_cudnn=*/true); EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true)); })); + + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false), + "CHECK-NOT: __cudnn$blockScaledDot"); + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true), + "CHECK: __cudnn$blockScaledDot"); } TEST_F(BlockScalingRewriterCudnnTest, Mxfp8_MixedTypes) { @@ -77,6 +82,53 @@ ENTRY main { BlockScalingRewriter pass(/*allow_cudnn=*/true); EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true)); })); + + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false), + "CHECK-NOT: __cudnn$blockScaledDot"); + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true), + "CHECK: __cudnn$blockScaledDot"); +} + +// Scale E2M1FN inputs, as otherwise they become all zeros for the random +// distribution produced by the test due to low type precision. +// Use positive block scale values, as Blackwell MMA discards the sign bit on +// the scale tensor. +TEST_F(BlockScalingRewriterCudnnTest, Nvfp4) { + constexpr absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + %mult_scalar = f16[] constant(6) + %mult = f16[256,256] broadcast(%mult_scalar), dimensions={} + %p0 = f16[256,256] parameter(0) + %p1 = f16[256,256] parameter(1) + %lhs = f4e2m1fn[256,256] convert(f16[256,256] multiply(%p0, %mult)) + %rhs = f4e2m1fn[256,256] convert(f16[256,256] multiply(%p1, %mult)) + %p2 = f8e4m3fn[256,16] parameter(2) + %p3 = f8e4m3fn[256,16] parameter(3) + %lhs_scale = f8e4m3fn[256,16] abs(%p2) + %rhs_scale = f8e4m3fn[256,16] abs(%p3) + ROOT %result = f32[256,256] custom-call(%lhs, %rhs, %lhs_scale, %rhs_scale), + custom_call_target="__op$block_scaled_dot" +})"; + + EXPECT_TRUE(RunAndCompare( + hlo_string, ErrorSpec(/*aabs=*/1e-4, /*arel=*/1e-5), + /*reference_preprocessor=*/ + [](HloModule* reference_module) { + BlockScalingRewriter pass(/*allow_cudnn=*/false); + EXPECT_THAT(RunHloPass(&pass, reference_module), IsOkAndHolds(true)); + }, + /*test_preprocessor=*/ + [](HloModule* test_module) { + BlockScalingRewriter pass(/*allow_cudnn=*/true); + EXPECT_THAT(RunHloPass(&pass, test_module), IsOkAndHolds(true)); + })); + + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(false), + "CHECK-NOT: __cudnn$blockScaledDot"); + RunAndFilecheckHloRewrite(hlo_string, BlockScalingRewriter(true), + "CHECK: __cudnn$blockScaledDot"); } } // namespace diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 18d5f39ffb51b..4c15388a0a49a 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -414,10 +414,14 @@ absl::StatusOr BuildGraphForCustomCallToBlockScaledDot( return absl::InternalError("Unsupported data type for block scaled dot"); } - // cuDNN supports MXFP8 (block size 32, E8M0 scales). - TF_RET_CHECK(lhs_scale.type() == DataType::kF8E8M0FNU && - rhs_scale.type() == DataType::kF8E8M0FNU); - const int block_size = BlockScalingRewriter::kBlockSizeMXFP8; + // cuDNN currently supports MXFP8 (block size 32, E8M0FNU scales) and NVFP4 + // (block size 16, E4M3FN scales). + TF_RET_CHECK(lhs_scale.type() == rhs_scale.type()); + TF_RET_CHECK(lhs_scale.type() == DataType::kF8E8M0FNU || + lhs_scale.type() == DataType::kF8E4M3FN); + const int block_size = lhs_scale.type() == DataType::kF8E8M0FNU + ? BlockScalingRewriter::kBlockSizeMXFP8 + : BlockScalingRewriter::kBlockSizeNVFP4; TF_ASSIGN_OR_RETURN(se::gpu::CudnnGraph graph, se::gpu::GetCudnnBlockScaledDotOperationGraph(