From c54a628048012fe86dbffdd60f163c04b122031d Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Fri, 14 Feb 2025 16:26:39 -0800 Subject: [PATCH] Fix forward for eval pass in FlowMatching models (#12056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- nemo/collections/audio/models/enhancement.py | 54 +++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index 8e2206afcef1..4da6a247d563 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -703,6 +703,7 @@ def output_types(self) -> Dict[str, NeuralType]: @torch.inference_mode() def forward(self, input_signal, input_length=None): """Forward pass of the model to generate samples from the target distribution. + This is used for inference mode only, and it explicitly disables SSL masking to the input. Args: input_signal: Tensor that represents a batch of raw audio signals, @@ -711,6 +712,51 @@ def forward(self, input_signal, input_length=None): input_signal_length: Vector of length B, that contains the individual lengths of the audio sequences. + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=False) + + @typecheck( + input_types={ + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + ) + @torch.inference_mode() + def forward_eval(self, input_signal, input_length=None): + """Forward pass of the model to generate samples from the target distribution. + This is used for eval mode only, and it enables SSL masking to the input. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=True) + + @torch.inference_mode() + def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=False): + """Internal forward pass of the model. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + enable_ssl_masking: Whether to enable SSL masking of the input. If using SSL pretraining, masking + is applied to the input signal. If not using SSL pretraining, masking is not applied. + Returns: Output signal `output` in the time domain and the length of the output signal `output_length`. """ @@ -725,11 +771,15 @@ def forward(self, input_signal, input_length=None): # Encoder encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + # Conditional input if self.p_cond == 0: + # The model is trained without the conditional input encoded = torch.zeros_like(encoded) - elif self.ssl_pretrain_masking is not None: + elif enable_ssl_masking and self.ssl_pretrain_masking is not None: + # Masking for self-supervised pretraining encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length) + # Initial process state init_state = torch.randn_like(encoded) * self.flow.sigma_start # Sampler @@ -867,7 +917,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = if update_metrics: # Generate output signal - output_signal, _ = self.forward( + output_signal, _ = self.forward_eval( input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] )