Skip to content

Commit

Permalink
HloModule is converted to proto in MlirToXlaComputation so need to ma…
Browse files Browse the repository at this point in the history
…ke sure result accuracy is kept.

PiperOrigin-RevId: 726178797
  • Loading branch information
hanrach9 authored and Google-ML-Automation committed Feb 12, 2025
1 parent 5ef6822 commit b6e6e17
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
4 changes: 4 additions & 0 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4308,6 +4308,10 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_original_value() = OriginalValueToProto(*original_value_);
}

if (has_result_accuracy()) {
*proto.mutable_result_accuracy() = result_accuracy();
}

return proto;
}

Expand Down
21 changes: 20 additions & 1 deletion xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5915,7 +5915,7 @@ HloModule bitcast_to_smaller
ENTRY main {
p = s32[10] parameter(0)
ROOT out = s8[10,4] bitcast-convert(p), result_accuracy={tolerance={rtol=0.5, atol=1.0, ulps=2}
ROOT out = s8[10,4] bitcast-convert(p), result_accuracy={tolerance={rtol=0.5, atol=1.0, ulps=2}}
}
)";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
Expand Down Expand Up @@ -5976,5 +5976,24 @@ TEST_F(HloParserTest,
OriginalValueToString(*wrapped_instr->original_value()));
}

TEST_F(HloParserTest, ResultAccuracyToProto) {
const char* const hlo_string = R"(
HloModule exponential_hw
ENTRY exponential_hw {
%exponent = f32[] parameter(0)
ROOT %exponential = f32[] exponential(f32[] %exponent), result_accuracy={tolerance={rtol=0.5, atol=1.0, ulps=2}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
HloInstruction* exp_hlo_instruction =
module->entry_computation()->root_instruction();
HloInstructionProto exp_hlo_inst_proto = exp_hlo_instruction->ToProto();
EXPECT_TRUE(exp_hlo_inst_proto.has_result_accuracy());
EXPECT_EQ(exp_hlo_inst_proto.result_accuracy().tolerance().rtol(), 0.5);
EXPECT_EQ(exp_hlo_inst_proto.result_accuracy().tolerance().atol(), 1.0);
}

} // namespace
} // namespace xla

0 comments on commit b6e6e17

Please sign in to comment.