Skip to content

Commit

Permalink
Migrate fuzz tests to always use PjRt.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725844097
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Feb 12, 2025
1 parent a98b259 commit 8c3aea8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
16 changes: 13 additions & 3 deletions xla/tests/fuzz/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,25 @@ cc_library(
srcs = ["hlo_test_lib.cc"],
deps = [
"//xla:error_spec",
"//xla/tests:hlo_test_base",
"@tsl//tsl/platform:env",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:hlo_module_config",
"//xla/tests:hlo_pjrt_interpreter_reference_mixin",
"//xla/tests:hlo_pjrt_test_base",
"//xla/tsl/platform:env",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
],
)

[hlo_test(
name = hlo + "_test",
hlo = hlo,
tags = (["cuda-only"] if hlo == "rand_000079.hlo" else []), # No int8
tags = (
["cuda-only"] if hlo == "rand_000079.hlo" else [] # No int8
) + [
"test_migrated_to_hlo_runner_pjrt",
],
) for hlo in glob(
include = ["rand_*.hlo"],
exclude = [
Expand Down
15 changes: 11 additions & 4 deletions xla/tests/fuzz/hlo_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@ limitations under the License.

#include <cstdlib>
#include <iostream>
#include <memory>
#include <ostream>
#include <string>
#include <utility>

#include "xla/error_spec.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/env.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
#include "xla/tests/hlo_pjrt_test_base.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"

namespace xla {
namespace {

class HloTest : public HloTestBase {};
class HloTest : public HloPjRtInterpreterReferenceMixin<HloPjRtTestBase> {};

TEST_F(HloTest, HloTest) {
std::string path_to_hlo = std::getenv("HLO_PATH");
Expand All @@ -36,7 +43,7 @@ TEST_F(HloTest, HloTest) {
std::cerr << hlo << std::endl;
HloModuleConfig config;

TF_ASSERT_OK_AND_ASSIGN(auto module,
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo, config));
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
}
Expand Down

0 comments on commit 8c3aea8

Please sign in to comment.