diff --git a/chirp/train/classifier.py b/chirp/train/classifier.py index b07e0ae4..847c11d7 100644 --- a/chirp/train/classifier.py +++ b/chirp/train/classifier.py @@ -400,6 +400,7 @@ def export_tf_model( tf_lite_dtype: str = "float16", tf_lite_select_ops: bool = True, export_dir: str | None = None, + enable_xla: bool = False, ): """Export SavedModel and TFLite.""" # Get model_ouput keys from output_head_metadatas and add the 'embedding' key @@ -427,7 +428,7 @@ def infer_fn(audio_batch, variables): else: shape = (1,) + input_shape converted_model = export_utils.Jax2TfModelWrapper( - infer_fn, variables, shape, False + infer_fn, variables, shape, enable_xla=enable_xla ) class_lists = { md.key: md.class_list for md in model_bundle.output_head_metadatas diff --git a/chirp/train_tests/frontend_test.py b/chirp/train_tests/frontend_test.py index f4c26f3a..fd60fe6e 100644 --- a/chirp/train_tests/frontend_test.py +++ b/chirp/train_tests/frontend_test.py @@ -132,6 +132,7 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs): "freq_range": (60, 10_000), }, "atol": 1e-4, + "enable_xla": False, }, { "module_type": frontend.SimpleMelspec, @@ -169,7 +170,12 @@ def test_inverse(self, module_type, inverse_module_type, module_kwargs): }, ) def test_tflite_stft_export( - self, module_type, module_kwargs, signal_shape=None, atol=1e-6 + self, + module_type, + module_kwargs, + signal_shape=None, + atol=1e-6, + enable_xla=False, ): # Note that the TFLite stft requires power-of-two nfft, given by: # nfft = 2 * (features - 1). @@ -182,7 +188,8 @@ def test_tflite_stft_export( tf_predict = tf.function( jax2tf.convert( - lambda signal: fe.apply(params, signal), enable_xla=False + lambda signal: fe.apply(params, signal), + enable_xla=enable_xla, ), input_signature=[ tf.TensorSpec(shape=signal.shape, dtype=tf.float32, name="input") diff --git a/chirp/train_tests/train_test.py b/chirp/train_tests/train_test.py index 80d31010..0b06d942 100644 --- a/chirp/train_tests/train_test.py +++ b/chirp/train_tests/train_test.py @@ -157,12 +157,20 @@ def test_config_structure(self): jax.tree_util.tree_structure(test_config.to_dict()), ) - def test_export_model(self): + @parameterized.named_parameters( + # Note that b0 tests tend to timeout. + # ("xla", True, False), + ("no_xla", False, False), + ) + def test_export_model(self, enable_xla, test_b0): # NOTE: This test might fail when run on a machine that has a GPU but when # CUDA is not linked (JAX will detect the GPU so jax2tf will try to create # a TF graph on the GPU and fail) config = self._get_test_config() - config = self._add_const_model_config(config) + if test_b0: + config = self._add_b0_model_config(config) + else: + config = self._add_const_model_config(config) config = self._add_pcen_melspec_frontend(config) model_bundle, train_state = classifier.initialize_model( @@ -177,6 +185,7 @@ def test_export_model(self): config.init_config.input_shape, num_train_steps=0, eval_sleep_s=0, + enable_xla=enable_xla, ) self.assertTrue( tf.io.gfile.exists(os.path.join(self.train_dir, "model.tflite"))