Skip to content

Commit

Permalink
Add a default value for write_frontend to avoid errors with old con…
Browse files Browse the repository at this point in the history
…figs. Also write/read LogitsOutputHead config files.

PiperOrigin-RevId: 612630662
  • Loading branch information
sdenton4 authored and copybara-github committed Mar 5, 2024
1 parent f9a4eeb commit 3746672
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 11 deletions.
10 changes: 5 additions & 5 deletions chirp/inference/embed_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(
write_logits: bool | Sequence[str],
write_separated_audio: bool,
write_raw_audio: bool,
write_frontend: bool,
model_key: str,
model_config: config_dict.ConfigDict,
file_id_depth: int,
Expand All @@ -169,6 +168,7 @@ def __init__(
target_sample_rate: int = -2,
logits_head_config: config_dict.ConfigDict | None = None,
tensor_dtype: str = 'float32',
write_frontend: bool = False,
):
"""Initialize the embedding DoFn.
Expand All @@ -178,8 +178,6 @@ def __init__(
logit keys to write.
write_separated_audio: Whether to write out separated audio tracks.
write_raw_audio: If true, will add the original audio to the output.
write_frontend: If true, will add the model's frontend (spectrogram) to
the output.
model_key: String indicating which model wrapper to use. See MODEL_KEYS.
Only used for setting up the embedding model.
model_config: Keyword arg dictionary for the model wrapper class. Only
Expand All @@ -197,6 +195,8 @@ def __init__(
interface.LogitsOutputHead classifying the model embeddings.
tensor_dtype: Dtype to use for storing tensors (embeddings, logits, or
audio). Default to float32, but float16 approximately halves file size.
write_frontend: If true, will add the model's frontend (spectrogram) to
the output. (Defaults False for backwards compatibility.)
"""
self.model_key = model_key
self.model_config = model_config
Expand Down Expand Up @@ -381,10 +381,10 @@ def maybe_write_config(parsed_config, output_dir):
f.write(config_json)


def load_embedding_config(embeddings_path):
def load_embedding_config(embeddings_path, filename: str = 'config.json'):
"""Loads the configuration to generate unlabeled embeddings."""
embeddings_path = epath.Path(embeddings_path)
with (embeddings_path / 'config.json').open() as f:
with (embeddings_path / filename).open() as f:
embedding_config = config_dict.ConfigDict(json.loads(f.read()))
return embedding_config

Expand Down
16 changes: 16 additions & 0 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Interface for models producing embeddings."""

import dataclasses
import json
from typing import Any, Callable, Dict

from absl import logging
Expand Down Expand Up @@ -181,6 +182,14 @@ class LogitsOutputHead:
class_list: namespace.ClassList
channel_pooling: str = 'max'

@classmethod
def from_config_file(cls, model_path: str, filename='logits_config.json'):
config_filepath = epath.Path(model_path) / filename
with (config_filepath).open() as f:
logits_config = config_dict.ConfigDict(json.loads(f.read()))
logits_config.model_path = model_path
return cls.from_config(logits_config)

@classmethod
def from_config(cls, config: config_dict.ConfigDict):
logits_model = tf.saved_model.load(config.model_path)
Expand All @@ -198,6 +207,13 @@ def save_model(self, output_path: str, embeddings_path: str):
# Write the model.
tf.saved_model.save(self.logits_model, output_path)
output_path = epath.Path(output_path)
# Write a config file.
config_data = dataclasses.asdict(self)
for k in ['logits_model', 'class_list']:
# These are loaded automatically.
config_data.pop(k)
with (output_path / 'logits_config.json').open('w') as f:
json.dump(config_data, f)
# Copy the embeddings_config if provided
if embeddings_path:
(epath.Path(embeddings_path) / 'config.json').copy(
Expand Down
2 changes: 1 addition & 1 deletion chirp/inference/tf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def model_outputs_to_tf_example(
write_logits: bool | Sequence[str],
write_separated_audio: bool,
write_raw_audio: bool,
write_frontend: bool,
write_frontend: bool = False,
tensor_dtype: str = 'float32',
) -> tf.train.Example:
"""Create a TFExample from InferenceOutputs."""
Expand Down
39 changes: 34 additions & 5 deletions chirp/tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,37 @@ def test_embed_fn(
else:
self.assertEqual(got_example[tf_examples.RAW_AUDIO].shape, (0,))

@parameterized.parameters(
{'config_filename': 'embedding_config_v0'},
# Includes frontend handling options.
{'config_filename': 'embedding_config_v1'},
)
def test_embed_fn_from_config(self, config_filename):
# Test that we can load a model from a golden config and compute embeddings.
test_config_path = os.fspath(
path_utils.get_absolute_path(f'tests/testdata/{config_filename}.json')
)
embed_config = embed_lib.load_embedding_config(test_config_path, '')
embed_fn = embed_lib.EmbedFn(**embed_config)
embed_fn.setup()
self.assertIsNotNone(embed_fn.embedding_model)

test_wav_path = os.fspath(
path_utils.get_absolute_path(
'tests/testdata/tfds_builder_wav_directory_test/clap.wav'
)
)
source_info = embed_lib.SourceInfo(test_wav_path, 0, 10)
example = embed_fn.process(source_info, crop_s=10.0)[0]
serialized = example.SerializeToString()

parser = tf_examples.get_example_parser(
logit_names=['label', 'other_label'],
tensor_dtype=embed_config.tensor_dtype,
)
got_example = parser(serialized)
self.assertIsNotNone(got_example)

def test_embed_fn_source_variations(self):
"""Test processing with variations of SourceInfo."""
model_kwargs = {
Expand Down Expand Up @@ -303,11 +334,9 @@ def test_logits_output_head(self):
# Save and restore the model.
with tempfile.TemporaryDirectory() as logits_model_dir:
logits_model.save_model(logits_model_dir, '')
restore_config = config_dict.ConfigDict({
'model_path': logits_model_dir,
'logits_key': 'other_label',
})
restored_model = interface.LogitsOutputHead.from_config(restore_config)
restored_model = interface.LogitsOutputHead.from_config_file(
logits_model_dir
)
reupdated_outputs = restored_model.add_logits(
base_outputs, keep_original=True
)
Expand Down
16 changes: 16 additions & 0 deletions chirp/tests/testdata/embedding_config_v0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"write_embeddings": true,
"write_logits": false,
"write_separated_audio": false,
"write_raw_audio": false,
"file_id_depth": 0,
"model_key": "placeholder_model",
"tensor_dtype": "float32",
"model_config": {
"sample_rate": 16000,
"embedding_size": 128,
"make_embeddings": true,
"make_logits": false,
"make_separated_audio": false
}
}
18 changes: 18 additions & 0 deletions chirp/tests/testdata/embedding_config_v1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"write_embeddings": true,
"write_logits": false,
"write_separated_audio": false,
"write_raw_audio": false,
"write_frontend": true,
"file_id_depth": 0,
"model_key": "placeholder_model",
"tensor_dtype": "float32",
"model_config": {
"sample_rate": 16000,
"embedding_size": 128,
"make_embeddings": true,
"make_logits": false,
"make_separated_audio": false,
"make_frontend": true
}
}

0 comments on commit 3746672

Please sign in to comment.