YuAnthony commited on
Commit
84bee66
·
1 Parent(s): ae5e53f

Delete infer_contrast.py

Browse files
Files changed (1) hide show
  1. infer_contrast.py +0 -51
infer_contrast.py DELETED
@@ -1,51 +0,0 @@
1
- import argparse
2
- import functools
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from utils.reader import load_audio
8
- from utils.utility import add_arguments, print_arguments
9
-
10
- parser = argparse.ArgumentParser(description=__doc__)
11
- add_arg = functools.partial(add_arguments, argparser=parser)
12
- add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值')
13
- add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
14
- add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径')
15
- # args = parser.parse_args()
16
- args =parser.parse_known_args()[0]
17
-
18
- print_arguments(args)
19
-
20
- device = torch.device("cuda")
21
-
22
- # 加载模型
23
- # model = torch.jit.load(args.model_path)
24
- model = torch.load(args.model_path)
25
- # model.to(device)
26
- model.eval()
27
-
28
-
29
- # 预测音频
30
- def infer(audio_path):
31
- input_shape = eval(args.input_shape)
32
- data = load_audio(audio_path, mode='infer', spec_len=input_shape[2])
33
- data = data[np.newaxis, :]
34
- data = torch.tensor(data, dtype=torch.float32)
35
- # 执行预测
36
- feature = model(data)
37
- return feature.data.cpu().numpy()
38
-
39
-
40
- def run(audio1,audio2):
41
- # 要预测的两个人的音频文件
42
- feature1 = infer(audio1)[0]
43
- feature2 = infer(audio2)[0]
44
- # 对角余弦值
45
- dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
46
- if dist > args.threshold:
47
- result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
48
- else:
49
- result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
50
-
51
- return result