File size: 905 Bytes
d7959a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b38d3a
d7959a1
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import math
import torch
import torchaudio 
from models.ecapa_tdnn import ECAPA_TDNN_SMALL

import torch.nn.functional as F
score_fn = nn.CosineSimilarity()
def load_model(checkpoint):
    model = ECAPA_TDNN_SMALL(
        feat_dim=1024, feat_type="wavlm_large", config_path=None
    )
    state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict, strict=False)
    return model

def inference_kathbadh( wav1, wav2):
    checkpoint = r"./wavlm_large_kathbadh_finetune.pth"
    model = load_model(checkpoint)
    model.eval()
    wav1, sr = torchaudio.load(wav1)
    wav2, sr = torchaudio.load(wav2)
    # input = torch.cat([wav1, wav2], dim=0)
    with torch.no_grad():
        embedding1 = model(wav1)
        embedding2 = model(wav2)
        score = score_fn(embedding1, embedding2)
    return score.item()