KhadgaA commited on
Commit
d7959a1
·
verified ·
1 Parent(s): 92dbcae

Init commit

Browse files
inference_kathbadh.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch
5
+ import torchaudio
6
+ from models.ecapa_tdnn import ECAPA_TDNN_SMALL
7
+
8
+ import torch.nn.functional as F
9
+ score_fn = nn.CosineSimilarity()
10
+ def load_model(checkpoint):
11
+ model = ECAPA_TDNN_SMALL(
12
+ feat_dim=1024, feat_type="wavlm_large", config_path=None
13
+ )
14
+ state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
15
+ model.load_state_dict(state_dict, strict=False)
16
+ return model
17
+
18
+ def inference_kathbadh( wav1, wav2):
19
+ checkpoint = r"C:\Users\KHADGA JYOTH ALLI\Desktop\programming\Class Work\IITJ\Speech Understanding\Speaker-verification\wavlm_large_kathbadh_finetune.pth"
20
+ model = load_model(checkpoint)
21
+ model.eval()
22
+ wav1, sr = torchaudio.load(wav1)
23
+ wav2, sr = torchaudio.load(wav2)
24
+ # input = torch.cat([wav1, wav2], dim=0)
25
+ with torch.no_grad():
26
+ embedding1 = model(wav1)
27
+ embedding2 = model(wav2)
28
+ score = score_fn(embedding1, embedding2)
29
+ return score.item()
30
+
models/__pycache__/ecapa_tdnn.cpython-310.pyc ADDED
Binary file (9.2 kB). View file
 
models/__pycache__/ecapa_tdnn.cpython-39.pyc ADDED
Binary file (9.13 kB). View file
 
models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
models/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.02 kB). View file
 
models/ecapa_tdnn.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio.transforms as trans
7
+ from .utils import UpstreamExpert
8
+ import s3prl.hub as hub
9
+
10
+
11
+ """ Res2Conv1d + BatchNorm1d + ReLU
12
+ """
13
+
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ """
17
+ in_channels == out_channels == channels
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ channels,
23
+ kernel_size=1,
24
+ stride=1,
25
+ padding=0,
26
+ dilation=1,
27
+ bias=True,
28
+ scale=4,
29
+ ):
30
+ super().__init__()
31
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
32
+ self.scale = scale
33
+ self.width = channels // scale
34
+ self.nums = scale if scale == 1 else scale - 1
35
+
36
+ self.convs = []
37
+ self.bns = []
38
+ for i in range(self.nums):
39
+ self.convs.append(
40
+ nn.Conv1d(
41
+ self.width,
42
+ self.width,
43
+ kernel_size,
44
+ stride,
45
+ padding,
46
+ dilation,
47
+ bias=bias,
48
+ )
49
+ )
50
+ self.bns.append(nn.BatchNorm1d(self.width))
51
+ self.convs = nn.ModuleList(self.convs)
52
+ self.bns = nn.ModuleList(self.bns)
53
+
54
+ def forward(self, x):
55
+ out = []
56
+ spx = torch.split(x, self.width, 1)
57
+ for i in range(self.nums):
58
+ if i == 0:
59
+ sp = spx[i]
60
+ else:
61
+ sp = sp + spx[i]
62
+ # Order: conv -> relu -> bn
63
+ sp = self.convs[i](sp)
64
+ sp = self.bns[i](F.relu(sp))
65
+ out.append(sp)
66
+ if self.scale != 1:
67
+ out.append(spx[self.nums])
68
+ out = torch.cat(out, dim=1)
69
+
70
+ return out
71
+
72
+
73
+ """ Conv1d + BatchNorm1d + ReLU
74
+ """
75
+
76
+
77
+ class Conv1dReluBn(nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_channels,
81
+ out_channels,
82
+ kernel_size=1,
83
+ stride=1,
84
+ padding=0,
85
+ dilation=1,
86
+ bias=True,
87
+ ):
88
+ super().__init__()
89
+ self.conv = nn.Conv1d(
90
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
91
+ )
92
+ self.bn = nn.BatchNorm1d(out_channels)
93
+
94
+ def forward(self, x):
95
+ return self.bn(F.relu(self.conv(x)))
96
+
97
+
98
+ """ The SE connection of 1D case.
99
+ """
100
+
101
+
102
+ class SE_Connect(nn.Module):
103
+ def __init__(self, channels, se_bottleneck_dim=128):
104
+ super().__init__()
105
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
106
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
107
+
108
+ def forward(self, x):
109
+ out = x.mean(dim=2)
110
+ out = F.relu(self.linear1(out))
111
+ out = torch.sigmoid(self.linear2(out))
112
+ out = x * out.unsqueeze(2)
113
+
114
+ return out
115
+
116
+
117
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
118
+ """
119
+
120
+
121
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
122
+ # return nn.Sequential(
123
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
124
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
125
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
126
+ # SE_Connect(channels)
127
+ # )
128
+
129
+
130
+ class SE_Res2Block(nn.Module):
131
+ def __init__(
132
+ self,
133
+ in_channels,
134
+ out_channels,
135
+ kernel_size,
136
+ stride,
137
+ padding,
138
+ dilation,
139
+ scale,
140
+ se_bottleneck_dim,
141
+ ):
142
+ super().__init__()
143
+ self.Conv1dReluBn1 = Conv1dReluBn(
144
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
145
+ )
146
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
147
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
148
+ )
149
+ self.Conv1dReluBn2 = Conv1dReluBn(
150
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
151
+ )
152
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
153
+
154
+ self.shortcut = None
155
+ if in_channels != out_channels:
156
+ self.shortcut = nn.Conv1d(
157
+ in_channels=in_channels,
158
+ out_channels=out_channels,
159
+ kernel_size=1,
160
+ )
161
+
162
+ def forward(self, x):
163
+ residual = x
164
+ if self.shortcut:
165
+ residual = self.shortcut(x)
166
+
167
+ x = self.Conv1dReluBn1(x)
168
+ x = self.Res2Conv1dReluBn(x)
169
+ x = self.Conv1dReluBn2(x)
170
+ x = self.SE_Connect(x)
171
+
172
+ return x + residual
173
+
174
+
175
+ """ Attentive weighted mean and standard deviation pooling.
176
+ """
177
+
178
+
179
+ class AttentiveStatsPool(nn.Module):
180
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
181
+ super().__init__()
182
+ self.global_context_att = global_context_att
183
+
184
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
185
+ if global_context_att:
186
+ self.linear1 = nn.Conv1d(
187
+ in_dim * 3, attention_channels, kernel_size=1
188
+ ) # equals W and b in the paper
189
+ else:
190
+ self.linear1 = nn.Conv1d(
191
+ in_dim, attention_channels, kernel_size=1
192
+ ) # equals W and b in the paper
193
+ self.linear2 = nn.Conv1d(
194
+ attention_channels, in_dim, kernel_size=1
195
+ ) # equals V and k in the paper
196
+
197
+ def forward(self, x):
198
+
199
+ if self.global_context_att:
200
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
201
+ context_std = torch.sqrt(
202
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
203
+ ).expand_as(x)
204
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
205
+ else:
206
+ x_in = x
207
+
208
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
209
+ alpha = torch.tanh(self.linear1(x_in))
210
+ # alpha = F.relu(self.linear1(x_in))
211
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
212
+ mean = torch.sum(alpha * x, dim=2)
213
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
214
+ std = torch.sqrt(residuals.clamp(min=1e-9))
215
+ return torch.cat([mean, std], dim=1)
216
+
217
+
218
+ class ECAPA_TDNN(nn.Module):
219
+ def __init__(
220
+ self,
221
+ feat_dim=80,
222
+ channels=512,
223
+ emb_dim=192,
224
+ global_context_att=False,
225
+ feat_type="fbank",
226
+ sr=16000,
227
+ feature_selection="hidden_states",
228
+ update_extract=False,
229
+ config_path=None,
230
+ ):
231
+ super().__init__()
232
+
233
+ self.feat_type = feat_type
234
+ self.feature_selection = feature_selection
235
+ self.update_extract = update_extract
236
+ self.sr = sr
237
+
238
+ if feat_type == "fbank" or feat_type == "mfcc":
239
+ self.update_extract = False
240
+
241
+ win_len = int(sr * 0.025)
242
+ hop_len = int(sr * 0.01)
243
+
244
+ if feat_type == "fbank":
245
+ self.feature_extract = trans.MelSpectrogram(
246
+ sample_rate=sr,
247
+ n_fft=512,
248
+ win_length=win_len,
249
+ hop_length=hop_len,
250
+ f_min=0.0,
251
+ f_max=sr // 2,
252
+ pad=0,
253
+ n_mels=feat_dim,
254
+ )
255
+ elif feat_type == "mfcc":
256
+ melkwargs = {
257
+ "n_fft": 512,
258
+ "win_length": win_len,
259
+ "hop_length": hop_len,
260
+ "f_min": 0.0,
261
+ "f_max": sr // 2,
262
+ "pad": 0,
263
+ }
264
+ self.feature_extract = trans.MFCC(
265
+ sample_rate=sr, n_mfcc=feat_dim, log_mels=False, melkwargs=melkwargs
266
+ )
267
+ else:
268
+ if config_path is None:
269
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
270
+ else:
271
+ self.feature_extract = UpstreamExpert(config_path)
272
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
273
+ self.feature_extract.model.encoder.layers[23].self_attn,
274
+ "fp32_attention",
275
+ ):
276
+ self.feature_extract.model.encoder.layers[
277
+ 23
278
+ ].self_attn.fp32_attention = False
279
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
280
+ self.feature_extract.model.encoder.layers[11].self_attn,
281
+ "fp32_attention",
282
+ ):
283
+ self.feature_extract.model.encoder.layers[
284
+ 11
285
+ ].self_attn.fp32_attention = False
286
+
287
+ self.feat_num = self.get_feat_num()
288
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
289
+
290
+ if feat_type != "fbank" and feat_type != "mfcc":
291
+ freeze_list = [
292
+ "final_proj",
293
+ "label_embs_concat",
294
+ "mask_emb",
295
+ "project_q",
296
+ "quantizer",
297
+ ]
298
+ for name, param in self.feature_extract.named_parameters():
299
+ for freeze_val in freeze_list:
300
+ if freeze_val in name:
301
+ param.requires_grad = False
302
+ break
303
+
304
+ if not self.update_extract:
305
+ for param in self.feature_extract.parameters():
306
+ param.requires_grad = False
307
+
308
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
309
+ # self.channels = [channels] * 4 + [channels * 3]
310
+ self.channels = [channels] * 4 + [1536]
311
+
312
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
313
+ self.layer2 = SE_Res2Block(
314
+ self.channels[0],
315
+ self.channels[1],
316
+ kernel_size=3,
317
+ stride=1,
318
+ padding=2,
319
+ dilation=2,
320
+ scale=8,
321
+ se_bottleneck_dim=128,
322
+ )
323
+ self.layer3 = SE_Res2Block(
324
+ self.channels[1],
325
+ self.channels[2],
326
+ kernel_size=3,
327
+ stride=1,
328
+ padding=3,
329
+ dilation=3,
330
+ scale=8,
331
+ se_bottleneck_dim=128,
332
+ )
333
+ self.layer4 = SE_Res2Block(
334
+ self.channels[2],
335
+ self.channels[3],
336
+ kernel_size=3,
337
+ stride=1,
338
+ padding=4,
339
+ dilation=4,
340
+ scale=8,
341
+ se_bottleneck_dim=128,
342
+ )
343
+
344
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
345
+ cat_channels = channels * 3
346
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
347
+ self.pooling = AttentiveStatsPool(
348
+ self.channels[-1],
349
+ attention_channels=128,
350
+ global_context_att=global_context_att,
351
+ )
352
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
353
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
354
+
355
+ def get_feat_num(self):
356
+ self.feature_extract.eval()
357
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
358
+ with torch.no_grad():
359
+ features = self.feature_extract(wav)
360
+ select_feature = features[self.feature_selection]
361
+ if isinstance(select_feature, (list, tuple)):
362
+ return len(select_feature)
363
+ else:
364
+ return 1
365
+
366
+ def get_feat(self, x):
367
+ if self.update_extract:
368
+ x = self.feature_extract([sample for sample in x])
369
+ else:
370
+ with torch.no_grad():
371
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
372
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
373
+ else:
374
+ x = self.feature_extract([sample for sample in x])
375
+
376
+ if self.feat_type == "fbank":
377
+ x = x.log()
378
+
379
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
380
+ x = x[self.feature_selection]
381
+ if isinstance(x, (list, tuple)):
382
+ x = torch.stack(x, dim=0)
383
+ else:
384
+ x = x.unsqueeze(0)
385
+ norm_weights = (
386
+ F.softmax(self.feature_weight, dim=-1)
387
+ .unsqueeze(-1)
388
+ .unsqueeze(-1)
389
+ .unsqueeze(-1)
390
+ )
391
+ x = (norm_weights * x).sum(dim=0)
392
+ x = torch.transpose(x, 1, 2) + 1e-6
393
+
394
+ x = self.instance_norm(x)
395
+ return x
396
+
397
+ def forward(self, x):
398
+ x = self.get_feat(x)
399
+
400
+ out1 = self.layer1(x)
401
+ out2 = self.layer2(out1)
402
+ out3 = self.layer3(out2)
403
+ out4 = self.layer4(out3)
404
+
405
+ out = torch.cat([out2, out3, out4], dim=1)
406
+ out = F.relu(self.conv(out))
407
+ out = self.bn(self.pooling(out))
408
+ out = self.linear(out)
409
+
410
+ return out
411
+
412
+
413
+ def ECAPA_TDNN_SMALL(
414
+ feat_dim,
415
+ emb_dim=256,
416
+ feat_type="fbank",
417
+ sr=16000,
418
+ feature_selection="hidden_states",
419
+ update_extract=False,
420
+ config_path=None,
421
+ ):
422
+ return ECAPA_TDNN(
423
+ feat_dim=feat_dim,
424
+ channels=512,
425
+ emb_dim=emb_dim,
426
+ feat_type=feat_type,
427
+ sr=sr,
428
+ feature_selection=feature_selection,
429
+ update_extract=update_extract,
430
+ config_path=config_path,
431
+ )
432
+
433
+
434
+ if __name__ == "__main__":
435
+ x = torch.zeros(2, 32000)
436
+ model = ECAPA_TDNN_SMALL(
437
+ feat_dim=768,
438
+ emb_dim=256,
439
+ feat_type="hubert_base",
440
+ feature_selection="hidden_states",
441
+ update_extract=False,
442
+ )
443
+
444
+ out = model(x)
445
+ # print(model)
446
+ print(out.shape)
models/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+ from s3prl.upstream.interfaces import UpstreamBase
4
+ from omegaconf import OmegaConf
5
+
6
+ import torch.nn.functional as F
7
+
8
+ def load_model(filepath):
9
+ state = torch.load(filepath, map_location=lambda storage, loc: storage)
10
+ cfg = state["cfg"]
11
+
12
+ task = cfg.task
13
+ model = cfg.model
14
+
15
+ return model, cfg, task
16
+
17
+
18
+ ###################
19
+ # UPSTREAM EXPERT #
20
+ ###################
21
+ class UpstreamExpert(UpstreamBase):
22
+ def __init__(self, ckpt, **kwargs):
23
+ super().__init__(**kwargs)
24
+
25
+ model, cfg, task = load_model(ckpt)
26
+ self.model = model
27
+ self.task = task
28
+
29
+ def forward(self, wavs):
30
+ if self.task.normalize:
31
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
32
+
33
+ device = wavs[0].device
34
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
35
+ wav_padding_mask = ~torch.lt(
36
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
37
+ wav_lengths.unsqueeze(1),
38
+ )
39
+ padded_wav = pad_sequence(wavs, batch_first=True)
40
+
41
+ features, feat_padding_mask = self.model.extract_features(
42
+ padded_wav,
43
+ padding_mask=wav_padding_mask,
44
+ mask=None,
45
+ )
46
+ return {
47
+ "default": features,
48
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ torchaudio
4
+ s3prl
5
+ soundfile