huseinzol05 commited on
Commit
3f5d181
·
1 Parent(s): 57efae9

Upload evaluate-gpu.ipynb

Browse files
Files changed (1) hide show
  1. evaluate-gpu.ipynb +826 -0
evaluate-gpu.ipynb ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "02b2d284",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "\n",
12
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "4966a667",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# !wget https://huggingface.co/huseinzol05/language-model-bahasa-manglish-combined/resolve/main/model.klm\n",
23
+ "# !pip3 install pyctcdecode==0.1.0 pypi-kenlm==0.1.20210121"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 3,
29
+ "id": "42d8d861",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stderr",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/home/ubuntu/.local/lib/python3.8/site-packages/apex/pyprof/__init__.py:5: FutureWarning: pyprof will be removed by the end of June, 2022\n",
37
+ " warnings.warn(\"pyprof will be removed by the end of June, 2022\", FutureWarning)\n"
38
+ ]
39
+ }
40
+ ],
41
+ "source": [
42
+ "import transformers\n",
43
+ "from transformers import (\n",
44
+ " HfArgumentParser,\n",
45
+ " Trainer,\n",
46
+ " TrainingArguments,\n",
47
+ " Wav2Vec2CTCTokenizer,\n",
48
+ " Wav2Vec2FeatureExtractor,\n",
49
+ " Wav2Vec2ForCTC,\n",
50
+ " Wav2Vec2Processor,\n",
51
+ " is_apex_available,\n",
52
+ " set_seed,\n",
53
+ " AutoModelForCTC,\n",
54
+ " TFWav2Vec2ForCTC,\n",
55
+ " TFWav2Vec2PreTrainedModel,\n",
56
+ " Wav2Vec2PreTrainedModel,\n",
57
+ ")\n",
58
+ "from scipy.special import log_softmax"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 4,
64
+ "id": "0d6b421c",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "import torch"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 5,
74
+ "id": "060fb120",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "import string\n",
79
+ "import json\n",
80
+ "\n",
81
+ "CTC_VOCAB = [''] + list(string.ascii_lowercase + string.digits) + [' ']\n",
82
+ "vocab_dict = {v: k for k, v in enumerate(CTC_VOCAB)}\n",
83
+ "vocab_dict[\"|\"] = vocab_dict[\" \"]\n",
84
+ "del vocab_dict[\" \"]\n",
85
+ "vocab_dict[\"[UNK]\"] = len(vocab_dict)\n",
86
+ "vocab_dict[\"[PAD]\"] = len(vocab_dict)\n",
87
+ "\n",
88
+ "with open(\"ctc-vocab.json\", \"w\") as vocab_file:\n",
89
+ " json.dump(vocab_dict, vocab_file)\n",
90
+ "\n",
91
+ "tokenizer = Wav2Vec2CTCTokenizer(\n",
92
+ " \"ctc-vocab.json\",\n",
93
+ " unk_token=\"[UNK]\",\n",
94
+ " pad_token=\"[PAD]\",\n",
95
+ " word_delimiter_token=\"|\",\n",
96
+ ")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 6,
102
+ "id": "c16b890f",
103
+ "metadata": {},
104
+ "outputs": [
105
+ {
106
+ "data": {
107
+ "text/plain": [
108
+ "(765, 3579, 614)"
109
+ ]
110
+ },
111
+ "execution_count": 6,
112
+ "metadata": {},
113
+ "output_type": "execute_result"
114
+ }
115
+ ],
116
+ "source": [
117
+ "from glob import glob\n",
118
+ "malay = sorted(glob('malay-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n",
119
+ "singlish = sorted(glob('singlish-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n",
120
+ "mandarin = sorted(glob('mandarin-test/*.wav'), key = lambda x: int(x.split('/')[1].replace('.wav', '')))\n",
121
+ "len(malay), len(singlish), len(mandarin)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 7,
127
+ "id": "29568a5f",
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "data": {
132
+ "text/plain": [
133
+ "(765, 3579, 614)"
134
+ ]
135
+ },
136
+ "execution_count": 7,
137
+ "metadata": {},
138
+ "output_type": "execute_result"
139
+ }
140
+ ],
141
+ "source": [
142
+ "with open('malay-test.json') as fopen:\n",
143
+ " malay_label = json.load(fopen)\n",
144
+ "with open('singlish-test.json') as fopen:\n",
145
+ " singlish_label = json.load(fopen)\n",
146
+ "with open('mandarin-test.json') as fopen:\n",
147
+ " mandarin_label = json.load(fopen)\n",
148
+ " \n",
149
+ "len(malay_label), len(singlish_label), len(mandarin_label)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 8,
155
+ "id": "bdac1296",
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "data": {
160
+ "text/plain": [
161
+ "[('mandarin-test/460.wav', 'ting qi lai you dian xiang zai chang de na zhong'),\n",
162
+ " ('mandarin-test/256.wav', 'zai jia hao wu liao a'),\n",
163
+ " ('singlish-test/2169.wav', 'controlling our environment is important'),\n",
164
+ " ('mandarin-test/400.wav', 'bo fang gu zheng de ge qu'),\n",
165
+ " ('singlish-test/1001.wav', 'because they are the one that badly need it'),\n",
166
+ " ('singlish-test/4.wav',\n",
167
+ " 'rescuers who used what appeared to be makeshift stretchers to carry the injured'),\n",
168
+ " ('singlish-test/392.wav', 'i attached a mirror to my closet door'),\n",
169
+ " ('singlish-test/2563.wav', 'do you know the answer'),\n",
170
+ " ('singlish-test/799.wav',\n",
171
+ " 'this kind of packaging can pose a danger to animals'),\n",
172
+ " ('singlish-test/1165.wav',\n",
173
+ " 'a lot of parents ive spoken to say they dont have the luxury to do that')]"
174
+ ]
175
+ },
176
+ "execution_count": 8,
177
+ "metadata": {},
178
+ "output_type": "execute_result"
179
+ }
180
+ ],
181
+ "source": [
182
+ "from sklearn.utils import shuffle\n",
183
+ "\n",
184
+ "audio = malay + singlish + mandarin\n",
185
+ "labels = malay_label + singlish_label + mandarin_label\n",
186
+ "audio, labels = shuffle(audio, labels)\n",
187
+ "test_set = list(zip(audio, labels))\n",
188
+ "test_set[:10]"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 9,
194
+ "id": "69cb17cc",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "import soundfile as sf\n",
199
+ "import numpy as np\n",
200
+ "\n",
201
+ "def norm_audio(x):\n",
202
+ " return (x - x.mean()) / np.sqrt(x.var() + 1e-7)\n",
203
+ "\n",
204
+ "def sequence_1d(\n",
205
+ " seq, maxlen=None, padding: str = 'post', pad_int=0, return_len=False\n",
206
+ "):\n",
207
+ " if padding not in ['post', 'pre']:\n",
208
+ " raise ValueError('padding only supported [`post`, `pre`]')\n",
209
+ "\n",
210
+ " if not maxlen:\n",
211
+ " maxlen = max([len(s) for s in seq])\n",
212
+ "\n",
213
+ " padded_seqs, length = [], []\n",
214
+ " for s in seq:\n",
215
+ " if isinstance(s, np.ndarray):\n",
216
+ " s = s.tolist()\n",
217
+ " if padding == 'post':\n",
218
+ " padded_seqs.append(s + [pad_int] * (maxlen - len(s)))\n",
219
+ " if padding == 'pre':\n",
220
+ " padded_seqs.append([pad_int] * (maxlen - len(s)) + s)\n",
221
+ " length.append(len(s))\n",
222
+ " if return_len:\n",
223
+ " return np.array(padded_seqs), length\n",
224
+ " return np.array(padded_seqs)\n",
225
+ "\n",
226
+ "def batching(audios):\n",
227
+ " audios = [sf.read(a)[0] for a in audios]\n",
228
+ " batch, lens = sequence_1d(audios,return_len=True)\n",
229
+ " attentions = [[1] * l for l in lens]\n",
230
+ " attentions = sequence_1d(attentions)\n",
231
+ " normed_input_values = []\n",
232
+ "\n",
233
+ " for vector, length in zip(batch, attentions.sum(-1)):\n",
234
+ " normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)\n",
235
+ " if length < normed_slice.shape[0]:\n",
236
+ " normed_slice[length:] = 0.0\n",
237
+ "\n",
238
+ " normed_input_values.append(normed_slice)\n",
239
+ "\n",
240
+ " normed_input_values = np.array(normed_input_values)\n",
241
+ " return normed_input_values.astype(np.float32), attentions"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 10,
247
+ "id": "f97f22e4",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "model = AutoModelForCTC.from_pretrained(\n",
252
+ " './wav2vec2-mixed-v3/checkpoint-55000',\n",
253
+ " ctc_loss_reduction=\"mean\",\n",
254
+ " pad_token_id=tokenizer.pad_token_id,\n",
255
+ " vocab_size=len(tokenizer),\n",
256
+ ").cuda()"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 11,
262
+ "id": "20fee479",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "_ = model.eval()"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 12,
272
+ "id": "51703510",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "batch_size = 4\n",
277
+ "batch_x = audio[:batch_size]\n",
278
+ "normed_input_values, attentions = batching(batch_x)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 13,
284
+ "id": "065fce75",
285
+ "metadata": {},
286
+ "outputs": [],
287
+ "source": [
288
+ "o_pt = model(torch.from_numpy(normed_input_values.astype(np.float32)).cuda(), \n",
289
+ " attention_mask = torch.from_numpy(attentions).cuda())\n",
290
+ "o_pt = o_pt.logits.detach().cpu().numpy()\n",
291
+ "o_pt = log_softmax(o_pt, axis = -1)"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 14,
297
+ "id": "b7851fc9",
298
+ "metadata": {},
299
+ "outputs": [
300
+ {
301
+ "data": {
302
+ "text/plain": [
303
+ "['ting qi lai you dian xiang zai chang de na zhong',\n",
304
+ " 'zai jia hao wu liao wa',\n",
305
+ " 'controlling our environment is important',\n",
306
+ " 'bo fang gu zheng de ge qu']"
307
+ ]
308
+ },
309
+ "execution_count": 14,
310
+ "metadata": {},
311
+ "output_type": "execute_result"
312
+ }
313
+ ],
314
+ "source": [
315
+ "pred_ids = np.argmax(o_pt, axis = -1)\n",
316
+ "tokenizer.batch_decode(pred_ids)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 15,
322
+ "id": "3efd715e",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "unique_vocab = list(vocab_dict.keys())\n",
327
+ "unique_vocab[-3] = ' ' \n",
328
+ "unique_vocab[-2] = '?'\n",
329
+ "unique_vocab[-1] = '_'"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 16,
335
+ "id": "3024298f",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "from pyctcdecode import build_ctcdecoder\n",
340
+ "import kenlm\n",
341
+ "\n",
342
+ "kenlm_model = kenlm.Model('model.klm')\n",
343
+ "decoder = build_ctcdecoder(\n",
344
+ " unique_vocab,\n",
345
+ " kenlm_model,\n",
346
+ " alpha=0.2,\n",
347
+ " beta=1.0,\n",
348
+ " ctc_token_idx=tokenizer.pad_token_id\n",
349
+ ")"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": 17,
355
+ "id": "6100ea60",
356
+ "metadata": {},
357
+ "outputs": [
358
+ {
359
+ "name": "stdout",
360
+ "output_type": "stream",
361
+ "text": [
362
+ "0 ting qi lai you dian xiang zai chang de na zhong\n",
363
+ "1 zai jia hao wu liao wa\n",
364
+ "2 controlling our environment is important\n",
365
+ "3 bo fang gu zheng de ge qu\n"
366
+ ]
367
+ }
368
+ ],
369
+ "source": [
370
+ "for k in range(len(o_pt)):\n",
371
+ " out = decoder.decode_beams(o_pt[k], prune_history=True)\n",
372
+ " d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]\n",
373
+ " print(k, d_lm2)"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": 18,
379
+ "id": "4672ac73",
380
+ "metadata": {},
381
+ "outputs": [
382
+ {
383
+ "data": {
384
+ "text/plain": [
385
+ "['ting qi lai you dian xiang zai chang de na zhong',\n",
386
+ " 'zai jia hao wu liao a',\n",
387
+ " 'controlling our environment is important',\n",
388
+ " 'bo fang gu zheng de ge qu']"
389
+ ]
390
+ },
391
+ "execution_count": 18,
392
+ "metadata": {},
393
+ "output_type": "execute_result"
394
+ }
395
+ ],
396
+ "source": [
397
+ "labels[:batch_size]"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 19,
403
+ "id": "5d47692d",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "def calculate_cer(actual, hyp):\n",
408
+ " \"\"\"\n",
409
+ " Calculate CER using `python-Levenshtein`.\n",
410
+ " \"\"\"\n",
411
+ " import Levenshtein as Lev\n",
412
+ "\n",
413
+ " actual = actual.replace(' ', '')\n",
414
+ " hyp = hyp.replace(' ', '')\n",
415
+ " return Lev.distance(actual, hyp) / len(actual)\n",
416
+ "\n",
417
+ "\n",
418
+ "def calculate_wer(actual, hyp):\n",
419
+ " \"\"\"\n",
420
+ " Calculate WER using `python-Levenshtein`.\n",
421
+ " \"\"\"\n",
422
+ " import Levenshtein as Lev\n",
423
+ "\n",
424
+ " b = set(actual.split() + hyp.split())\n",
425
+ " word2char = dict(zip(b, range(len(b))))\n",
426
+ "\n",
427
+ " w1 = [chr(word2char[w]) for w in actual.split()]\n",
428
+ " w2 = [chr(word2char[w]) for w in hyp.split()]\n",
429
+ "\n",
430
+ " return Lev.distance(''.join(w1), ''.join(w2)) / len(actual.split())"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": 20,
436
+ "id": "c01ea2e4",
437
+ "metadata": {},
438
+ "outputs": [
439
+ {
440
+ "name": "stderr",
441
+ "output_type": "stream",
442
+ "text": [
443
+ "100%|██████████| 1240/1240 [04:27<00:00, 4.63it/s]\n"
444
+ ]
445
+ }
446
+ ],
447
+ "source": [
448
+ "from tqdm import tqdm\n",
449
+ "\n",
450
+ "wer, cer = [], []\n",
451
+ "wer_lm, cer_lm = [], []\n",
452
+ "\n",
453
+ "for i in tqdm(range(0, len(audio), batch_size)):\n",
454
+ " torch.cuda.empty_cache()\n",
455
+ " \n",
456
+ " batch_x = audio[i: i + batch_size]\n",
457
+ " batch_y = labels[i: i + batch_size]\n",
458
+ " normed_input_values, attentions = batching(batch_x)\n",
459
+ " inputs = torch.from_numpy(normed_input_values.astype(np.float32)).cuda()\n",
460
+ " attention_mask = torch.from_numpy(attentions).cuda()\n",
461
+ " o_pt = model(inputs, attention_mask = attention_mask)\n",
462
+ " o_pt = o_pt.logits.detach().cpu().numpy()\n",
463
+ " o_pt = log_softmax(o_pt, axis = -1)\n",
464
+ " pred_ids = np.argmax(o_pt, axis = -1)\n",
465
+ " pred = tokenizer.batch_decode(pred_ids)\n",
466
+ " for k in range(len(o_pt)):\n",
467
+ " out = decoder.decode_beams(o_pt[k], prune_history=True)\n",
468
+ " d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]\n",
469
+ " \n",
470
+ " wer.append(calculate_wer(batch_y[k], pred[k]))\n",
471
+ " cer.append(calculate_cer(batch_y[k], pred[k]))\n",
472
+ " \n",
473
+ " wer_lm.append(calculate_wer(batch_y[k], d_lm2))\n",
474
+ " cer_lm.append(calculate_cer(batch_y[k], d_lm2))\n",
475
+ " \n",
476
+ " "
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": 21,
482
+ "id": "6c6ce8ef",
483
+ "metadata": {},
484
+ "outputs": [
485
+ {
486
+ "data": {
487
+ "text/plain": [
488
+ "(0.14251665517797765,\n",
489
+ " 0.05082346216269688,\n",
490
+ " 0.10380217528405207,\n",
491
+ " 0.042868860764264445)"
492
+ ]
493
+ },
494
+ "execution_count": 21,
495
+ "metadata": {},
496
+ "output_type": "execute_result"
497
+ }
498
+ ],
499
+ "source": [
500
+ "np.mean(wer), np.mean(cer), np.mean(wer_lm), np.mean(cer_lm)"
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "execution_count": 22,
506
+ "id": "cf53914e",
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "index_malay = [no for no, i in enumerate(audio) if 'malay-test/' in i]\n",
511
+ "index_singlish = [no for no, i in enumerate(audio) if 'singlish-test/' in i]\n",
512
+ "index_mandarin = [no for no, i in enumerate(audio) if 'mandarin-test/' in i]"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 23,
518
+ "id": "b1558987",
519
+ "metadata": {},
520
+ "outputs": [
521
+ {
522
+ "data": {
523
+ "text/plain": [
524
+ "(0.21723938552369926,\n",
525
+ " 0.05027226867066105,\n",
526
+ " 0.13593624603428525,\n",
527
+ " 0.03601546154013878)"
528
+ ]
529
+ },
530
+ "execution_count": 23,
531
+ "metadata": {},
532
+ "output_type": "execute_result"
533
+ }
534
+ ],
535
+ "source": [
536
+ "np.mean(np.array(wer)[index_malay]), np.mean(np.array(cer)[index_malay]), np.mean(np.array(wer_lm)[index_malay]), np.mean(np.array(cer_lm)[index_malay])"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "code",
541
+ "execution_count": 24,
542
+ "id": "f340cde7",
543
+ "metadata": {},
544
+ "outputs": [
545
+ {
546
+ "data": {
547
+ "text/plain": [
548
+ "(0.1331819722523124,\n",
549
+ " 0.05161275767676772,\n",
550
+ " 0.09859626021111582,\n",
551
+ " 0.04419848182804781)"
552
+ ]
553
+ },
554
+ "execution_count": 24,
555
+ "metadata": {},
556
+ "output_type": "execute_result"
557
+ }
558
+ ],
559
+ "source": [
560
+ "np.mean(np.array(wer)[index_singlish]), np.mean(np.array(cer)[index_singlish]), np.mean(np.array(wer_lm)[index_singlish]), np.mean(np.array(cer_lm)[index_singlish])"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": 25,
566
+ "id": "cbc2539f",
567
+ "metadata": {},
568
+ "outputs": [
569
+ {
570
+ "data": {
571
+ "text/plain": [
572
+ "(0.10382926344585862,\n",
573
+ " 0.04690941391603209,\n",
574
+ " 0.09411065398455744,\n",
575
+ " 0.0436573568867001)"
576
+ ]
577
+ },
578
+ "execution_count": 25,
579
+ "metadata": {},
580
+ "output_type": "execute_result"
581
+ }
582
+ ],
583
+ "source": [
584
+ "np.mean(np.array(wer)[index_mandarin]), np.mean(np.array(cer)[index_mandarin]), np.mean(np.array(wer_lm)[index_mandarin]), np.mean(np.array(cer_lm)[index_mandarin])"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 26,
590
+ "id": "4c543d0c",
591
+ "metadata": {},
592
+ "outputs": [
593
+ {
594
+ "name": "stderr",
595
+ "output_type": "stream",
596
+ "text": [
597
+ "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/utils/_deprecation.py:39: FutureWarning: Pass token='wav2vec2-xls-r-300m-mixed' as keyword args. From version 0.7 passing these as positional arguments will result in an error\n",
598
+ " warnings.warn(\n",
599
+ "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/hf_api.py:79: FutureWarning: `name` and `organization` input arguments are deprecated and will be removed in v0.7. Pass `repo_id` instead.\n",
600
+ " warnings.warn(\n",
601
+ "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/hf_api.py:596: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n",
602
+ " warnings.warn(\n",
603
+ "Cloning https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed into local empty directory.\n"
604
+ ]
605
+ },
606
+ {
607
+ "data": {
608
+ "application/vnd.jupyter.widget-view+json": {
609
+ "model_id": "10c88b815c83447b9d04f297f54fe1d9",
610
+ "version_major": 2,
611
+ "version_minor": 0
612
+ },
613
+ "text/plain": [
614
+ "Upload file pytorch_model.bin: 0%| | 4.00k/1.18G [00:00<?, ?B/s]"
615
+ ]
616
+ },
617
+ "metadata": {},
618
+ "output_type": "display_data"
619
+ },
620
+ {
621
+ "name": "stderr",
622
+ "output_type": "stream",
623
+ "text": [
624
+ "remote: Enforcing permissions... \n",
625
+ "remote: Allowed refs: all \n",
626
+ "To https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed\n",
627
+ " 33df917..7044629 main -> main\n",
628
+ "\n"
629
+ ]
630
+ },
631
+ {
632
+ "data": {
633
+ "text/plain": [
634
+ "'https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed/commit/7044629625df853dec50f463f6b794afe61d391f'"
635
+ ]
636
+ },
637
+ "execution_count": 26,
638
+ "metadata": {},
639
+ "output_type": "execute_result"
640
+ }
641
+ ],
642
+ "source": [
643
+ "model.push_to_hub('wav2vec2-xls-r-300m-mixed', organization='mesolitica')"
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": 27,
649
+ "id": "05ec385e",
650
+ "metadata": {},
651
+ "outputs": [
652
+ {
653
+ "name": "stderr",
654
+ "output_type": "stream",
655
+ "text": [
656
+ "2022-06-01 09:29:07.148431: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
657
+ "2022-06-01 09:29:07.191068: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
658
+ "2022-06-01 09:29:07.192882: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
659
+ "2022-06-01 09:29:07.194967: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n",
660
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
661
+ "2022-06-01 09:29:07.196435: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
662
+ "2022-06-01 09:29:07.197071: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
663
+ "2022-06-01 09:29:07.197672: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
664
+ "2022-06-01 09:29:07.199082: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
665
+ "2022-06-01 09:29:07.199700: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
666
+ "2022-06-01 09:29:07.200318: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n",
667
+ "2022-06-01 09:29:07.201032: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n",
668
+ "2022-06-01 09:29:07.201159: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 17325 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090 Ti, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
669
+ "\n",
670
+ "TFWav2Vec2ForCTC has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tine this model, you need a GPU or a TPU\n",
671
+ "2022-06-01 09:29:09.085113: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8100\n",
672
+ "2022-06-01 09:29:09.930887: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory\n",
673
+ "2022-06-01 09:29:10.708302: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.\n",
674
+ "All PyTorch model weights were used when initializing TFWav2Vec2ForCTC.\n",
675
+ "\n",
676
+ "All the weights of TFWav2Vec2ForCTC were initialized from the PyTorch model.\n",
677
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFWav2Vec2ForCTC for predictions without further training.\n"
678
+ ]
679
+ }
680
+ ],
681
+ "source": [
682
+ "model_tf = TFWav2Vec2ForCTC.from_pretrained(\n",
683
+ " './wav2vec2-mixed-v3/checkpoint-55000',\n",
684
+ " ctc_loss_reduction=\"mean\",\n",
685
+ " pad_token_id=tokenizer.pad_token_id,\n",
686
+ " vocab_size=len(tokenizer),\n",
687
+ " from_pt=True,\n",
688
+ ")"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": 28,
694
+ "id": "e0f3f749",
695
+ "metadata": {},
696
+ "outputs": [
697
+ {
698
+ "name": "stderr",
699
+ "output_type": "stream",
700
+ "text": [
701
+ "2022-06-01 09:29:38.885075: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 33554432 exceeds 10% of free system memory.\n"
702
+ ]
703
+ },
704
+ {
705
+ "data": {
706
+ "application/vnd.jupyter.widget-view+json": {
707
+ "model_id": "f70276237419473091977e7dcd4da591",
708
+ "version_major": 2,
709
+ "version_minor": 0
710
+ },
711
+ "text/plain": [
712
+ "Upload file tf_model.h5: 0%| | 4.00k/1.18G [00:00<?, ?B/s]"
713
+ ]
714
+ },
715
+ "metadata": {},
716
+ "output_type": "display_data"
717
+ },
718
+ {
719
+ "name": "stderr",
720
+ "output_type": "stream",
721
+ "text": [
722
+ "remote: Enforcing permissions... \n",
723
+ "remote: Allowed refs: all \n",
724
+ "To https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed\n",
725
+ " 7044629..86e9f45 main -> main\n",
726
+ "\n"
727
+ ]
728
+ },
729
+ {
730
+ "data": {
731
+ "text/plain": [
732
+ "'https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed/commit/86e9f450fa80b3f51175f04f694b35f342a6a09e'"
733
+ ]
734
+ },
735
+ "execution_count": 28,
736
+ "metadata": {},
737
+ "output_type": "execute_result"
738
+ }
739
+ ],
740
+ "source": [
741
+ "model_tf.push_to_hub('wav2vec2-xls-r-300m-mixed', organization='mesolitica')"
742
+ ]
743
+ },
744
+ {
745
+ "cell_type": "code",
746
+ "execution_count": 30,
747
+ "id": "999b8b28",
748
+ "metadata": {},
749
+ "outputs": [],
750
+ "source": [
751
+ "tokenizer = Wav2Vec2CTCTokenizer(\n",
752
+ " \"ctc-vocab.json\",\n",
753
+ " unk_token=\"[UNK]\",\n",
754
+ " pad_token=\"[PAD]\",\n",
755
+ " word_delimiter_token=\"|\",\n",
756
+ ")"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 31,
762
+ "id": "54a3285e",
763
+ "metadata": {},
764
+ "outputs": [],
765
+ "source": [
766
+ "feature_extractor = Wav2Vec2FeatureExtractor(\n",
767
+ " feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True\n",
768
+ ")\n",
769
+ "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": 32,
775
+ "id": "b4bf1a21",
776
+ "metadata": {},
777
+ "outputs": [
778
+ {
779
+ "name": "stderr",
780
+ "output_type": "stream",
781
+ "text": [
782
+ "remote: Enforcing permissions... \n",
783
+ "remote: Allowed refs: all \n",
784
+ "To https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed\n",
785
+ " 86e9f45..adf6534 main -> main\n",
786
+ "\n"
787
+ ]
788
+ },
789
+ {
790
+ "data": {
791
+ "text/plain": [
792
+ "'https://huggingface.co/mesolitica/wav2vec2-xls-r-300m-mixed/commit/adf65347379e5902f7488753aef24d4e9d16daff'"
793
+ ]
794
+ },
795
+ "execution_count": 32,
796
+ "metadata": {},
797
+ "output_type": "execute_result"
798
+ }
799
+ ],
800
+ "source": [
801
+ "processor.push_to_hub('wav2vec2-xls-r-300m-mixed', organization='mesolitica')"
802
+ ]
803
+ }
804
+ ],
805
+ "metadata": {
806
+ "kernelspec": {
807
+ "display_name": "Python 3 (ipykernel)",
808
+ "language": "python",
809
+ "name": "python3"
810
+ },
811
+ "language_info": {
812
+ "codemirror_mode": {
813
+ "name": "ipython",
814
+ "version": 3
815
+ },
816
+ "file_extension": ".py",
817
+ "mimetype": "text/x-python",
818
+ "name": "python",
819
+ "nbconvert_exporter": "python",
820
+ "pygments_lexer": "ipython3",
821
+ "version": "3.8.10"
822
+ }
823
+ },
824
+ "nbformat": 4,
825
+ "nbformat_minor": 5
826
+ }