|
import os |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
class SpeakerEmbeddingsDataset(torch.utils.data.Dataset): |
|
|
|
def __init__(self, feature_path, device, mode='utterance'): |
|
super(SpeakerEmbeddingsDataset, self).__init__() |
|
|
|
modes = ['utterance', 'speaker'] |
|
assert mode in modes, f'mode: {mode} is not supported' |
|
if mode == 'utterance': |
|
self.mode = 'utt' |
|
elif mode == 'speaker': |
|
self.mode = 'spk' |
|
|
|
self.device = device |
|
|
|
self.x, self.speakers = self._load_features(feature_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.speakers) |
|
|
|
def __getitem__(self, index): |
|
embedding = self.normalize_embedding(self.x[index]) |
|
|
|
return embedding, torch.zeros([0]) |
|
|
|
def normalize_embedding(self, vector): |
|
return torch.sub(vector, self.mean) / self.std |
|
|
|
def get_speaker(self, label): |
|
return self.class2spk[label] |
|
|
|
def get_embedding_dim(self): |
|
return self.x.shape[-1] |
|
|
|
def get_num_speaker(self): |
|
return len(torch.unique((self.y))) |
|
|
|
def set_labels(self, labels): |
|
self.y_old = self.y |
|
self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
def _load_features(self, feature_path): |
|
if os.path.isfile(feature_path): |
|
vectors = torch.load(feature_path, map_location=self.device) |
|
if isinstance(vectors, list): |
|
vectors = torch.stack(vectors) |
|
|
|
self.mean = torch.mean(vectors) |
|
self.std = torch.std(vectors) |
|
return vectors, torch.zeros(vectors.size(0)) |
|
else: |
|
vectors = torch.load(feature_path, map_location=self.device) |
|
|
|
self.mean = torch.mean(vectors) |
|
self.std = torch.std(vectors) |
|
|
|
spk2idx = {} |
|
with open(feature_path / f'{self.mode}2idx', 'r') as f: |
|
for line in f: |
|
split_line = line.strip().split() |
|
if len(split_line) == 2: |
|
spk2idx[split_line[0].strip()] = int(split_line[1]) |
|
|
|
speakers, indices = zip(*spk2idx.items()) |
|
|
|
if (feature_path / 'utt2spk').exists(): |
|
utt2spk = {} |
|
with open(feature_path / 'utt2spk', 'r') as f: |
|
for line in f: |
|
split_line = line.strip().split() |
|
if len(split_line) == 2: |
|
utt2spk[split_line[0].strip()] = split_line[1].strip() |
|
|
|
speakers = [utt2spk[utt] for utt in speakers] |
|
|
|
return vectors[np.array(indices)], speakers |
|
|
|
def _reformat_features(self, features): |
|
if len(features.shape) == 2: |
|
return features.reshape(features.shape[0], 1, 1, features.shape[1]) |
|
|