Skip to content

Commit

Permalink
Fix forward for eval pass in FlowMatching models (#12056)
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju authored Feb 15, 2025
1 parent 0eb9e5d commit c54a628
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions nemo/collections/audio/models/enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def output_types(self) -> Dict[str, NeuralType]:
@torch.inference_mode()
def forward(self, input_signal, input_length=None):
"""Forward pass of the model to generate samples from the target distribution.
This is used for inference mode only, and it explicitly disables SSL masking to the input.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
Expand All @@ -711,6 +712,51 @@ def forward(self, input_signal, input_length=None):
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=False)

@typecheck(
input_types={
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
},
output_types={
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
},
)
@torch.inference_mode()
def forward_eval(self, input_signal, input_length=None):
"""Forward pass of the model to generate samples from the target distribution.
This is used for eval mode only, and it enables SSL masking to the input.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=True)

@torch.inference_mode()
def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=False):
"""Internal forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
enable_ssl_masking: Whether to enable SSL masking of the input. If using SSL pretraining, masking
is applied to the input signal. If not using SSL pretraining, masking is not applied.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
Expand All @@ -725,11 +771,15 @@ def forward(self, input_signal, input_length=None):
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

# Conditional input
if self.p_cond == 0:
# The model is trained without the conditional input
encoded = torch.zeros_like(encoded)
elif self.ssl_pretrain_masking is not None:
elif enable_ssl_masking and self.ssl_pretrain_masking is not None:
# Masking for self-supervised pretraining
encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length)

# Initial process state
init_state = torch.randn_like(encoded) * self.flow.sigma_start

# Sampler
Expand Down Expand Up @@ -867,7 +917,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =

if update_metrics:
# Generate output signal
output_signal, _ = self.forward(
output_signal, _ = self.forward_eval(
input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
)

Expand Down

0 comments on commit c54a628

Please sign in to comment.