|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Run inference for pre-processed data with a trained model. |
|
""" |
|
|
|
import logging |
|
import math |
|
import numpy, math, pdb, sys, random |
|
import time, os, itertools, shutil, importlib |
|
import argparse |
|
import os |
|
import sys |
|
import glob |
|
from sklearn import metrics |
|
import soundfile as sf |
|
|
|
import torch |
|
import inference as encoder |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from pathlib import Path |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
from resemblyzer import VoiceEncoder, preprocess_wav |
|
|
|
|
|
def tuneThresholdfromScore(scores, labels, target_fa, target_fr=None): |
|
fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1) |
|
fnr = 1 - tpr |
|
|
|
fnr = fnr * 100 |
|
fpr = fpr * 100 |
|
|
|
tunedThreshold = []; |
|
if target_fr: |
|
for tfr in target_fr: |
|
idx = numpy.nanargmin(numpy.absolute((tfr - fnr))) |
|
tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); |
|
|
|
for tfa in target_fa: |
|
idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) |
|
tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); |
|
|
|
idxE = numpy.nanargmin(numpy.absolute((fnr - fpr))) |
|
eer = max(fpr[idxE], fnr[idxE]) |
|
|
|
return (tunedThreshold, eer, fpr, fnr); |
|
|
|
|
|
def loadWAV(filename, max_frames, evalmode=True, num_eval=10): |
|
|
|
max_audio = max_frames * 160 + 240 |
|
|
|
|
|
audio,sample_rate = sf.read(filename) |
|
|
|
feats_v0 = torch.from_numpy(audio).float() |
|
audiosize = audio.shape[0] |
|
|
|
if audiosize <= max_audio: |
|
shortage = math.floor((max_audio - audiosize + 1) / 2) |
|
audio = numpy.pad(audio, (shortage, shortage), 'constant', constant_values=0) |
|
audiosize = audio.shape[0] |
|
|
|
if evalmode: |
|
startframe = numpy.linspace(0, audiosize - max_audio, num=num_eval) |
|
else: |
|
startframe = numpy.array([numpy.int64(random.random() * (audiosize - max_audio))]) |
|
feats = [] |
|
if evalmode and max_frames == 0: |
|
feats.append(audio) |
|
else: |
|
for asf in startframe: |
|
feats.append(audio[int(asf):int(asf) + max_audio]) |
|
feat = numpy.stack(feats, axis=0) |
|
feat = torch.FloatTensor(feat) |
|
return feat; |
|
|
|
def evaluateFromList(listfilename, print_interval=100, test_path='', multi=False): |
|
|
|
lines = [] |
|
files = [] |
|
feats = {} |
|
tstart = time.time() |
|
|
|
|
|
with open(listfilename) as listfile: |
|
while True: |
|
line = listfile.readline(); |
|
if (not line): |
|
break; |
|
|
|
data = line.split(); |
|
|
|
|
|
if len(data) == 2: data = [random.randint(0,1)] + data |
|
|
|
files.append(data[1]) |
|
files.append(data[2]) |
|
lines.append(line) |
|
|
|
setfiles = list(set(files)) |
|
setfiles.sort() |
|
|
|
for idx, file in enumerate(setfiles): |
|
|
|
|
|
processed_wav = preprocess_wav(os.path.join(test_path,file)) |
|
embed = voice_encoder.embed_utterance(processed_wav) |
|
|
|
torch.cuda.empty_cache() |
|
ref_feat = torch.from_numpy(embed).unsqueeze(0) |
|
|
|
feats[file] = ref_feat |
|
|
|
telapsed = time.time() - tstart |
|
|
|
if idx % print_interval == 0: |
|
sys.stdout.write("\rReading %d of %d: %.2f Hz, embedding size %d"%(idx,len(setfiles),idx/telapsed,ref_feat.size()[1])); |
|
|
|
print('') |
|
all_scores = []; |
|
all_labels = []; |
|
all_trials = []; |
|
tstart = time.time() |
|
|
|
|
|
for idx, line in enumerate(lines): |
|
|
|
data = line.split(); |
|
|
|
if len(data) == 2: data = [random.randint(0,1)] + data |
|
|
|
ref_feat = feats[data[1]] |
|
com_feat = feats[data[2]] |
|
ref_feat = ref_feat.cuda() |
|
com_feat = com_feat.cuda() |
|
|
|
ref_feat = F.normalize(ref_feat, p=2, dim=1) |
|
com_feat = F.normalize(com_feat, p=2, dim=1) |
|
|
|
dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1)).detach().cpu().numpy(); |
|
|
|
score = -1 * numpy.mean(dist); |
|
|
|
all_scores.append(score); |
|
all_labels.append(int(data[0])); |
|
all_trials.append(data[1]+" "+data[2]) |
|
|
|
if idx % print_interval == 0: |
|
telapsed = time.time() - tstart |
|
sys.stdout.write("\rComputing %d of %d: %.2f Hz"%(idx,len(lines),idx/telapsed)); |
|
sys.stdout.flush(); |
|
|
|
print('\n') |
|
|
|
return (all_scores, all_labels, all_trials); |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser("baseline") |
|
parser.add_argument("--data_root", type=str, help="", required=True) |
|
parser.add_argument("--list", type=str, help="", required=True) |
|
parser.add_argument("--model_dir", type=str, help="model parameters for AudioEncoder", required=True) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
print("Preparing the encoder...") |
|
|
|
print("Insert the wav file name...") |
|
voice_encoder = VoiceEncoder().cuda() |
|
|
|
sc, lab, trials = evaluateFromList(args.list, print_interval=100, test_path=args.data_root) |
|
result = tuneThresholdfromScore(sc, lab, [1, 0.1]); |
|
print('EER %2.4f'%result[1]) |
|
|