nx_denoise / toolbox /torchaudio /models /dfnet /configuration_dfnet.py
HoneyTian's picture
add microphone audio input
85a1b16
raw
history blame
4.87 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple
from toolbox.torchaudio.configuration_utils import PretrainedConfig
class DfNetConfig(PretrainedConfig):
def __init__(self,
sample_rate: int = 8000,
nfft: int = 512,
win_size: int = 200,
hop_size: int = 80,
win_type: str = "hann",
spec_bins: int = 256,
erb_bins: int = 32,
min_freq_bins_for_erb: int = 2,
conv_channels: int = 64,
conv_kernel_size_input: Tuple[int, int] = (3, 3),
conv_kernel_size_inner: Tuple[int, int] = (1, 3),
conv_lookahead: int = 0,
convt_kernel_size_inner: Tuple[int, int] = (1, 3),
embedding_hidden_size: int = 256,
encoder_combine_op: str = "concat",
encoder_emb_skip_op: str = "none",
encoder_emb_linear_groups: int = 16,
encoder_emb_hidden_size: int = 256,
encoder_linear_groups: int = 32,
decoder_emb_num_layers: int = 3,
decoder_emb_skip_op: str = "none",
decoder_emb_linear_groups: int = 16,
decoder_emb_hidden_size: int = 256,
df_decoder_hidden_size: int = 256,
df_num_layers: int = 2,
df_order: int = 5,
df_bins: int = 96,
df_gru_skip: str = "grouped_linear",
df_decoder_linear_groups: int = 16,
df_pathway_kernel_size_t: int = 5,
df_lookahead: int = 2,
n_frame: int = 3,
max_local_snr: int = 30,
min_local_snr: int = -15,
norm_tau: float = 1.,
min_snr_db: float = -10,
max_snr_db: float = 20,
lr: float = 0.001,
lr_scheduler: str = "CosineAnnealingLR",
lr_scheduler_kwargs: dict = None,
max_epochs: int = 100,
clip_grad_norm: float = 10.,
seed: int = 1234,
num_workers: int = 4,
batch_size: int = 4,
eval_steps: int = 25000,
use_post_filter: bool = False,
**kwargs
):
super(DfNetConfig, self).__init__(**kwargs)
# transform
self.sample_rate = sample_rate
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
# spectrum
self.spec_bins = spec_bins
self.erb_bins = erb_bins
self.min_freq_bins_for_erb = min_freq_bins_for_erb
# conv
self.conv_channels = conv_channels
self.conv_kernel_size_input = conv_kernel_size_input
self.conv_kernel_size_inner = conv_kernel_size_inner
self.conv_lookahead = conv_lookahead
self.convt_kernel_size_inner = convt_kernel_size_inner
self.embedding_hidden_size = embedding_hidden_size
# encoder
self.encoder_emb_skip_op = encoder_emb_skip_op
self.encoder_emb_linear_groups = encoder_emb_linear_groups
self.encoder_emb_hidden_size = encoder_emb_hidden_size
self.encoder_linear_groups = encoder_linear_groups
self.encoder_combine_op = encoder_combine_op
# decoder
self.decoder_emb_num_layers = decoder_emb_num_layers
self.decoder_emb_skip_op = decoder_emb_skip_op
self.decoder_emb_linear_groups = decoder_emb_linear_groups
self.decoder_emb_hidden_size = decoder_emb_hidden_size
# df decoder
self.df_decoder_hidden_size = df_decoder_hidden_size
self.df_num_layers = df_num_layers
self.df_order = df_order
self.df_bins = df_bins
self.df_gru_skip = df_gru_skip
self.df_decoder_linear_groups = df_decoder_linear_groups
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
self.df_lookahead = df_lookahead
# lsnr
self.n_frame = n_frame
self.max_local_snr = max_local_snr
self.min_local_snr = min_local_snr
self.norm_tau = norm_tau
# data snr
self.min_snr_db = min_snr_db
self.max_snr_db = max_snr_db
# train
self.lr = lr
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
self.max_epochs = max_epochs
self.clip_grad_norm = clip_grad_norm
self.seed = seed
self.num_workers = num_workers
self.batch_size = batch_size
self.eval_steps = eval_steps
# runtime
self.use_post_filter = use_post_filter
if __name__ == "__main__":
pass