|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
import numpy as np |
|
from librosa.util import find_files |
|
from torchaudio import load |
|
from torch import nn |
|
import os |
|
import re |
|
import random |
|
import pickle |
|
import torchaudio |
|
import sys |
|
import time |
|
import glob |
|
import tqdm |
|
from pathlib import Path |
|
|
|
CACHE_PATH = os.path.join(os.path.dirname(__file__), '.cache/') |
|
|
|
|
|
|
|
class SpeakerClassifiDataset(Dataset): |
|
def __init__(self, mode, file_path, meta_data, max_timestep=None): |
|
|
|
self.root = file_path |
|
self.speaker_num = 1251 |
|
self.meta_data = meta_data |
|
self.max_timestep = max_timestep |
|
self.usage_list = open(self.meta_data, "r").readlines() |
|
|
|
cache_path = os.path.join(CACHE_PATH, f'{mode}.pkl') |
|
if os.path.isfile(cache_path): |
|
print(f'[SpeakerClassifiDataset] - Loading file paths from {cache_path}') |
|
with open(cache_path, 'rb') as cache: |
|
dataset = pickle.load(cache) |
|
else: |
|
dataset = eval("self.{}".format(mode))() |
|
os.makedirs(os.path.dirname(cache_path), exist_ok=True) |
|
with open(cache_path, 'wb') as cache: |
|
pickle.dump(dataset, cache) |
|
print(f'[SpeakerClassifiDataset] - there are {len(dataset)} files found') |
|
|
|
self.dataset = dataset |
|
self.label = self.build_label(self.dataset) |
|
|
|
|
|
def build_label(self, train_path_list): |
|
|
|
y = [] |
|
for path in train_path_list: |
|
id_string = path.split(os.path.sep)[-3] |
|
y.append(int(id_string[2:]) - 10001) |
|
|
|
return y |
|
|
|
@classmethod |
|
def label2speaker(self, labels): |
|
return [f"id{label + 10001}" for label in labels] |
|
|
|
def train(self): |
|
|
|
dataset = [] |
|
print("search specified wav name for training set") |
|
for string in tqdm.tqdm(self.usage_list): |
|
pair = string.split() |
|
index = pair[0] |
|
if int(index) == 1: |
|
x = list(self.root.glob("*/wav/" + pair[1])) |
|
dataset.append(str(x[0])) |
|
print("finish searching training set wav") |
|
|
|
return dataset |
|
|
|
def dev(self): |
|
|
|
dataset = [] |
|
print("search specified wav name for dev set") |
|
for string in tqdm.tqdm(self.usage_list): |
|
pair = string.split() |
|
index = pair[0] |
|
if int(index) == 2: |
|
x = list(self.root.glob("*/wav/" + pair[1])) |
|
dataset.append(str(x[0])) |
|
print("finish searching dev set wav") |
|
|
|
return dataset |
|
|
|
def test(self): |
|
|
|
dataset = [] |
|
print("search specified wav name for test set") |
|
for string in tqdm.tqdm(self.usage_list): |
|
pair = string.split() |
|
index = pair[0] |
|
if int(index) == 3: |
|
x = list(self.root.glob("*/wav/" + pair[1])) |
|
dataset.append(str(x[0])) |
|
print("finish searching test set wav") |
|
|
|
return dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
wav, sr = torchaudio.load(self.dataset[idx]) |
|
wav = wav.squeeze(0) |
|
length = wav.shape[0] |
|
|
|
if self.max_timestep !=None: |
|
if length > self.max_timestep: |
|
start = random.randint(0, int(length-self.max_timestep)) |
|
wav = wav[start:start+self.max_timestep] |
|
length = self.max_timestep |
|
|
|
def path2name(path): |
|
return Path("-".join((Path(path).parts)[-3:])).stem |
|
|
|
path = self.dataset[idx] |
|
return wav.numpy(), self.label[idx], path2name(path) |
|
|
|
def collate_fn(self, samples): |
|
return zip(*samples) |
|
|