Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import math | |
import torch | |
import torchaudio | |
from models.ecapa_tdnn import ECAPA_TDNN_SMALL | |
import torch.nn.functional as F | |
score_fn = nn.CosineSimilarity() | |
def load_model(checkpoint): | |
model = ECAPA_TDNN_SMALL( | |
feat_dim=1024, feat_type="wavlm_large", config_path=None | |
) | |
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) | |
model.load_state_dict(state_dict, strict=False) | |
return model | |
def inference_kathbadh( wav1, wav2): | |
checkpoint = r"./wavlm_large_kathbadh_finetune.pth" | |
model = load_model(checkpoint) | |
model.eval() | |
wav1, sr = torchaudio.load(wav1) | |
wav2, sr = torchaudio.load(wav2) | |
# input = torch.cat([wav1, wav2], dim=0) | |
with torch.no_grad(): | |
embedding1 = model(wav1) | |
embedding2 = model(wav2) | |
score = score_fn(embedding1, embedding2) | |
return score.item() | |