TomCallan's picture
Upload 14 files
aed64b5
import logging
import numpy as np
from tqdm import tqdm
from deep_speaker.audio import Audio
from deep_speaker.batcher import LazyTripletBatcher
from deep_speaker.constants import NUM_FBANKS, NUM_FRAMES, CHECKPOINTS_TRIPLET_DIR, BATCH_SIZE
from deep_speaker.conv_models import DeepSpeakerModel
from deep_speaker.eval_metrics import evaluate
from deep_speaker.utils import load_best_checkpoint, enable_deterministic
logger = logging.getLogger(__name__)
def batch_cosine_similarity(x1, x2):
# https://en.wikipedia.org/wiki/Cosine_similarity
# 1 = equal direction ; -1 = opposite direction
mul = np.multiply(x1, x2)
s = np.sum(mul, axis=1)
# l1 = np.sum(np.multiply(x1, x1),axis=1)
# l2 = np.sum(np.multiply(x2, x2), axis=1)
# as values have have length 1, we don't need to divide by norm (as it is 1)
return s
def eval_model(working_dir: str, model: DeepSpeakerModel):
enable_deterministic()
audio = Audio(working_dir)
batcher = LazyTripletBatcher(working_dir, NUM_FRAMES, model)
speakers_list = list(audio.speakers_to_utterances.keys())
num_negative_speakers = 99
num_speakers = len(speakers_list)
y_pred = np.zeros(shape=(num_speakers, num_negative_speakers + 1)) # negatives + positive
for i, positive_speaker in tqdm(enumerate(speakers_list), desc='test', total=num_speakers):
# convention id[0] is anchor speaker, id[1] is positive, id[2:] are negative.
input_data = batcher.get_speaker_verification_data(positive_speaker, num_negative_speakers)
# batch size is not relevant. just making sure we don't push too much on the GPU.
predictions = model.m.predict(input_data, batch_size=BATCH_SIZE)
anchor_embedding = predictions[0]
for j, other_than_anchor_embedding in enumerate(predictions[1:]): # positive + negatives
y_pred[i][j] = batch_cosine_similarity([anchor_embedding], [other_than_anchor_embedding])[0]
# y_pred[i] = softmax(y_pred[i])
# could apply softmax here.
y_true = np.zeros_like(y_pred) # positive is at index 0.
y_true[:, 0] = 1.0
print(np.matrix(y_true))
print(np.matrix(y_pred))
print(np.min(y_pred), np.max(y_pred))
fm, tpr, acc, eer = evaluate(y_pred, y_true)
return fm, tpr, acc, eer
def test(working_dir, checkpoint_file=None):
batch_input_shape = [None, NUM_FRAMES, NUM_FBANKS, 1]
dsm = DeepSpeakerModel(batch_input_shape)
if checkpoint_file is None:
checkpoint_file = load_best_checkpoint(CHECKPOINTS_TRIPLET_DIR)
if checkpoint_file is not None:
logger.info(f'Found checkpoint [{checkpoint_file}]. Loading weights...')
dsm.m.load_weights(checkpoint_file, by_name=True)
else:
logger.info(f'Could not find any checkpoint in {checkpoint_file}.')
exit(1)
fm, tpr, acc, eer = eval_model(working_dir, model=dsm)
logger.info(f'f-measure = {fm:.3f}, true positive rate = {tpr:.3f}, '
f'accuracy = {acc:.3f}, equal error rate = {eer:.3f}')