Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -324,9 +324,6 @@ def main():
|
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
-
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
-
# raise AssertionError("nan or inf in snr_loss")
|
329 |
-
|
330 |
loss = speech_loss + irm_loss + snr_loss
|
331 |
|
332 |
total_loss += loss.item()
|
|
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
|
|
|
|
|
|
327 |
loss = speech_loss + irm_loss + snr_loss
|
328 |
|
329 |
total_loss += loss.item()
|
examples/spectrum_dfnet_aishell/step_3_evaluation.py
CHANGED
@@ -19,7 +19,7 @@ import torch.nn as nn
|
|
19 |
import torchaudio
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
-
from toolbox.torchaudio.models.
|
23 |
|
24 |
|
25 |
def get_args():
|
@@ -152,7 +152,7 @@ def main():
|
|
152 |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
153 |
|
154 |
logger.info("prepare model")
|
155 |
-
model =
|
156 |
pretrained_model_name_or_path=args.model_dir,
|
157 |
)
|
158 |
model.to(device)
|
|
|
19 |
import torchaudio
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
+
from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
|
23 |
|
24 |
|
25 |
def get_args():
|
|
|
152 |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
153 |
|
154 |
logger.info("prepare model")
|
155 |
+
model = SpectrumDfNetPretrainedModel.from_pretrained(
|
156 |
pretrained_model_name_or_path=args.model_dir,
|
157 |
)
|
158 |
model.to(device)
|