Kremon96 commited on
Commit
fe6ebc3
·
verified ·
1 Parent(s): cd1bb0a

Delete encoder

Browse files
encoder/__init__.py DELETED
File without changes
encoder/audio.py DELETED
@@ -1,117 +0,0 @@
1
- from scipy.ndimage.morphology import binary_dilation
2
- from encoder.params_data import *
3
- from pathlib import Path
4
- from typing import Optional, Union
5
- from warnings import warn
6
- import numpy as np
7
- import librosa
8
- import struct
9
-
10
- try:
11
- import webrtcvad
12
- except:
13
- warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
- webrtcvad=None
15
-
16
- int16_max = (2 ** 15) - 1
17
-
18
-
19
- def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
- source_sr: Optional[int] = None,
21
- normalize: Optional[bool] = True,
22
- trim_silence: Optional[bool] = True):
23
- """
24
- Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
- either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
-
27
- :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
- just .wav), either the waveform as a numpy array of floats.
29
- :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
- preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
- hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
- this argument will be ignored.
33
- """
34
- # Load the wav from disk if needed
35
- if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
- wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
- else:
38
- wav = fpath_or_wav
39
-
40
- # Resample the wav if needed
41
- if source_sr is not None and source_sr != sampling_rate:
42
- wav = librosa.resample(wav, source_sr, sampling_rate)
43
-
44
- # Apply the preprocessing: normalize volume and shorten long silences
45
- if normalize:
46
- wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
- if webrtcvad and trim_silence:
48
- wav = trim_long_silences(wav)
49
-
50
- return wav
51
-
52
-
53
- def wav_to_mel_spectrogram(wav):
54
- """
55
- Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
- Note: this not a log-mel spectrogram.
57
- """
58
- frames = librosa.feature.melspectrogram(
59
- wav,
60
- sampling_rate,
61
- n_fft=int(sampling_rate * mel_window_length / 1000),
62
- hop_length=int(sampling_rate * mel_window_step / 1000),
63
- n_mels=mel_n_channels
64
- )
65
- return frames.astype(np.float32).T
66
-
67
-
68
- def trim_long_silences(wav):
69
- """
70
- Ensures that segments without voice in the waveform remain no longer than a
71
- threshold determined by the VAD parameters in params.py.
72
-
73
- :param wav: the raw waveform as a numpy array of floats
74
- :return: the same waveform with silences trimmed away (length <= original wav length)
75
- """
76
- # Compute the voice detection window size
77
- samples_per_window = (vad_window_length * sampling_rate) // 1000
78
-
79
- # Trim the end of the audio to have a multiple of the window size
80
- wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
-
82
- # Convert the float waveform to 16-bit mono PCM
83
- pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
-
85
- # Perform voice activation detection
86
- voice_flags = []
87
- vad = webrtcvad.Vad(mode=3)
88
- for window_start in range(0, len(wav), samples_per_window):
89
- window_end = window_start + samples_per_window
90
- voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
- sample_rate=sampling_rate))
92
- voice_flags = np.array(voice_flags)
93
-
94
- # Smooth the voice detection with a moving average
95
- def moving_average(array, width):
96
- array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
- ret = np.cumsum(array_padded, dtype=float)
98
- ret[width:] = ret[width:] - ret[:-width]
99
- return ret[width - 1:] / width
100
-
101
- audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
- audio_mask = np.round(audio_mask).astype(np.bool)
103
-
104
- # Dilate the voiced regions
105
- audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
- audio_mask = np.repeat(audio_mask, samples_per_window)
107
-
108
- return wav[audio_mask == True]
109
-
110
-
111
- def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
- if increase_only and decrease_only:
113
- raise ValueError("Both increase only and decrease only are set")
114
- dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
- if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
- return wav
117
- return wav * (10 ** (dBFS_change / 20))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/config.py DELETED
@@ -1,45 +0,0 @@
1
- librispeech_datasets = {
2
- "train": {
3
- "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
- "other": ["LibriSpeech/train-other-500"]
5
- },
6
- "test": {
7
- "clean": ["LibriSpeech/test-clean"],
8
- "other": ["LibriSpeech/test-other"]
9
- },
10
- "dev": {
11
- "clean": ["LibriSpeech/dev-clean"],
12
- "other": ["LibriSpeech/dev-other"]
13
- },
14
- }
15
- libritts_datasets = {
16
- "train": {
17
- "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
- "other": ["LibriTTS/train-other-500"]
19
- },
20
- "test": {
21
- "clean": ["LibriTTS/test-clean"],
22
- "other": ["LibriTTS/test-other"]
23
- },
24
- "dev": {
25
- "clean": ["LibriTTS/dev-clean"],
26
- "other": ["LibriTTS/dev-other"]
27
- },
28
- }
29
- voxceleb_datasets = {
30
- "voxceleb1" : {
31
- "train": ["VoxCeleb1/wav"],
32
- "test": ["VoxCeleb1/test_wav"]
33
- },
34
- "voxceleb2" : {
35
- "train": ["VoxCeleb2/dev/aac"],
36
- "test": ["VoxCeleb2/test_wav"]
37
- }
38
- }
39
-
40
- other_datasets = [
41
- "LJSpeech-1.1",
42
- "VCTK-Corpus/wav48",
43
- ]
44
-
45
- anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/data_objects/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
- from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
 
 
 
encoder/data_objects/random_cycler.py DELETED
@@ -1,37 +0,0 @@
1
- import random
2
-
3
- class RandomCycler:
4
- """
5
- Creates an internal copy of a sequence and allows access to its items in a constrained random
6
- order. For a source sequence of n items and one or several consecutive queries of a total
7
- of m items, the following guarantees hold (one implies the other):
8
- - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
- - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
- """
11
-
12
- def __init__(self, source):
13
- if len(source) == 0:
14
- raise Exception("Can't create RandomCycler from an empty collection")
15
- self.all_items = list(source)
16
- self.next_items = []
17
-
18
- def sample(self, count: int):
19
- shuffle = lambda l: random.sample(l, len(l))
20
-
21
- out = []
22
- while count > 0:
23
- if count >= len(self.all_items):
24
- out.extend(shuffle(list(self.all_items)))
25
- count -= len(self.all_items)
26
- continue
27
- n = min(count, len(self.next_items))
28
- out.extend(self.next_items[:n])
29
- count -= n
30
- self.next_items = self.next_items[n:]
31
- if len(self.next_items) == 0:
32
- self.next_items = shuffle(list(self.all_items))
33
- return out
34
-
35
- def __next__(self):
36
- return self.sample(1)[0]
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/data_objects/speaker.py DELETED
@@ -1,40 +0,0 @@
1
- from encoder.data_objects.random_cycler import RandomCycler
2
- from encoder.data_objects.utterance import Utterance
3
- from pathlib import Path
4
-
5
- # Contains the set of utterances of a single speaker
6
- class Speaker:
7
- def __init__(self, root: Path):
8
- self.root = root
9
- self.name = root.name
10
- self.utterances = None
11
- self.utterance_cycler = None
12
-
13
- def _load_utterances(self):
14
- with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
- sources = [l.split(",") for l in sources_file]
16
- sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
- self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
- self.utterance_cycler = RandomCycler(self.utterances)
19
-
20
- def random_partial(self, count, n_frames):
21
- """
22
- Samples a batch of <count> unique partial utterances from the disk in a way that all
23
- utterances come up at least once every two cycles and in a random order every time.
24
-
25
- :param count: The number of partial utterances to sample from the set of utterances from
26
- that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
- the number of utterances available.
28
- :param n_frames: The number of frames in the partial utterance.
29
- :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
- frames are the frames of the partial utterances and range is the range of the partial
31
- utterance with regard to the complete utterance.
32
- """
33
- if self.utterances is None:
34
- self._load_utterances()
35
-
36
- utterances = self.utterance_cycler.sample(count)
37
-
38
- a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
-
40
- return a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/data_objects/speaker_batch.py DELETED
@@ -1,13 +0,0 @@
1
- import numpy as np
2
- from typing import List
3
- from encoder.data_objects.speaker import Speaker
4
-
5
-
6
- class SpeakerBatch:
7
- def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
- self.speakers = speakers
9
- self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
-
11
- # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
- # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
- self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/data_objects/speaker_verification_dataset.py DELETED
@@ -1,56 +0,0 @@
1
- from encoder.data_objects.random_cycler import RandomCycler
2
- from encoder.data_objects.speaker_batch import SpeakerBatch
3
- from encoder.data_objects.speaker import Speaker
4
- from encoder.params_data import partials_n_frames
5
- from torch.utils.data import Dataset, DataLoader
6
- from pathlib import Path
7
-
8
- # TODO: improve with a pool of speakers for data efficiency
9
-
10
- class SpeakerVerificationDataset(Dataset):
11
- def __init__(self, datasets_root: Path):
12
- self.root = datasets_root
13
- speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
- if len(speaker_dirs) == 0:
15
- raise Exception("No speakers found. Make sure you are pointing to the directory "
16
- "containing all preprocessed speaker directories.")
17
- self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
- self.speaker_cycler = RandomCycler(self.speakers)
19
-
20
- def __len__(self):
21
- return int(1e10)
22
-
23
- def __getitem__(self, index):
24
- return next(self.speaker_cycler)
25
-
26
- def get_logs(self):
27
- log_string = ""
28
- for log_fpath in self.root.glob("*.txt"):
29
- with log_fpath.open("r") as log_file:
30
- log_string += "".join(log_file.readlines())
31
- return log_string
32
-
33
-
34
- class SpeakerVerificationDataLoader(DataLoader):
35
- def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
- batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
- worker_init_fn=None):
38
- self.utterances_per_speaker = utterances_per_speaker
39
-
40
- super().__init__(
41
- dataset=dataset,
42
- batch_size=speakers_per_batch,
43
- shuffle=False,
44
- sampler=sampler,
45
- batch_sampler=batch_sampler,
46
- num_workers=num_workers,
47
- collate_fn=self.collate,
48
- pin_memory=pin_memory,
49
- drop_last=False,
50
- timeout=timeout,
51
- worker_init_fn=worker_init_fn
52
- )
53
-
54
- def collate(self, speakers):
55
- return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/data_objects/utterance.py DELETED
@@ -1,26 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class Utterance:
5
- def __init__(self, frames_fpath, wave_fpath):
6
- self.frames_fpath = frames_fpath
7
- self.wave_fpath = wave_fpath
8
-
9
- def get_frames(self):
10
- return np.load(self.frames_fpath)
11
-
12
- def random_partial(self, n_frames):
13
- """
14
- Crops the frames into a partial utterance of n_frames
15
-
16
- :param n_frames: The number of frames of the partial utterance
17
- :return: the partial utterance frames and a tuple indicating the start and end of the
18
- partial utterance in the complete utterance.
19
- """
20
- frames = self.get_frames()
21
- if frames.shape[0] == n_frames:
22
- start = 0
23
- else:
24
- start = np.random.randint(0, frames.shape[0] - n_frames)
25
- end = start + n_frames
26
- return frames[start:end], (start, end)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/inference.py DELETED
@@ -1,178 +0,0 @@
1
- from encoder.params_data import *
2
- from encoder.model import SpeakerEncoder
3
- from encoder.audio import preprocess_wav # We want to expose this function from here
4
- from matplotlib import cm
5
- from encoder import audio
6
- from pathlib import Path
7
- import numpy as np
8
- import torch
9
-
10
- _model = None # type: SpeakerEncoder
11
- _device = None # type: torch.device
12
-
13
-
14
- def load_model(weights_fpath: Path, device=None):
15
- """
16
- Loads the model in memory. If this function is not explicitely called, it will be run on the
17
- first call to embed_frames() with the default weights file.
18
-
19
- :param weights_fpath: the path to saved model weights.
20
- :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
- model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
- If None, will default to your GPU if it"s available, otherwise your CPU.
23
- """
24
- # TODO: I think the slow loading of the encoder might have something to do with the device it
25
- # was saved on. Worth investigating.
26
- global _model, _device
27
- if device is None:
28
- _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- elif isinstance(device, str):
30
- _device = torch.device(device)
31
- _model = SpeakerEncoder(_device, torch.device("cpu"))
32
- checkpoint = torch.load(weights_fpath, _device)
33
- _model.load_state_dict(checkpoint["model_state"])
34
- _model.eval()
35
- print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
-
37
-
38
- def is_loaded():
39
- return _model is not None
40
-
41
-
42
- def embed_frames_batch(frames_batch):
43
- """
44
- Computes embeddings for a batch of mel spectrogram.
45
-
46
- :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
47
- (batch_size, n_frames, n_channels)
48
- :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
49
- """
50
- if _model is None:
51
- raise Exception("Model was not loaded. Call load_model() before inference.")
52
-
53
- frames = torch.from_numpy(frames_batch).to(_device)
54
- embed = _model.forward(frames).detach().cpu().numpy()
55
- return embed
56
-
57
-
58
- def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
59
- min_pad_coverage=0.75, overlap=0.5):
60
- """
61
- Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
62
- partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
63
- spectrogram slices are returned, so as to make each partial utterance waveform correspond to
64
- its spectrogram. This function assumes that the mel spectrogram parameters used are those
65
- defined in params_data.py.
66
-
67
- The returned ranges may be indexing further than the length of the waveform. It is
68
- recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
69
-
70
- :param n_samples: the number of samples in the waveform
71
- :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
72
- utterance
73
- :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
74
- enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
75
- then the last partial utterance will be considered, as if we padded the audio. Otherwise,
76
- it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
77
- utterance, this parameter is ignored so that the function always returns at least 1 slice.
78
- :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
79
- utterances are entirely disjoint.
80
- :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
81
- respectively the waveform and the mel spectrogram with these slices to obtain the partial
82
- utterances.
83
- """
84
- assert 0 <= overlap < 1
85
- assert 0 < min_pad_coverage <= 1
86
-
87
- samples_per_frame = int((sampling_rate * mel_window_step / 1000))
88
- n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
89
- frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
90
-
91
- # Compute the slices
92
- wav_slices, mel_slices = [], []
93
- steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
94
- for i in range(0, steps, frame_step):
95
- mel_range = np.array([i, i + partial_utterance_n_frames])
96
- wav_range = mel_range * samples_per_frame
97
- mel_slices.append(slice(*mel_range))
98
- wav_slices.append(slice(*wav_range))
99
-
100
- # Evaluate whether extra padding is warranted or not
101
- last_wav_range = wav_slices[-1]
102
- coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
103
- if coverage < min_pad_coverage and len(mel_slices) > 1:
104
- mel_slices = mel_slices[:-1]
105
- wav_slices = wav_slices[:-1]
106
-
107
- return wav_slices, mel_slices
108
-
109
-
110
- def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
111
- """
112
- Computes an embedding for a single utterance.
113
-
114
- # TODO: handle multiple wavs to benefit from batching on GPU
115
- :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
116
- :param using_partials: if True, then the utterance is split in partial utterances of
117
- <partial_utterance_n_frames> frames and the utterance embedding is computed from their
118
- normalized average. If False, the utterance is instead computed from feeding the entire
119
- spectogram to the network.
120
- :param return_partials: if True, the partial embeddings will also be returned along with the
121
- wav slices that correspond to the partial embeddings.
122
- :param kwargs: additional arguments to compute_partial_splits()
123
- :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
124
- <return_partials> is True, the partial utterances as a numpy array of float32 of shape
125
- (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
126
- returned. If <using_partials> is simultaneously set to False, both these values will be None
127
- instead.
128
- """
129
- # Process the entire utterance if not using partials
130
- if not using_partials:
131
- frames = audio.wav_to_mel_spectrogram(wav)
132
- embed = embed_frames_batch(frames[None, ...])[0]
133
- if return_partials:
134
- return embed, None, None
135
- return embed
136
-
137
- # Compute where to split the utterance into partials and pad if necessary
138
- wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
139
- max_wave_length = wave_slices[-1].stop
140
- if max_wave_length >= len(wav):
141
- wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
142
-
143
- # Split the utterance into partials
144
- frames = audio.wav_to_mel_spectrogram(wav)
145
- frames_batch = np.array([frames[s] for s in mel_slices])
146
- partial_embeds = embed_frames_batch(frames_batch)
147
-
148
- # Compute the utterance embedding from the partial embeddings
149
- raw_embed = np.mean(partial_embeds, axis=0)
150
- embed = raw_embed / np.linalg.norm(raw_embed, 2)
151
-
152
- if return_partials:
153
- return embed, partial_embeds, wave_slices
154
- return embed
155
-
156
-
157
- def embed_speaker(wavs, **kwargs):
158
- raise NotImplemented()
159
-
160
-
161
- def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
162
- import matplotlib.pyplot as plt
163
- if ax is None:
164
- ax = plt.gca()
165
-
166
- if shape is None:
167
- height = int(np.sqrt(len(embed)))
168
- shape = (height, -1)
169
- embed = embed.reshape(shape)
170
-
171
- cmap = cm.get_cmap()
172
- mappable = ax.imshow(embed, cmap=cmap)
173
- cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
- sm = cm.ScalarMappable(cmap=cmap)
175
- sm.set_clim(*color_range)
176
-
177
- ax.set_xticks([]), ax.set_yticks([])
178
- ax.set_title(title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/model.py DELETED
@@ -1,135 +0,0 @@
1
- from encoder.params_model import *
2
- from encoder.params_data import *
3
- from scipy.interpolate import interp1d
4
- from sklearn.metrics import roc_curve
5
- from torch.nn.utils import clip_grad_norm_
6
- from scipy.optimize import brentq
7
- from torch import nn
8
- import numpy as np
9
- import torch
10
-
11
-
12
- class SpeakerEncoder(nn.Module):
13
- def __init__(self, device, loss_device):
14
- super().__init__()
15
- self.loss_device = loss_device
16
-
17
- # Network defition
18
- self.lstm = nn.LSTM(input_size=mel_n_channels,
19
- hidden_size=model_hidden_size,
20
- num_layers=model_num_layers,
21
- batch_first=True).to(device)
22
- self.linear = nn.Linear(in_features=model_hidden_size,
23
- out_features=model_embedding_size).to(device)
24
- self.relu = torch.nn.ReLU().to(device)
25
-
26
- # Cosine similarity scaling (with fixed initial parameter values)
27
- self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
- self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
-
30
- # Loss
31
- self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
-
33
- def do_gradient_ops(self):
34
- # Gradient scale
35
- self.similarity_weight.grad *= 0.01
36
- self.similarity_bias.grad *= 0.01
37
-
38
- # Gradient clipping
39
- clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
-
41
- def forward(self, utterances, hidden_init=None):
42
- """
43
- Computes the embeddings of a batch of utterance spectrograms.
44
-
45
- :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
- (batch_size, n_frames, n_channels)
47
- :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
- batch_size, hidden_size). Will default to a tensor of zeros if None.
49
- :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
- """
51
- # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
- # and the final cell state.
53
- out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
-
55
- # We take only the hidden state of the last layer
56
- embeds_raw = self.relu(self.linear(hidden[-1]))
57
-
58
- # L2-normalize it
59
- embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
-
61
- return embeds
62
-
63
- def similarity_matrix(self, embeds):
64
- """
65
- Computes the similarity matrix according the section 2.1 of GE2E.
66
-
67
- :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
- utterances_per_speaker, embedding_size)
69
- :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
- utterances_per_speaker, speakers_per_batch)
71
- """
72
- speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
-
74
- # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
- centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
- centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
-
78
- # Exclusive centroids (1 per utterance)
79
- centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
- centroids_excl /= (utterances_per_speaker - 1)
81
- centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
-
83
- # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
- # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
- # We vectorize the computation for efficiency.
86
- sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
- speakers_per_batch).to(self.loss_device)
88
- mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
- for j in range(speakers_per_batch):
90
- mask = np.where(mask_matrix[j])[0]
91
- sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
- sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
-
94
- ## Even more vectorized version (slower maybe because of transpose)
95
- # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
- # ).to(self.loss_device)
97
- # eye = np.eye(speakers_per_batch, dtype=np.int)
98
- # mask = np.where(1 - eye)
99
- # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
- # mask = np.where(eye)
101
- # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
- # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
-
104
- sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
- return sim_matrix
106
-
107
- def loss(self, embeds):
108
- """
109
- Computes the softmax loss according the section 2.1 of GE2E.
110
-
111
- :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
- utterances_per_speaker, embedding_size)
113
- :return: the loss and the EER for this batch of embeddings.
114
- """
115
- speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
-
117
- # Loss
118
- sim_matrix = self.similarity_matrix(embeds)
119
- sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
- speakers_per_batch))
121
- ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
- target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
- loss = self.loss_fn(sim_matrix, target)
124
-
125
- # EER (not backpropagated)
126
- with torch.no_grad():
127
- inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
- labels = np.array([inv_argmax(i) for i in ground_truth])
129
- preds = sim_matrix.detach().cpu().numpy()
130
-
131
- # Snippet from https://yangcha.github.io/EER-ROC/
132
- fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
- eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
-
135
- return loss, eer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/params_data.py DELETED
@@ -1,29 +0,0 @@
1
-
2
- ## Mel-filterbank
3
- mel_window_length = 25 # In milliseconds
4
- mel_window_step = 10 # In milliseconds
5
- mel_n_channels = 40
6
-
7
-
8
- ## Audio
9
- sampling_rate = 16000
10
- # Number of spectrogram frames in a partial utterance
11
- partials_n_frames = 160 # 1600 ms
12
- # Number of spectrogram frames at inference
13
- inference_n_frames = 80 # 800 ms
14
-
15
-
16
- ## Voice Activation Detection
17
- # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
- # This sets the granularity of the VAD. Should not need to be changed.
19
- vad_window_length = 30 # In milliseconds
20
- # Number of frames to average together when performing the moving average smoothing.
21
- # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
- vad_moving_average_width = 8
23
- # Maximum number of consecutive silent frames a segment can have.
24
- vad_max_silence_length = 6
25
-
26
-
27
- ## Audio volume normalization
28
- audio_norm_target_dBFS = -30
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/params_model.py DELETED
@@ -1,11 +0,0 @@
1
-
2
- ## Model parameters
3
- model_hidden_size = 256
4
- model_embedding_size = 256
5
- model_num_layers = 3
6
-
7
-
8
- ## Training parameters
9
- learning_rate_init = 1e-4
10
- speakers_per_batch = 64
11
- utterances_per_speaker = 10
 
 
 
 
 
 
 
 
 
 
 
 
encoder/preprocess.py DELETED
@@ -1,184 +0,0 @@
1
- from datetime import datetime
2
- from functools import partial
3
- from multiprocessing import Pool
4
- from pathlib import Path
5
-
6
- import numpy as np
7
- from tqdm import tqdm
8
-
9
- from encoder import audio
10
- from encoder.config import librispeech_datasets, anglophone_nationalites
11
- from encoder.params_data import *
12
-
13
-
14
- _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
-
16
- class DatasetLog:
17
- """
18
- Registers metadata about the dataset in a text file.
19
- """
20
- def __init__(self, root, name):
21
- self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
- self.sample_data = dict()
23
-
24
- start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
- self.write_line("Creating dataset %s on %s" % (name, start_time))
26
- self.write_line("-----")
27
- self._log_params()
28
-
29
- def _log_params(self):
30
- from encoder import params_data
31
- self.write_line("Parameter values:")
32
- for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
- value = getattr(params_data, param_name)
34
- self.write_line("\t%s: %s" % (param_name, value))
35
- self.write_line("-----")
36
-
37
- def write_line(self, line):
38
- self.text_file.write("%s\n" % line)
39
-
40
- def add_sample(self, **kwargs):
41
- for param_name, value in kwargs.items():
42
- if not param_name in self.sample_data:
43
- self.sample_data[param_name] = []
44
- self.sample_data[param_name].append(value)
45
-
46
- def finalize(self):
47
- self.write_line("Statistics:")
48
- for param_name, values in self.sample_data.items():
49
- self.write_line("\t%s:" % param_name)
50
- self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
- self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
- self.write_line("-----")
53
- end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
- self.write_line("Finished on %s" % end_time)
55
- self.text_file.close()
56
-
57
-
58
- def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
- dataset_root = datasets_root.joinpath(dataset_name)
60
- if not dataset_root.exists():
61
- print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
- return None, None
63
- return dataset_root, DatasetLog(out_dir, dataset_name)
64
-
65
-
66
- def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
- # Give a name to the speaker that includes its dataset
68
- speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
-
70
- # Create an output directory with that name, as well as a txt file containing a
71
- # reference to each source file.
72
- speaker_out_dir = out_dir.joinpath(speaker_name)
73
- speaker_out_dir.mkdir(exist_ok=True)
74
- sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
-
76
- # There's a possibility that the preprocessing was interrupted earlier, check if
77
- # there already is a sources file.
78
- if sources_fpath.exists():
79
- try:
80
- with sources_fpath.open("r") as sources_file:
81
- existing_fnames = {line.split(",")[0] for line in sources_file}
82
- except:
83
- existing_fnames = {}
84
- else:
85
- existing_fnames = {}
86
-
87
- # Gather all audio files for that speaker recursively
88
- sources_file = sources_fpath.open("a" if skip_existing else "w")
89
- audio_durs = []
90
- for extension in _AUDIO_EXTENSIONS:
91
- for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
- # Check if the target output file already exists
93
- out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
- out_fname = out_fname.replace(".%s" % extension, ".npy")
95
- if skip_existing and out_fname in existing_fnames:
96
- continue
97
-
98
- # Load and preprocess the waveform
99
- wav = audio.preprocess_wav(in_fpath)
100
- if len(wav) == 0:
101
- continue
102
-
103
- # Create the mel spectrogram, discard those that are too short
104
- frames = audio.wav_to_mel_spectrogram(wav)
105
- if len(frames) < partials_n_frames:
106
- continue
107
-
108
- out_fpath = speaker_out_dir.joinpath(out_fname)
109
- np.save(out_fpath, frames)
110
- sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
- audio_durs.append(len(wav) / sampling_rate)
112
-
113
- sources_file.close()
114
-
115
- return audio_durs
116
-
117
-
118
- def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
- print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
-
121
- # Process the utterances for each speaker
122
- work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
- with Pool(4) as pool:
124
- tasks = pool.imap(work_fn, speaker_dirs)
125
- for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
- for sample_dur in sample_durs:
127
- logger.add_sample(duration=sample_dur)
128
-
129
- logger.finalize()
130
- print("Done preprocessing %s.\n" % dataset_name)
131
-
132
-
133
- def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
- for dataset_name in librispeech_datasets["train"]["other"]:
135
- # Initialize the preprocessing
136
- dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
- if not dataset_root:
138
- return
139
-
140
- # Preprocess all speakers
141
- speaker_dirs = list(dataset_root.glob("*"))
142
- _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
-
144
-
145
- def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
146
- # Initialize the preprocessing
147
- dataset_name = "VoxCeleb1"
148
- dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
149
- if not dataset_root:
150
- return
151
-
152
- # Get the contents of the meta file
153
- with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
154
- metadata = [line.split("\t") for line in metafile][1:]
155
-
156
- # Select the ID and the nationality, filter out non-anglophone speakers
157
- nationalities = {line[0]: line[3] for line in metadata}
158
- keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
159
- nationality.lower() in anglophone_nationalites]
160
- print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
161
- (len(keep_speaker_ids), len(nationalities)))
162
-
163
- # Get the speaker directories for anglophone speakers only
164
- speaker_dirs = dataset_root.joinpath("wav").glob("*")
165
- speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
166
- speaker_dir.name in keep_speaker_ids]
167
- print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
168
- (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
169
-
170
- # Preprocess all speakers
171
- _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
172
-
173
-
174
- def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
175
- # Initialize the preprocessing
176
- dataset_name = "VoxCeleb2"
177
- dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
178
- if not dataset_root:
179
- return
180
-
181
- # Get the speaker directories
182
- # Preprocess all speakers
183
- speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
184
- _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/train.py DELETED
@@ -1,125 +0,0 @@
1
- from pathlib import Path
2
-
3
- import torch
4
-
5
- from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
- from encoder.model import SpeakerEncoder
7
- from encoder.params_model import *
8
- from encoder.visualizations import Visualizations
9
- from utils.profiler import Profiler
10
-
11
-
12
- def sync(device: torch.device):
13
- # For correct profiling (cuda operations are async)
14
- if device.type == "cuda":
15
- torch.cuda.synchronize(device)
16
-
17
-
18
- def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
- backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
- no_visdom: bool):
21
- # Create a dataset and a dataloader
22
- dataset = SpeakerVerificationDataset(clean_data_root)
23
- loader = SpeakerVerificationDataLoader(
24
- dataset,
25
- speakers_per_batch,
26
- utterances_per_speaker,
27
- num_workers=4,
28
- )
29
-
30
- # Setup the device on which to run the forward pass and the loss. These can be different,
31
- # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
- # hyperparameters) faster on the CPU.
33
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- # FIXME: currently, the gradient is None if loss_device is cuda
35
- loss_device = torch.device("cpu")
36
-
37
- # Create the model and the optimizer
38
- model = SpeakerEncoder(device, loss_device)
39
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
- init_step = 1
41
-
42
- # Configure file path for the model
43
- model_dir = models_dir / run_id
44
- model_dir.mkdir(exist_ok=True, parents=True)
45
- state_fpath = model_dir / "encoder.pt"
46
-
47
- # Load any existing model
48
- if not force_restart:
49
- if state_fpath.exists():
50
- print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
- checkpoint = torch.load(state_fpath)
52
- init_step = checkpoint["step"]
53
- model.load_state_dict(checkpoint["model_state"])
54
- optimizer.load_state_dict(checkpoint["optimizer_state"])
55
- optimizer.param_groups[0]["lr"] = learning_rate_init
56
- else:
57
- print("No model \"%s\" found, starting training from scratch." % run_id)
58
- else:
59
- print("Starting the training from scratch.")
60
- model.train()
61
-
62
- # Initialize the visualization environment
63
- vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
- vis.log_dataset(dataset)
65
- vis.log_params()
66
- device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
- vis.log_implementation({"Device": device_name})
68
-
69
- # Training loop
70
- profiler = Profiler(summarize_every=10, disabled=False)
71
- for step, speaker_batch in enumerate(loader, init_step):
72
- profiler.tick("Blocking, waiting for batch (threaded)")
73
-
74
- # Forward pass
75
- inputs = torch.from_numpy(speaker_batch.data).to(device)
76
- sync(device)
77
- profiler.tick("Data to %s" % device)
78
- embeds = model(inputs)
79
- sync(device)
80
- profiler.tick("Forward pass")
81
- embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
- loss, eer = model.loss(embeds_loss)
83
- sync(loss_device)
84
- profiler.tick("Loss")
85
-
86
- # Backward pass
87
- model.zero_grad()
88
- loss.backward()
89
- profiler.tick("Backward pass")
90
- model.do_gradient_ops()
91
- optimizer.step()
92
- profiler.tick("Parameter update")
93
-
94
- # Update visualizations
95
- # learning_rate = optimizer.param_groups[0]["lr"]
96
- vis.update(loss.item(), eer, step)
97
-
98
- # Draw projections and save them to the backup folder
99
- if umap_every != 0 and step % umap_every == 0:
100
- print("Drawing and saving projections (step %d)" % step)
101
- projection_fpath = model_dir / f"umap_{step:06d}.png"
102
- embeds = embeds.detach().cpu().numpy()
103
- vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
- vis.save()
105
-
106
- # Overwrite the latest version of the model
107
- if save_every != 0 and step % save_every == 0:
108
- print("Saving the model (step %d)" % step)
109
- torch.save({
110
- "step": step + 1,
111
- "model_state": model.state_dict(),
112
- "optimizer_state": optimizer.state_dict(),
113
- }, state_fpath)
114
-
115
- # Make a backup
116
- if backup_every != 0 and step % backup_every == 0:
117
- print("Making a backup (step %d)" % step)
118
- backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
- torch.save({
120
- "step": step + 1,
121
- "model_state": model.state_dict(),
122
- "optimizer_state": optimizer.state_dict(),
123
- }, backup_fpath)
124
-
125
- profiler.tick("Extras (visualizations, saving)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoder/visualizations.py DELETED
@@ -1,179 +0,0 @@
1
- from datetime import datetime
2
- from time import perf_counter as timer
3
-
4
- import numpy as np
5
- import umap
6
- import visdom
7
-
8
- from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
-
10
-
11
- colormap = np.array([
12
- [76, 255, 0],
13
- [0, 127, 70],
14
- [255, 0, 0],
15
- [255, 217, 38],
16
- [0, 135, 255],
17
- [165, 0, 165],
18
- [255, 167, 255],
19
- [0, 255, 255],
20
- [255, 96, 38],
21
- [142, 76, 0],
22
- [33, 0, 127],
23
- [0, 0, 0],
24
- [183, 183, 183],
25
- ], dtype=np.float) / 255
26
-
27
-
28
- class Visualizations:
29
- def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
- # Tracking data
31
- self.last_update_timestamp = timer()
32
- self.update_every = update_every
33
- self.step_times = []
34
- self.losses = []
35
- self.eers = []
36
- print("Updating the visualizations every %d steps." % update_every)
37
-
38
- # If visdom is disabled TODO: use a better paradigm for that
39
- self.disabled = disabled
40
- if self.disabled:
41
- return
42
-
43
- # Set the environment name
44
- now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
- if env_name is None:
46
- self.env_name = now
47
- else:
48
- self.env_name = "%s (%s)" % (env_name, now)
49
-
50
- # Connect to visdom and open the corresponding window in the browser
51
- try:
52
- self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
- except ConnectionError:
54
- raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
- "start it.")
56
- # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
-
58
- # Create the windows
59
- self.loss_win = None
60
- self.eer_win = None
61
- # self.lr_win = None
62
- self.implementation_win = None
63
- self.projection_win = None
64
- self.implementation_string = ""
65
-
66
- def log_params(self):
67
- if self.disabled:
68
- return
69
- from encoder import params_data
70
- from encoder import params_model
71
- param_string = "<b>Model parameters</b>:<br>"
72
- for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
- value = getattr(params_model, param_name)
74
- param_string += "\t%s: %s<br>" % (param_name, value)
75
- param_string += "<b>Data parameters</b>:<br>"
76
- for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
- value = getattr(params_data, param_name)
78
- param_string += "\t%s: %s<br>" % (param_name, value)
79
- self.vis.text(param_string, opts={"title": "Parameters"})
80
-
81
- def log_dataset(self, dataset: SpeakerVerificationDataset):
82
- if self.disabled:
83
- return
84
- dataset_string = ""
85
- dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
- dataset_string += "\n" + dataset.get_logs()
87
- dataset_string = dataset_string.replace("\n", "<br>")
88
- self.vis.text(dataset_string, opts={"title": "Dataset"})
89
-
90
- def log_implementation(self, params):
91
- if self.disabled:
92
- return
93
- implementation_string = ""
94
- for param, value in params.items():
95
- implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
- implementation_string = implementation_string.replace("\n", "<br>")
97
- self.implementation_string = implementation_string
98
- self.implementation_win = self.vis.text(
99
- implementation_string,
100
- opts={"title": "Training implementation"}
101
- )
102
-
103
- def update(self, loss, eer, step):
104
- # Update the tracking data
105
- now = timer()
106
- self.step_times.append(1000 * (now - self.last_update_timestamp))
107
- self.last_update_timestamp = now
108
- self.losses.append(loss)
109
- self.eers.append(eer)
110
- print(".", end="")
111
-
112
- # Update the plots every <update_every> steps
113
- if step % self.update_every != 0:
114
- return
115
- time_string = "Step time: mean: %5dms std: %5dms" % \
116
- (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
- print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
- (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
- if not self.disabled:
120
- self.loss_win = self.vis.line(
121
- [np.mean(self.losses)],
122
- [step],
123
- win=self.loss_win,
124
- update="append" if self.loss_win else None,
125
- opts=dict(
126
- legend=["Avg. loss"],
127
- xlabel="Step",
128
- ylabel="Loss",
129
- title="Loss",
130
- )
131
- )
132
- self.eer_win = self.vis.line(
133
- [np.mean(self.eers)],
134
- [step],
135
- win=self.eer_win,
136
- update="append" if self.eer_win else None,
137
- opts=dict(
138
- legend=["Avg. EER"],
139
- xlabel="Step",
140
- ylabel="EER",
141
- title="Equal error rate"
142
- )
143
- )
144
- if self.implementation_win is not None:
145
- self.vis.text(
146
- self.implementation_string + ("<b>%s</b>" % time_string),
147
- win=self.implementation_win,
148
- opts={"title": "Training implementation"},
149
- )
150
-
151
- # Reset the tracking
152
- self.losses.clear()
153
- self.eers.clear()
154
- self.step_times.clear()
155
-
156
- def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
- import matplotlib.pyplot as plt
158
-
159
- max_speakers = min(max_speakers, len(colormap))
160
- embeds = embeds[:max_speakers * utterances_per_speaker]
161
-
162
- n_speakers = len(embeds) // utterances_per_speaker
163
- ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
- colors = [colormap[i] for i in ground_truth]
165
-
166
- reducer = umap.UMAP()
167
- projected = reducer.fit_transform(embeds)
168
- plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
- plt.gca().set_aspect("equal", "datalim")
170
- plt.title("UMAP projection (step %d)" % step)
171
- if not self.disabled:
172
- self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
- if out_fpath is not None:
174
- plt.savefig(out_fpath)
175
- plt.clf()
176
-
177
- def save(self):
178
- if not self.disabled:
179
- self.vis.save([self.env_name])