From 1b9ab7ad99f30a7cf42d7e634d9126212e01929a Mon Sep 17 00:00:00 2001 From: xla authors Date: Fri, 14 Feb 2025 14:45:25 -0800 Subject: [PATCH] Support optimization_level and memory_fitting_level XLA compilation options. PiperOrigin-RevId: 727070422 --- xla/python/xla_client.py | 2 +- xla/python/xla_compiler.cc | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 71c15643e4cfa..9cf6e8b6c47bd 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 315 +_version = 316 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/python/xla_compiler.cc b/xla/python/xla_compiler.cc index 606aa83de8da9..a809b0746f008 100644 --- a/xla/python/xla_compiler.cc +++ b/xla/python/xla_compiler.cc @@ -1378,6 +1378,18 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def_prop_rw("memory_fitting_effort", &ExecutableBuildOptions::memory_fitting_effort, &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) .def_prop_rw("use_spmd_partitioning", &ExecutableBuildOptions::use_spmd_partitioning, &ExecutableBuildOptions::set_use_spmd_partitioning)