#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import logging import os from pathlib import Path import sys import uuid pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import librosa import numpy as np import pandas as pd from scipy.io import wavfile import torch import torch.nn as nn import torchaudio from tqdm import tqdm from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--model_dir", default="serialization_dir/best", type=str) parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) parser.add_argument("--limit", default=10, type=int) args = parser.parse_args() return args def logging_config(): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) return logger def main(): return if __name__ == '__main__': main()