|
import logging |
|
import os |
|
import random |
|
import shutil |
|
from glob import glob |
|
|
|
import click |
|
import dill |
|
import numpy as np |
|
import pandas as pd |
|
from natsort import natsorted |
|
|
|
from deep_speaker.constants import TRAIN_TEST_RATIO |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def find_files(directory, ext='wav'): |
|
return sorted(glob(directory + f'/**/*.{ext}', recursive=True)) |
|
|
|
|
|
def init_pandas(): |
|
pd.set_option('display.float_format', lambda x: '%.3f' % x) |
|
pd.set_option('display.max_rows', None) |
|
pd.set_option('display.max_columns', None) |
|
pd.set_option('display.width', 1000) |
|
|
|
|
|
def create_new_empty_dir(directory: str): |
|
if os.path.exists(directory): |
|
shutil.rmtree(directory) |
|
os.makedirs(directory) |
|
|
|
|
|
def ensure_dir_for_filename(filename: str): |
|
ensures_dir(os.path.dirname(filename)) |
|
|
|
|
|
def ensures_dir(directory: str): |
|
if len(directory) > 0 and not os.path.exists(directory): |
|
os.makedirs(directory) |
|
|
|
|
|
class ClickType: |
|
|
|
@staticmethod |
|
def input_file(writable=False): |
|
return click.Path(exists=True, file_okay=True, dir_okay=False, |
|
writable=writable, readable=True, resolve_path=True) |
|
|
|
@staticmethod |
|
def input_dir(writable=False): |
|
return click.Path(exists=True, file_okay=False, dir_okay=True, |
|
writable=writable, readable=True, resolve_path=True) |
|
|
|
@staticmethod |
|
def output_file(): |
|
return click.Path(exists=False, file_okay=True, dir_okay=False, |
|
writable=True, readable=True, resolve_path=True) |
|
|
|
@staticmethod |
|
def output_dir(): |
|
return click.Path(exists=False, file_okay=False, dir_okay=True, |
|
writable=True, readable=True, resolve_path=True) |
|
|
|
|
|
def parallel_function(f, sequence, num_threads=None): |
|
from multiprocessing import Pool |
|
pool = Pool(processes=num_threads) |
|
result = pool.map(f, sequence) |
|
cleaned = [x for x in result if x is not None] |
|
pool.close() |
|
pool.join() |
|
return cleaned |
|
|
|
|
|
def load_best_checkpoint(checkpoint_dir): |
|
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5'))) |
|
if len(checkpoints) != 0: |
|
return checkpoints[-1] |
|
return None |
|
|
|
|
|
def delete_older_checkpoints(checkpoint_dir, max_to_keep=5): |
|
assert max_to_keep > 0 |
|
checkpoints = natsorted(glob(os.path.join(checkpoint_dir, '*.h5'))) |
|
checkpoints_to_keep = checkpoints[-max_to_keep:] |
|
for checkpoint in checkpoints: |
|
if checkpoint not in checkpoints_to_keep: |
|
os.remove(checkpoint) |
|
|
|
|
|
def enable_deterministic(): |
|
print('Deterministic mode enabled.') |
|
np.random.seed(123) |
|
random.seed(123) |
|
|
|
|
|
def load_pickle(file): |
|
if not os.path.exists(file): |
|
return None |
|
logger.info(f'Loading PKL file: {file}.') |
|
with open(file, 'rb') as r: |
|
return dill.load(r) |
|
|
|
|
|
def load_npy(file): |
|
if not os.path.exists(file): |
|
return None |
|
logger.info(f'Loading NPY file: {file}.') |
|
return np.load(file) |
|
|
|
|
|
def train_test_sp_to_utt(audio, is_test): |
|
sp_to_utt = {} |
|
for speaker_id, utterances in audio.speakers_to_utterances.items(): |
|
utterances_files = sorted(utterances.values()) |
|
train_test_sep = int(len(utterances_files) * TRAIN_TEST_RATIO) |
|
sp_to_utt[speaker_id] = utterances_files[train_test_sep:] if is_test else utterances_files[:train_test_sep] |
|
return sp_to_utt |
|
|