YuAnthony commited on
Commit
eb3c3fe
·
1 Parent(s): 84bee66

Create infer_contrast.py

Browse files
Files changed (1) hide show
  1. infer_contrast.py +51 -0
infer_contrast.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from utils.reader import load_audio
8
+ from utils.utility import add_arguments, print_arguments
9
+
10
+ parser = argparse.ArgumentParser(description=__doc__)
11
+ add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值')
13
+ add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
14
+ add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径')
15
+ # args = parser.parse_args()
16
+ args =parser.parse_known_args()[0]
17
+
18
+ print_arguments(args)
19
+
20
+ device = torch.device("cuda")
21
+
22
+ # 加载模型
23
+ # model = torch.jit.load(args.model_path)
24
+ model = torch.load(args.model_path)
25
+ # model.to(device)
26
+ model.eval()
27
+
28
+
29
+ # 预测音频
30
+ def infer(audio_path):
31
+ input_shape = eval(args.input_shape)
32
+ data = load_audio(audio_path, mode='infer', spec_len=input_shape[2])
33
+ data = data[np.newaxis, :]
34
+ data = torch.tensor(data, dtype=torch.float32)
35
+ # 执行预测
36
+ feature = model(data)
37
+ return feature.data.cpu().numpy()
38
+
39
+
40
+ def run(audio1,audio2):
41
+ # 要预测的两个人的音频文件
42
+ feature1 = infer(audio1)[0]
43
+ feature2 = infer(audio2)[0]
44
+ # 对角余弦值
45
+ dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
46
+ if dist > args.threshold:
47
+ result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
48
+ else:
49
+ result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
50
+
51
+ return result