Skip to content

Commit

Permalink
In cudnn_fused_conv_rewriter.h, allow clamp to omitted when convertin…
Browse files Browse the repository at this point in the history
…g f32 to s8.

As an implementation detail, XLA already clamps when converting float to int, so it's ok to pattern-match a fused_conv_outputting_f32->convert_to_s8 into a fused_conv_outputting_s8, even without a clamp in between the fused conv and convert.

Even so, I would still recommend users have a clamp in their code, since the implicit clamping behavior is unspecified.

PiperOrigin-RevId: 726550567
  • Loading branch information
reedwm authored and Google-ML-Automation committed Feb 13, 2025
1 parent 95cd71a commit 5d0a237
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
18 changes: 17 additions & 1 deletion xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,

PrimitiveType conv_output_ty;
if (MatchAndLogIfFailed(
instr, "s8->s8 conv",
instr, "s8->s8 conv with clamp",
m::Convert(m::Clamp(m::Broadcast(m::ConstantEffectiveScalar(-128)),
m::GetTupleElement(
&gte,
Expand All @@ -1301,6 +1301,22 @@ absl::StatusOr<bool> FuseConvertToS8(HloComputation* comp,
m::Op()))
.WithElementType(S8))) {
conv_output_ty = S8;
} else if (MatchAndLogIfFailed(
instr, "s8->s8 conv without clamp",
m::Convert(m::GetTupleElement(
&gte,
conv_pattern.WithOperandIfPresent(
3, m::Op().WithPredicate(
IsLosslesslyConvertibleToS8)),
0)
.WithElementType(F32)
.WithOneUse())
.WithElementType(S8),
VLOG_IS_ON(3),
m::Convert(m::GetTupleElement(
m::Op().WithPredicate(IsConvCustomCall)))
.WithElementType(S8))) {
conv_output_ty = S8;
} else if (MatchAndLogIfFailed(
instr, "s8->f32 conv",
m::GetTupleElement(&gte,
Expand Down
5 changes: 5 additions & 0 deletions xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ namespace gpu {
// In the `result_s8` case where there's no bias, side-input, or alpha1, you can
// skip the convert_f32 on conv.
//
// In the `result_s8` case, you can skip the clamp as long as the convert_f32
// is not skipped. The reason is XLA implicitly clamps when converting from
// float to int (although this is an implementation detail and not guaranteed by
// the spec.)
//
// If you have an integer convolution that doesn't fit one of these idioms, this
// pass returns an error -- cudnn will not be able to run it.
class CudnnFusedConvRewriter : public HloModulePass {
Expand Down
41 changes: 41 additions & 0 deletions xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,47 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
.WithShape(S8, {1, 32, 9, 9})));
}

TEST_F(CudnnFusedConvRewriterHloTest,
TestConvInt8ToInt8BiasSideInputWithoutClamp) {
MAYBE_SKIP_TEST("I8");
const std::string module_str = R"(
HloModule Test

ENTRY Test {
input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))

conv = s32[1,32,9,9] convolution(input, filter),
window={size=3x3 pad=1_1x1_1},
dim_labels=bf01_01io->bf01
conv_f32 = f32[1,32,9,9] convert(conv)
ROOT root = s8[1,32,9,9] convert(add(add(conv_f32, bias), side_input))
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));

ConvRewriter rewriter{GetCudaComputeCapability()};
TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(),
GetToolkitVersion()};
TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());

// Simplify new `convert`'s that may be added to the graph.
AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());

SCOPED_TRACE(m->ToString());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::GetTupleElement(
m::CustomCall({kCudnnConvBiasActivationForwardCallTarget},
m::Parameter(0), m::Parameter(1),
m::Parameter(2), m::Parameter(3)),
0)
.WithShape(S8, {1, 32, 9, 9})));
}

TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
MAYBE_SKIP_TEST("I8");
const std::string module_str = R"(
Expand Down

0 comments on commit 5d0a237

Please sign in to comment.