diff --git a/xla/backends/cpu/nanort/ifrt_client_test.cc b/xla/backends/cpu/nanort/ifrt_client_test.cc index a30c7854dc1e7..aa374a862714b 100644 --- a/xla/backends/cpu/nanort/ifrt_client_test.cc +++ b/xla/backends/cpu/nanort/ifrt_client_test.cc @@ -57,13 +57,13 @@ TEST(NanoIfrtClientTest, BigResult) { // A program that is likely to need some temporary buffers to be allocated. absl::string_view kBigResult = R"(module { - func.func @main(%arg: tensor) -> tensor<1024x1024xf32> { - %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[1024, 1024]> : tensor<2xi64>} : (tensor) -> tensor<1024x1024xf32> - %1 = "mhlo.add"(%0, %0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %2 = "mhlo.dot"(%1, %1) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - return %2 : tensor<1024x1024xf32> - } - })"; + func.func @main(%arg0: tensor) -> tensor<1024x1024xf32> { + %0 = stablehlo.broadcast %arg0, sizes = [1024, 1024] : (tensor) -> tensor<1024x1024xf32> + %1 = stablehlo.add %0, %0 : tensor<1024x1024xf32> + %2 = stablehlo.dot %1, %1 : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %2 : tensor<1024x1024xf32> + } + })"; auto client = NanoIfrtClient::Create(); auto compiler = client->GetDefaultCompiler(); @@ -142,7 +142,7 @@ static void BM_IfRtAddScalars(benchmark::State& state) { constexpr absl::string_view program = R"(module { func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.add %arg0, %arg1 : tensor + %0 = stablehlo.add %arg0, %arg1 : tensor return %0 : tensor } })"; @@ -185,16 +185,16 @@ static void BM_IfRtAddManyScalars(benchmark::State& state) { -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { - %0 = mhlo.add %arg0, %arg1 : tensor - %1 = mhlo.add %arg0, %0 : tensor - %2 = mhlo.add %arg0, %1 : tensor - %3 = mhlo.add %arg0, %2 : tensor - %4 = mhlo.add %arg0, %3 : tensor - %5 = mhlo.add %arg0, %4 : tensor - %6 = mhlo.add %arg0, %5 : tensor - %7 = mhlo.add %arg0, %6 : tensor - %8 = mhlo.add %arg0, %7 : tensor - %9 = mhlo.add %arg0, %8 : tensor + %0 = stablehlo.add %arg0, %arg1 : tensor + %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.add %arg0, %1 : tensor + %3 = stablehlo.add %arg0, %2 : tensor + %4 = stablehlo.add %arg0, %3 : tensor + %5 = stablehlo.add %arg0, %4 : tensor + %6 = stablehlo.add %arg0, %5 : tensor + %7 = stablehlo.add %arg0, %6 : tensor + %8 = stablehlo.add %arg0, %7 : tensor + %9 = stablehlo.add %arg0, %8 : tensor return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 374e306e5e686..801ed23db5cd8 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -292,11 +292,8 @@ cc_library( "//xla/hlo/transforms/simplifiers:tree_reduction_rewriter", "//xla/hlo/transforms/simplifiers:tuple_simplifier", "//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination", - "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:transforms_passes", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 2d01f7de78917..00bc0a05bc743 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -139,7 +139,6 @@ limitations under the License. #include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" #include "xla/hlo/transforms/while_loop_trip_count_annotator.h" -#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/literal.h" #include "xla/literal_pool.h" #include "xla/map_util.h"