|
import logging |
|
|
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from deep_speaker.audio import Audio |
|
from deep_speaker.batcher import LazyTripletBatcher |
|
from deep_speaker.constants import NUM_FBANKS, NUM_FRAMES, CHECKPOINTS_TRIPLET_DIR, BATCH_SIZE |
|
from deep_speaker.conv_models import DeepSpeakerModel |
|
from deep_speaker.eval_metrics import evaluate |
|
from deep_speaker.utils import load_best_checkpoint, enable_deterministic |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def batch_cosine_similarity(x1, x2): |
|
|
|
|
|
mul = np.multiply(x1, x2) |
|
s = np.sum(mul, axis=1) |
|
|
|
|
|
|
|
|
|
return s |
|
|
|
|
|
def eval_model(working_dir: str, model: DeepSpeakerModel): |
|
enable_deterministic() |
|
audio = Audio(working_dir) |
|
batcher = LazyTripletBatcher(working_dir, NUM_FRAMES, model) |
|
speakers_list = list(audio.speakers_to_utterances.keys()) |
|
num_negative_speakers = 99 |
|
num_speakers = len(speakers_list) |
|
y_pred = np.zeros(shape=(num_speakers, num_negative_speakers + 1)) |
|
for i, positive_speaker in tqdm(enumerate(speakers_list), desc='test', total=num_speakers): |
|
|
|
input_data = batcher.get_speaker_verification_data(positive_speaker, num_negative_speakers) |
|
|
|
predictions = model.m.predict(input_data, batch_size=BATCH_SIZE) |
|
anchor_embedding = predictions[0] |
|
for j, other_than_anchor_embedding in enumerate(predictions[1:]): |
|
y_pred[i][j] = batch_cosine_similarity([anchor_embedding], [other_than_anchor_embedding])[0] |
|
|
|
|
|
y_true = np.zeros_like(y_pred) |
|
y_true[:, 0] = 1.0 |
|
print(np.matrix(y_true)) |
|
print(np.matrix(y_pred)) |
|
print(np.min(y_pred), np.max(y_pred)) |
|
fm, tpr, acc, eer = evaluate(y_pred, y_true) |
|
return fm, tpr, acc, eer |
|
|
|
|
|
def test(working_dir, checkpoint_file=None): |
|
batch_input_shape = [None, NUM_FRAMES, NUM_FBANKS, 1] |
|
dsm = DeepSpeakerModel(batch_input_shape) |
|
if checkpoint_file is None: |
|
checkpoint_file = load_best_checkpoint(CHECKPOINTS_TRIPLET_DIR) |
|
if checkpoint_file is not None: |
|
logger.info(f'Found checkpoint [{checkpoint_file}]. Loading weights...') |
|
dsm.m.load_weights(checkpoint_file, by_name=True) |
|
else: |
|
logger.info(f'Could not find any checkpoint in {checkpoint_file}.') |
|
exit(1) |
|
|
|
fm, tpr, acc, eer = eval_model(working_dir, model=dsm) |
|
logger.info(f'f-measure = {fm:.3f}, true positive rate = {tpr:.3f}, ' |
|
f'accuracy = {acc:.3f}, equal error rate = {eer:.3f}') |
|
|