Skip to content

Commit

Permalink
Allow configuring and testing XLA in model export.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723135244
  • Loading branch information
sdenton4 authored and copybara-github committed Feb 4, 2025
1 parent ac6e047 commit 827f0e1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
3 changes: 2 additions & 1 deletion chirp/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def export_tf_model(
tf_lite_dtype: str = "float16",
tf_lite_select_ops: bool = True,
export_dir: str | None = None,
enable_xla: bool = False,
):
"""Export SavedModel and TFLite."""
# Get model_ouput keys from output_head_metadatas and add the 'embedding' key
Expand Down Expand Up @@ -427,7 +428,7 @@ def infer_fn(audio_batch, variables):
else:
shape = (1,) + input_shape
converted_model = export_utils.Jax2TfModelWrapper(
infer_fn, variables, shape, False
infer_fn, variables, shape, enable_xla=enable_xla
)
class_lists = {
md.key: md.class_list for md in model_bundle.output_head_metadatas
Expand Down
11 changes: 9 additions & 2 deletions chirp/train_tests/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs):
"freq_range": (60, 10_000),
},
"atol": 1e-4,
"enable_xla": False,
},
{
"module_type": frontend.SimpleMelspec,
Expand Down Expand Up @@ -169,7 +170,12 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs):
},
)
def test_tflite_stft_export(
self, module_type, module_kwargs, signal_shape=None, atol=1e-6
self,
module_type,
module_kwargs,
signal_shape=None,
atol=1e-6,
enable_xla=False,
):
# Note that the TFLite stft requires power-of-two nfft, given by:
# nfft = 2 * (features - 1).
Expand All @@ -182,7 +188,8 @@ def test_tflite_stft_export(

tf_predict = tf.function(
jax2tf.convert(
lambda signal: fe.apply(params, signal), enable_xla=False
lambda signal: fe.apply(params, signal),
enable_xla=enable_xla,
),
input_signature=[
tf.TensorSpec(shape=signal.shape, dtype=tf.float32, name="input")
Expand Down
13 changes: 11 additions & 2 deletions chirp/train_tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,20 @@ def test_config_structure(self):
jax.tree_util.tree_structure(test_config.to_dict()),
)

def test_export_model(self):
@parameterized.named_parameters(
# Note that b0 tests tend to timeout.
# ("xla", True, False),
("no_xla", False, False),
)
def test_export_model(self, enable_xla, test_b0):
# NOTE: This test might fail when run on a machine that has a GPU but when
# CUDA is not linked (JAX will detect the GPU so jax2tf will try to create
# a TF graph on the GPU and fail)
config = self._get_test_config()
config = self._add_const_model_config(config)
if test_b0:
config = self._add_b0_model_config(config)
else:
config = self._add_const_model_config(config)
config = self._add_pcen_melspec_frontend(config)

model_bundle, train_state = classifier.initialize_model(
Expand All @@ -177,6 +185,7 @@ def test_export_model(self):
config.init_config.input_shape,
num_train_steps=0,
eval_sleep_s=0,
enable_xla=enable_xla,
)
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.train_dir, "model.tflite"))
Expand Down

0 comments on commit 827f0e1

Please sign in to comment.