Skip to content

Commit

Permalink
Change naming style of Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou245 committed Jan 24, 2024
1 parent 5e0c2db commit a9a4efa
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 50 deletions.
4 changes: 2 additions & 2 deletions runtime/core/decoder/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ namespace wenet {

FeatureType StringToFeatureType(const std::string& feat_type_str) {
if (feat_type_str == "kaldi")
return FeatureType::KALDI;
return FeatureType::kKaldi;
else if (feat_type_str == "whisper")
return FeatureType::Whisper;
return FeatureType::kWhisper;
else
throw std::invalid_argument("Unsupported feat type!");
}
Expand Down
66 changes: 35 additions & 31 deletions runtime/core/frontend/fbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,43 @@
#include "frontend/fft.h"
#include "utils/log.h"

#define S16_TO_FLOAT_SCALE 32768
#define S16_ABS_MAX (2 << 15)

namespace wenet {

// This code is based on kaldi Fbank implementation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc

enum class WindowType {
Povey,
Hanning,
kPovey = 0,
kHanning,
};

enum class MelType {
HTK,
Slaney,
kHTK = 0,
kSlaney,
};

enum class NormalizationType {
KALDI,
Whisper,
kKaldi = 0,
kWhisper,
};

enum class LogBase {
BaseE,
Base10,
kBaseE = 0,
kBase10,
};

class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
float low_freq = 20, bool pre_emphasis = true,
bool scaled_float_as_input = false,
bool scale_input_to_unit = false,
float log_floor = std::numeric_limits<float>::epsilon(),
LogBase log_base = LogBase::BaseE,
WindowType window_type = WindowType::Povey,
MelType mel_type = MelType::HTK,
NormalizationType norm_type = NormalizationType::KALDI)
LogBase log_base = LogBase::kBaseE,
WindowType window_type = WindowType::kPovey,
MelType mel_type = MelType::kHTK,
NormalizationType norm_type = NormalizationType::kKaldi)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
Expand All @@ -71,7 +71,7 @@ class Fbank {
distribution_(0, 1.0),
dither_(0.0),
pre_emphasis_(pre_emphasis),
scaled_float_as_input_(scaled_float_as_input),
scale_input_to_unit_(scale_input_to_unit),
log_floor_(log_floor),
log_base_(log_base),
norm_type_(norm_type) {
Expand Down Expand Up @@ -105,12 +105,12 @@ class Fbank {
float mel = MelScale(freq, mel_type);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel_type == MelType::HTK) {
if (mel_type == MelType::kHTK) {
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else if (mel > center_mel)
weight = (right_mel - mel) / (right_mel - center_mel);
} else if (mel_type == MelType::Slaney) {
} else if (mel_type == MelType::kSlaney) {
if (mel <= center_mel) {
weight = (InverseMelScale(mel, mel_type) -
InverseMelScale(left_mel, mel_type)) /
Expand Down Expand Up @@ -155,12 +155,12 @@ class Fbank {

void InitWindow(WindowType window_type) {
window_.resize(frame_length_);
if (window_type == WindowType::Povey) {
if (window_type == WindowType::kPovey) {
// povey window
double a = M_2PI / (frame_length_ - 1);
for (int i = 0; i < frame_length_; ++i)
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
} else if (window_type == WindowType::Hanning) {
} else if (window_type == WindowType::kHanning) {
// periodic hanning window
double a = M_2PI / (frame_length_);
for (int i = 0; i < frame_length_; ++i)
Expand All @@ -169,10 +169,10 @@ class Fbank {
}

static inline float InverseMelScale(float mel_freq,
MelType mel_type = MelType::HTK) {
if (mel_type == MelType::HTK) {
MelType mel_type = MelType::kHTK) {
if (mel_type == MelType::kHTK) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
} else if (mel_type == MelType::Slaney) {
} else if (mel_type == MelType::kSlaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
Expand All @@ -184,13 +184,15 @@ class Fbank {
} else {
return freq;
}
} else {
throw std::invalid_argument("Unsupported mel type!");
}
}

static inline float MelScale(float freq, MelType mel_type = MelType::HTK) {
if (mel_type == MelType::HTK) {
static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) {
if (mel_type == MelType::kHTK) {
return 1127.0f * logf(1.0f + freq / 700.0f);
} else if (mel_type == MelType::Slaney) {
} else if (mel_type == MelType::kSlaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
Expand All @@ -202,6 +204,8 @@ class Fbank {
} else {
return mel;
}
} else {
throw std::invalid_argument("Unsupported mel type!");
}
}

Expand Down Expand Up @@ -255,9 +259,9 @@ class Fbank {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);

if (scaled_float_as_input_) {
if (scale_input_to_unit_) {
for (int j = 0; j < frame_length_; ++j) {
data[j] = data[j] / S16_TO_FLOAT_SCALE;
data[j] = data[j] / S16_ABS_MAX;
}
}

Expand Down Expand Up @@ -303,16 +307,16 @@ class Fbank {
if (use_log_) {
if (mel_energy < log_floor_) mel_energy = log_floor_;

if (log_base_ == LogBase::BaseE)
if (log_base_ == LogBase::kBaseE)
mel_energy = logf(mel_energy);
else if (log_base_ == LogBase::Base10)
else if (log_base_ == LogBase::kBase10)
mel_energy = log10(mel_energy);
}
if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
(*feat)[i][j] = mel_energy;
}
}
if (norm_type_ == NormalizationType::Whisper)
if (norm_type_ == NormalizationType::kWhisper)
WhisperNorm(feat, max_mel_engery);

return num_frames;
Expand All @@ -326,7 +330,7 @@ class Fbank {
bool use_log_;
bool remove_dc_offset_;
bool pre_emphasis_;
bool scaled_float_as_input_;
bool scale_input_to_unit_;
float log_floor_;
LogBase log_base_;
NormalizationType norm_type_;
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/frontend/feature_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift, config.low_freq, config.pre_emphasis,
config.scaled_float_as_input, config.log_floor, config.log_base,
config.scale_input_to_unit, config.log_floor, config.log_base,
config.window_type, config.mel_type, config.norm_type),
num_frames_(0),
input_finished_(false) {}
Expand Down
32 changes: 16 additions & 16 deletions runtime/core/frontend/feature_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace wenet {

enum class FeatureType {
KALDI,
Whisper,
kKaldi = 0,
kWhisper,
};

struct FeaturePipelineConfig {
Expand All @@ -39,37 +39,37 @@ struct FeaturePipelineConfig {
int frame_shift;
float low_freq;
bool pre_emphasis;
bool scaled_float_as_input;
bool scale_input_to_unit;
float log_floor;
LogBase log_base;
WindowType window_type;
MelType mel_type;
NormalizationType norm_type;

FeaturePipelineConfig(int num_bins, int sample_rate,
FeatureType feat_type = FeatureType::KALDI)
FeatureType feat_type = FeatureType::kKaldi)
: num_bins(num_bins), // 80 dim fbank
sample_rate(sample_rate) { // 16k sample rate
frame_length = sample_rate / 1000 * 25; // frame length 25ms
frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
if (feat_type == FeatureType::KALDI) {
if (feat_type == FeatureType::kKaldi) {
low_freq = 20.0;
pre_emphasis = true;
log_floor = std::numeric_limits<float>::epsilon();
log_base = LogBase::BaseE;
window_type = WindowType::Povey;
mel_type = MelType::HTK;
norm_type = NormalizationType::KALDI;
scaled_float_as_input = false;
} else if (feat_type == FeatureType::Whisper) {
log_base = LogBase::kBaseE;
window_type = WindowType::kPovey;
mel_type = MelType::kHTK;
norm_type = NormalizationType::kKaldi;
scale_input_to_unit = false;
} else if (feat_type == FeatureType::kWhisper) {
low_freq = 0.0;
pre_emphasis = false;
log_floor = 1e-10;
log_base = LogBase::Base10;
window_type = WindowType::Hanning;
mel_type = MelType::Slaney;
scaled_float_as_input = true;
norm_type = NormalizationType::Whisper;
log_base = LogBase::kBase10;
window_type = WindowType::kHanning;
mel_type = MelType::kSlaney;
scale_input_to_unit = true;
norm_type = NormalizationType::kWhisper;
}
}

Expand Down

0 comments on commit a9a4efa

Please sign in to comment.