Skip to content

Commit

Permalink
Internal only change for handling pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609621790
  • Loading branch information
Chirp Team authored and copybara-github committed Feb 23, 2024
1 parent bae62f8 commit 209b3f3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions chirp/models/separation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def time_reduce_logits(self, reduction: str = 'AVG') -> 'SeparatorOutput':
elif reduction == 'MAX':
reduce_fn = lambda x: jnp.max(x, axis=1)
elif reduction == 'MIDPOINT':
if self.label is None:
raise ValueError('SeperatorOutput.label is None')
midpt = self.label.shape[1] // 2
reduce_fn = lambda x: x[:, midpt, :]
else:
Expand Down Expand Up @@ -169,6 +171,8 @@ def bottleneck_classifier(self, bottleneck, train: bool):
)(classify_hiddens)
classify_hiddens = nn.swish(classify_hiddens)
classify_outputs = {}
if self.num_classes is None:
raise ValueError('SeparatorModel.num_classes is None')
for k, n in self.num_classes.items():
classify_outputs[k] = nn.Conv(n, (1,), (1,), 'SAME')(classify_hiddens)
classify_outputs['embedding'] = classify_hiddens
Expand Down
2 changes: 2 additions & 0 deletions chirp/train/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,8 @@ def train(
"reload_quantizer being True."
)

if train_dataset is None:
raise ValueError("train_dataset is None.")
train_iterator = train_dataset.as_numpy_iterator()
taxonomy_keys = ["label"]
taxonomy_loss_weight = model_bundle.model.taxonomy_loss_weight
Expand Down
4 changes: 4 additions & 0 deletions chirp/train/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def train(
] = optax.sigmoid_binary_cross_entropy,
) -> None:
"""Train a model."""
if train_dataset is None:
raise ValueError('train_dataset is None')
train_iterator = train_dataset.as_numpy_iterator()
train_metrics_collection = train_utils.NestedCollection.create(
**TRAIN_METRICS
Expand Down Expand Up @@ -318,6 +320,8 @@ def update_step(params, model_state):
train_state.params, train_state.model_state
)
grads = jax.lax.pmean(grads, axis_name='batch')
if model_bundle.optimizer is None:
raise ValueError('model_bundle.optimizer is None')
updates, opt_state = model_bundle.optimizer.update(
grads, train_state.opt_state
)
Expand Down

0 comments on commit 209b3f3

Please sign in to comment.