From 8af934a4e52118c10970ea834ba4e47ce7a16065 Mon Sep 17 00:00:00 2001 From: Vincent Dumoulin Date: Tue, 28 Mar 2023 05:05:26 -0700 Subject: [PATCH] Improve classwise metric logging PiperOrigin-RevId: 519994859 --- chirp/models/cmap.py | 12 +++------ chirp/train/classifier.py | 51 ++++++++++++++++------------------- chirp/train/hubert.py | 57 +++++++++++++++++++-------------------- chirp/train/separator.py | 33 +++++++++++++---------- chirp/train/utils.py | 2 +- 5 files changed, 74 insertions(+), 81 deletions(-) diff --git a/chirp/models/cmap.py b/chirp/models/cmap.py index 1087781a..15964eaf 100644 --- a/chirp/models/cmap.py +++ b/chirp/models/cmap.py @@ -44,16 +44,12 @@ def compute(self, sample_threshold: int = 0): # Same as sklearn's average_precision_score(label, logits, average=None) # but that implementation doesn't scale to 10k+ classes class_aps = metrics.average_precision( - values["label_logits"][:, mask].T, values["label"][:, mask].T + values["label_logits"].T, values["label"].T ) + class_aps = jnp.where(mask, class_aps, jnp.nan) return { - "macro": jnp.mean(class_aps), - **{ - str(i): ap - for i, ap in zip( - jnp.arange(values["label"].shape[1])[mask], class_aps - ) - }, + "macro": jnp.mean(class_aps, where=mask), + "individual": class_aps, } diff --git a/chirp/train/classifier.py b/chirp/train/classifier.py index 66dacc99..c5202f8e 100644 --- a/chirp/train/classifier.py +++ b/chirp/train/classifier.py @@ -239,16 +239,20 @@ def update_step(key, batch, train_state): train_metrics, train_state = update_step(step_key, batch, train_state) if step % log_every_steps == 0: - train_metrics = flax_utils.unreplicate(train_metrics).compute() + train_metrics = utils.flatten_dict( + flax_utils.unreplicate(train_metrics).compute() + ) - metrics_kept = {} - for k, v in train_metrics.items(): - if "xentropy" in k and not add_class_wise_metrics: - continue - metrics_kept[k] = v - train_metrics = metrics_kept + classwise_metrics = { + k: v for k, v in train_metrics.items() if "individual" in k + } + train_metrics = { + k: v for k, v in train_metrics.items() if k not in classwise_metrics + } - writer.write_scalars(step, utils.flatten_dict(train_metrics)) + writer.write_scalars(step, train_metrics) + if add_class_wise_metrics: + writer.write_summaries(step, classwise_metrics) reporter(step) if (step + 1) % checkpoint_every_steps == 0 or step == num_train_steps: @@ -361,26 +365,17 @@ def remainder_batch_fn(x): break # Log validation loss - valid_metrics = valid_metrics.compute() - - if not add_class_wise_metrics: - metrics_kept = {} - for k, v in valid_metrics.items(): - if "xentropy" in k: - # Only the class-wise xentropy metrics contain the string 'xentropy'; - # the key corresponding to overall xentropy is called 'loss'. - continue - metrics_kept[k] = v - valid_metrics = metrics_kept - - for k, v in valid_metrics.items(): - # Only one of the keys of valid_metrics will contain the string 'cmap', - # and the associated value is a dict that has a 'macro' key as well as - # a key per class. To disable class-wise metrics, we keep only 'macro'. - if "_cmap" in k: - valid_metrics[k] = v["macro"] - - writer.write_scalars(step, utils.flatten_dict(valid_metrics)) + valid_metrics = utils.flatten_dict(valid_metrics.compute()) + classwise_metrics = { + k: v for k, v in valid_metrics.items() if "individual" in k + } + valid_metrics = { + k: v for k, v in valid_metrics.items() if k not in classwise_metrics + } + + writer.write_scalars(step, valid_metrics) + if add_class_wise_metrics: + writer.write_summaries(step, classwise_metrics) writer.flush() diff --git a/chirp/train/hubert.py b/chirp/train/hubert.py index 09c92b1b..a5d7c121 100644 --- a/chirp/train/hubert.py +++ b/chirp/train/hubert.py @@ -855,16 +855,20 @@ def step(params, model_state): ) if step % log_every_steps == 0: - train_metrics = flax_utils.unreplicate(train_metrics).compute() + train_metrics = utils.flatten_dict( + flax_utils.unreplicate(train_metrics).compute() + ) - metrics_kept = {} - for k, v in train_metrics.items(): - if "xentropy" in k and not add_class_wise_metrics: - continue - metrics_kept[k] = v - train_metrics = metrics_kept + classwise_metrics = { + k: v for k, v in train_metrics.items() if "individual" in k + } + train_metrics = { + k: v for k, v in train_metrics.items() if k not in classwise_metrics + } - writer.write_scalars(step, utils.flatten_dict(train_metrics)) + writer.write_scalars(step, train_metrics) + if add_class_wise_metrics: + writer.write_summaries(step, classwise_metrics) reporter(step) if (step + 1) % checkpoint_every_steps == 0 or step == num_train_steps: @@ -955,7 +959,7 @@ def get_metrics(batch, train_state, mask_key): step = int(flax_utils.unreplicate(train_state.step)) key = model_bundle.key with reporter.timed("eval"): - valid_metrics = flax_utils.replicate(valid_metrics_collection.empty()) + valid_metrics = valid_metrics_collection.empty() for s, batch in enumerate(valid_dataset.as_numpy_iterator()): batch = jax.tree_map(np.asarray, batch) mask_key = None @@ -963,7 +967,9 @@ def get_metrics(batch, train_state, mask_key): mask_key, key = random.split(key) mask_key = random.split(mask_key, num=jax.local_device_count()) new_valid_metrics = get_metrics(batch, train_state, mask_key) - valid_metrics = valid_metrics.merge(new_valid_metrics) + valid_metrics = valid_metrics.merge( + flax_utils.unreplicate(new_valid_metrics) + ) if ( eval_steps_per_checkpoint is not None and s >= eval_steps_per_checkpoint @@ -971,26 +977,17 @@ def get_metrics(batch, train_state, mask_key): break # Log validation loss - valid_metrics = flax_utils.unreplicate(valid_metrics).compute() - - if not add_class_wise_metrics: - metrics_kept = {} - for k, v in valid_metrics.items(): - if "xentropy" in k: - # Only the class-wise xentropy metrics contain the string 'xentropy'; - # the key corresponding to overall xentropy is called 'loss'. - continue - metrics_kept[k] = v - valid_metrics = metrics_kept - - for k, v in valid_metrics.items(): - # Only one of the keys of valid_metrics will contain the string 'cmap', - # and the associated value is a dict that has a 'macro' key as well as - # a key per class. To disable class-wise metrics, we keep only 'macro'. - if "_cmap" in k: - valid_metrics[k] = v["macro"] - - writer.write_scalars(step, utils.flatten_dict(valid_metrics)) + valid_metrics = utils.flatten_dict(valid_metrics.compute()) + classwise_metrics = { + k: v for k, v in valid_metrics.items() if "individual" in k + } + valid_metrics = { + k: v for k, v in valid_metrics.items() if k not in classwise_metrics + } + + writer.write_scalars(step, valid_metrics) + if add_class_wise_metrics: + writer.write_summaries(step, classwise_metrics) writer.flush() diff --git a/chirp/train/separator.py b/chirp/train/separator.py index 60e2c378..776dd79f 100644 --- a/chirp/train/separator.py +++ b/chirp/train/separator.py @@ -322,11 +322,13 @@ def get_metrics(batch, train_state): ) with reporter.timed('eval'): - valid_metrics = flax.jax_utils.replicate(valid_metrics_collection.empty()) + valid_metrics = valid_metrics_collection.empty() for valid_step, batch in enumerate(valid_dataset.as_numpy_iterator()): batch = jax.tree_map(np.asarray, batch) new_valid_metrics = get_metrics(batch, flax_utils.replicate(train_state)) - valid_metrics = valid_metrics.merge(new_valid_metrics) + valid_metrics = valid_metrics.merge( + flax_utils.unreplicate(new_valid_metrics) + ) if ( eval_steps_per_checkpoint > 0 and valid_step >= eval_steps_per_checkpoint @@ -334,18 +336,21 @@ def get_metrics(batch, train_state): break # Log validation loss - valid_metrics = flax_utils.unreplicate(valid_metrics).compute() - - if not add_class_wise_metrics: - metrics_kept = {} - for k, v in valid_metrics.items(): - if '_cmap_' in k and not v.endswith('_cmap_macro'): - # Discard metrics like 'valid_cmap_442' keeping only 'valid_cmap_macro'. - continue - metrics_kept[k] = v - valid_metrics = metrics_kept - valid_metrics = {k.replace('___', '/'): v for k, v in valid_metrics.items()} - writer.write_scalars(int(train_state.step), utils.flatten_dict(valid_metrics)) + valid_metrics = valid_metrics.compute() + + valid_metrics = utils.flatten_dict( + {k.replace('___', '/'): v for k, v in valid_metrics.items()} + ) + classwise_metrics = { + k: v for k, v in valid_metrics.items() if 'individual' in k + } + valid_metrics = { + k: v for k, v in valid_metrics.items() if k not in classwise_metrics + } + + writer.write_scalars(int(train_state.step), valid_metrics) + if add_class_wise_metrics: + writer.write_summaries(int(train_state.step), classwise_metrics) writer.flush() diff --git a/chirp/train/utils.py b/chirp/train/utils.py index ee0a7114..210396bf 100644 --- a/chirp/train/utils.py +++ b/chirp/train/utils.py @@ -102,7 +102,7 @@ def compute(self): averages = self.total / self.count return { "mean": jnp.sum(self.total) / jnp.sum(self.count), - **{str(i): averages[i] for i in range(jnp.size(averages))}, + "individual": averages, }