File size: 3,386 Bytes
aed64b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import logging
import os
import random
import shutil
from glob import glob
import click
import dill
import numpy as np
import pandas as pd
from natsort import natsorted
from deep_speaker.constants import TRAIN_TEST_RATIO
logger = logging.getLogger(__name__)
def find_files(directory, ext='wav'):
return sorted(glob(directory + f'/**/*.{ext}', recursive=True))
def init_pandas():
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
def create_new_empty_dir(directory: str):
if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory)
def ensure_dir_for_filename(filename: str):
ensures_dir(os.path.dirname(filename))
def ensures_dir(directory: str):
if len(directory) > 0 and not os.path.exists(directory):
os.makedirs(directory)
class ClickType:
@staticmethod
def input_file(writable=False):
return click.Path(exists=True, file_okay=True, dir_okay=False,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def input_dir(writable=False):
return click.Path(exists=True, file_okay=False, dir_okay=True,
writable=writable, readable=True, resolve_path=True)
@staticmethod
def output_file():
return click.Path(exists=False, file_okay=True, dir_okay=False,
writable=True, readable=True, resolve_path=True)
@staticmethod
def output_dir():
return click.Path(exists=False, file_okay=False, dir_okay=True,
writable=True, readable=True, resolve_path=True)
def parallel_function(f, sequence, num_threads=None):
from multiprocessing import Pool
pool = Pool(processes=num_threads)
result = pool.map(f, sequence)
cleaned = [x for x in result if x is not None]
pool.close()
pool.join()
return cleaned
def load_best_checkpoint(checkpoint_dir):
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
if len(checkpoints) != 0:
return checkpoints[-1]
return None
def delete_older_checkpoints(checkpoint_dir, max_to_keep=5):
assert max_to_keep > 0
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5')))
checkpoints_to_keep = checkpoints[-max_to_keep:]
for checkpoint in checkpoints:
if checkpoint not in checkpoints_to_keep:
os.remove(checkpoint)
def enable_deterministic():
print('Deterministic mode enabled.')
np.random.seed(123)
random.seed(123)
def load_pickle(file):
if not os.path.exists(file):
return None
logger.info(f'Loading PKL file: {file}.')
with open(file, 'rb') as r:
return dill.load(r)
def load_npy(file):
if not os.path.exists(file):
return None
logger.info(f'Loading NPY file: {file}.')
return np.load(file)
def train_test_sp_to_utt(audio, is_test):
sp_to_utt = {}
for speaker_id, utterances in audio.speakers_to_utterances.items():
utterances_files = sorted(utterances.values())
train_test_sep = int(len(utterances_files) * TRAIN_TEST_RATIO)
sp_to_utt[speaker_id] = utterances_files[train_test_sep:] if is_test else utterances_files[:train_test_sep]
return sp_to_utt
|