Speaker-Verification / inference_kathbadh.py
KhadgaA's picture
Update inference_kathbadh.py
5b38d3a verified
raw
history blame contribute delete
905 Bytes
import torch
import torch.nn as nn
import math
import torch
import torchaudio
from models.ecapa_tdnn import ECAPA_TDNN_SMALL
import torch.nn.functional as F
score_fn = nn.CosineSimilarity()
def load_model(checkpoint):
model = ECAPA_TDNN_SMALL(
feat_dim=1024, feat_type="wavlm_large", config_path=None
)
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict, strict=False)
return model
def inference_kathbadh( wav1, wav2):
checkpoint = r"./wavlm_large_kathbadh_finetune.pth"
model = load_model(checkpoint)
model.eval()
wav1, sr = torchaudio.load(wav1)
wav2, sr = torchaudio.load(wav2)
# input = torch.cat([wav1, wav2], dim=0)
with torch.no_grad():
embedding1 = model(wav1)
embedding2 = model(wav2)
score = score_fn(embedding1, embedding2)
return score.item()