Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import sys | |
import uuid | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import librosa | |
import numpy as np | |
import pandas as pd | |
from scipy.io import wavfile | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
from tqdm import tqdm | |
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig | |
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel | |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--model_dir", default="serialization_dir/best", type=str) | |
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) | |
parser.add_argument("--limit", default=10, type=int) | |
args = parser.parse_args() | |
return args | |
def logging_config(): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO) | |
stream_handler = logging.StreamHandler() | |
stream_handler.setLevel(logging.INFO) | |
stream_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
return logger | |
def main(): | |
return | |
if __name__ == '__main__': | |
main() | |