darshanmakwana commited on
Commit
2cddd11
·
verified ·
1 Parent(s): 03d39ff

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. prompting/.ipynb_checkpoints/dataset_generation-checkpoint.ipynb +332 -0
  3. prompting/.ipynb_checkpoints/generate_rare_words-checkpoint.py +62 -0
  4. prompting/.ipynb_checkpoints/generate_transcripts-checkpoint.py +60 -0
  5. prompting/.ipynb_checkpoints/get_error_word_count-checkpoint.py +106 -0
  6. prompting/.ipynb_checkpoints/model-checkpoint.py +53 -0
  7. prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json +3 -0
  8. prompting/.ipynb_checkpoints/train_lora-checkpoint.py +137 -0
  9. prompting/.ipynb_checkpoints/train_phi-checkpoint.py +86 -0
  10. prompting/.ipynb_checkpoints/training-checkpoint.ipynb +278 -0
  11. prompting/RepCodec/.gitignore +160 -0
  12. prompting/RepCodec/.ipynb_checkpoints/tinker-checkpoint.ipynb +267 -0
  13. prompting/RepCodec/LICENSE +428 -0
  14. prompting/RepCodec/README.md +273 -0
  15. prompting/RepCodec/dataloader/__init__.py +2 -0
  16. prompting/RepCodec/dataloader/collater.py +22 -0
  17. prompting/RepCodec/dataloader/dataset.py +90 -0
  18. prompting/RepCodec/examples/.ipynb_checkpoints/Untitled-checkpoint.ipynb +334 -0
  19. prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_audio-checkpoint.py +541 -0
  20. prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_feature_reader-checkpoint.py +87 -0
  21. prompting/RepCodec/examples/.ipynb_checkpoints/dump_feature-checkpoint.py +142 -0
  22. prompting/RepCodec/examples/.ipynb_checkpoints/feature_utils-checkpoint.py +70 -0
  23. prompting/RepCodec/examples/.ipynb_checkpoints/some_run-checkpoint.py +66 -0
  24. prompting/RepCodec/examples/Untitled.ipynb +214 -0
  25. prompting/RepCodec/examples/__pycache__/data2vec_audio.cpython-38.pyc +0 -0
  26. prompting/RepCodec/examples/__pycache__/data2vec_feature_reader.cpython-38.pyc +0 -0
  27. prompting/RepCodec/examples/__pycache__/feature_utils.cpython-38.pyc +0 -0
  28. prompting/RepCodec/examples/__pycache__/hubert_feature_reader.cpython-38.pyc +0 -0
  29. prompting/RepCodec/examples/__pycache__/tokenize.cpython-38.pyc +0 -0
  30. prompting/RepCodec/examples/data2vec_audio.py +541 -0
  31. prompting/RepCodec/examples/data2vec_feature_reader.py +87 -0
  32. prompting/RepCodec/examples/dump_feature.py +142 -0
  33. prompting/RepCodec/examples/feature_utils.py +70 -0
  34. prompting/RepCodec/examples/hubert_feature_reader.py +64 -0
  35. prompting/RepCodec/examples/some_run.py +66 -0
  36. prompting/RepCodec/examples/tkns/test.clean.npz +3 -0
  37. prompting/RepCodec/examples/tkns/test.other.npz +3 -0
  38. prompting/RepCodec/examples/tkns/train.clean.100.npz +3 -0
  39. prompting/RepCodec/examples/tkns/train.clean.360.npz +3 -0
  40. prompting/RepCodec/examples/tkns/train.other.500.npz +3 -0
  41. prompting/RepCodec/examples/tkns/validation.clean.npz +3 -0
  42. prompting/RepCodec/examples/tkns/validation.other.npz +3 -0
  43. prompting/RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens +0 -0
  44. prompting/RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens +0 -0
  45. prompting/RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens +0 -0
  46. prompting/RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens +0 -0
  47. prompting/RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens +0 -0
  48. prompting/RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens +0 -0
  49. prompting/RepCodec/examples/whisper_feature_reader.py +110 -0
  50. prompting/RepCodec/examples/whisper_model.py +58 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json filter=lfs diff=lfs merge=lfs -text
37
+ prompting/train_clean_100_error.json filter=lfs diff=lfs merge=lfs -text
38
+ prompting/train_data/train.clean.360.json filter=lfs diff=lfs merge=lfs -text
39
+ prompting/train_data/train.other.500.json filter=lfs diff=lfs merge=lfs -text
40
+ prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
41
+ prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
42
+ prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
prompting/.ipynb_checkpoints/dataset_generation-checkpoint.ipynb ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "59efc3d7-a57f-43cc-8aa3-34bb57de0251",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Librispeech"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "327243de-fd0f-449d-998a-63282a1c67a2",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from datasets import load_dataset\n",
19
+ "\n",
20
+ "cache_dir = \"./../cache\"\n",
21
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "456889e1-f8cc-440b-bf6b-f6fbfafc367d",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from torchmetrics import WordErrorRate, CharErrorRate\n",
32
+ "from edit_distance import SequenceMatcher\n",
33
+ "from tqdm import tqdm\n",
34
+ "import jiwer\n",
35
+ "\n",
36
+ "def correct_text(text):\n",
37
+ " transforms = jiwer.Compose(\n",
38
+ " [\n",
39
+ " jiwer.ExpandCommonEnglishContractions(),\n",
40
+ " jiwer.ToLowerCase(),\n",
41
+ " jiwer.RemoveMultipleSpaces(),\n",
42
+ " jiwer.Strip(),\n",
43
+ " jiwer.RemovePunctuation(),\n",
44
+ " jiwer.ReduceToListOfListOfWords(),\n",
45
+ " ]\n",
46
+ " )\n",
47
+ " return transforms(text)\n",
48
+ "\n",
49
+ "def align_gt_asr(gt, asr):\n",
50
+ "\n",
51
+ " sm = SequenceMatcher(a=gt, b=asr)\n",
52
+ " best_path = []\n",
53
+ " opcodes = sm.get_opcodes()\n",
54
+ "\n",
55
+ " for tag, i1, i2, j1, j2 in opcodes:\n",
56
+ "\n",
57
+ " if tag == \"delete\":\n",
58
+ " for i in range(i1, i2):\n",
59
+ " best_path.append([gt[i], \"\"])\n",
60
+ "\n",
61
+ " if tag == \"replace\" or tag == \"equal\":\n",
62
+ " for i, j in zip(range(i1, i2), range(j1, j2)):\n",
63
+ " best_path.append([gt[i], asr[j]])\n",
64
+ "\n",
65
+ " if tag == \"insert\":\n",
66
+ " for j in range(j1, j2):\n",
67
+ " best_path.append([\"\", asr[j]])\n",
68
+ "\n",
69
+ " return best_path\n",
70
+ "\n",
71
+ "import string\n",
72
+ "def process(text):\n",
73
+ "\n",
74
+ " # Lower case every letter\n",
75
+ " text = text.lower()\n",
76
+ "\n",
77
+ " # Remove punctuation\n",
78
+ " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
79
+ " translation_table = str.maketrans('', '', punctuation_to_remove)\n",
80
+ " text = text.translate(translation_table)\n",
81
+ "\n",
82
+ " # Remove whitespaces from front and behind\n",
83
+ " while text[0] == ' ' or text[-1] == ' ':\n",
84
+ " if text[0] == ' ':\n",
85
+ " text = text[1:]\n",
86
+ " if text[-1] == ' ':\n",
87
+ " text = text[:-1]\n",
88
+ " \n",
89
+ " return text"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "id": "3bc907b0-2ebe-46ac-b6a1-02919e69af88",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "from tqdm import tqdm\n",
100
+ "\n",
101
+ "gens = []\n",
102
+ "texts = []\n",
103
+ "\n",
104
+ "unmatches = []\n",
105
+ "\n",
106
+ "for split in [\"validation.clean\"]:\n",
107
+ " data = dataset[split]\n",
108
+ " with open(f\"./transcripts/{split}.txt\", \"r\") as f:\n",
109
+ " for idx, line in enumerate(tqdm(f)):\n",
110
+ " preds = process(line.rstrip())\n",
111
+ " text = data[idx][\"text\"]\n",
112
+ "\n",
113
+ " path = align_gt_asr(correct_text(text)[0], correct_text(preds)[0])\n",
114
+ " un = 0\n",
115
+ " for a, b in path:\n",
116
+ " if a!=b:\n",
117
+ " un+=1\n",
118
+ " \n",
119
+ " unmatches.append(un)\n",
120
+ "\n",
121
+ " # texts.append(process(text))\n",
122
+ " # gens.append(preds)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "cac10009-1b47-4e2f-a232-f71b23ee983e",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "import numpy as np\n",
133
+ "\n",
134
+ "np.count_nonzero(unmatches)"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "afdc9f74-c2cf-4d52-8563-1bd827f6d900",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "def align_gt_asr(gt, asr):\n",
145
+ "\n",
146
+ " sm = SequenceMatcher(a=gt, b=asr)\n",
147
+ " best_path = []\n",
148
+ " opcodes = sm.get_opcodes()\n",
149
+ " \n",
150
+ " for tag, i1, i2, j1, j2 in opcodes:\n",
151
+ " \n",
152
+ " if tag == \"delete\":\n",
153
+ " for i in range(i1, i2):\n",
154
+ " best_path.append([gt[i], \"\"])\n",
155
+ " \n",
156
+ " if tag == \"replace\" or tag == \"equal\":\n",
157
+ " for i, j in zip(range(i1, i2), range(j1, j2)):\n",
158
+ " best_path.append([gt[i], asr[j]])\n",
159
+ " \n",
160
+ " if tag == \"insert\":\n",
161
+ " for j in range(j1, j2):\n",
162
+ " best_path.append([\"\", asr[j]])\n",
163
+ " \n",
164
+ " return best_path\n",
165
+ "\n",
166
+ "# align_gt_asr(correct_text(text), correct_text(preds))"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "3cdfd3d9-6c22-4ccd-a22b-df8e79fc20b0",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "correct_text(text)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "2c33f46a-f3dd-435f-81e3-e7b10ae03470",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "correct_text([\"hello\", \"hey\"])"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "2cfab12a-2b2c-4c00-bd80-ab571c012f29",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "## Transcript of whisper small WER\n",
197
+ "## validation.clean 4.62\n",
198
+ "## validation.other 8.11\n",
199
+ "## test.clean 4.22\n",
200
+ "## test.other 8.56\n"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "24cb2d8f-9ce2-42f2-bbf0-522106078aac",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
211
+ "from datasets import load_dataset\n",
212
+ "import numpy as np\n",
213
+ "import torch\n",
214
+ "\n",
215
+ "device = \"cuda:0\"\n",
216
+ "dtype = torch.float16\n",
217
+ "cache_dir = \"./../cache\"\n",
218
+ "model_id = \"openai/whisper-small\"\n",
219
+ "\n",
220
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", cache_dir=cache_dir)\n",
221
+ "model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation=\"sdpa\").to(device).to(dtype).eval()"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "markdown",
226
+ "id": "d5fa6f8e-43f2-44ce-b719-2d8fde4067ce",
227
+ "metadata": {},
228
+ "source": [
229
+ "## Biasing List"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "3cc0f934-d208-445e-aecd-31df73be6986",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "import sys, os\n",
240
+ "import json\n",
241
+ "import string\n",
242
+ "from tqdm import tqdm\n",
243
+ "def process(text):\n",
244
+ "\n",
245
+ " # Lower case every letter\n",
246
+ " text = text.lower()\n",
247
+ "\n",
248
+ " # Remove punctuation\n",
249
+ " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
250
+ " translation_table = str.maketrans('', '', punctuation_to_remove)\n",
251
+ " text = text.translate(translation_table)\n",
252
+ "\n",
253
+ " # Remove whitespaces from front and behind\n",
254
+ " while text[0] == ' ' or text[-1] == ' ':\n",
255
+ " if text[0] == ' ':\n",
256
+ " text = text[1:]\n",
257
+ " if text[-1] == ' ':\n",
258
+ " text = text[:-1]\n",
259
+ " \n",
260
+ " return text\n",
261
+ "\n",
262
+ "split_name = \"train.clean.100\"\n",
263
+ "\n",
264
+ "with open(\"./blist/all_rare_words.txt\") as fin:\n",
265
+ " rarewords = [process(word.strip()) for word in fin]\n",
266
+ "\n",
267
+ "with open(f\"./transcripts/{split_name}.txt\") as fin:\n",
268
+ " transcripts = [line.strip() for line in fin]\n",
269
+ "\n",
270
+ "from datasets import load_dataset\n",
271
+ "\n",
272
+ "cache_dir = \"./../cache\"\n",
273
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
274
+ "\n",
275
+ "train_data = []\n",
276
+ "\n",
277
+ "pbar = tqdm(dataset[split_name])\n",
278
+ "for idx, sample in enumerate(pbar):\n",
279
+ " \n",
280
+ " text = process(sample[\"text\"])\n",
281
+ " transcript = transcripts[idx]\n",
282
+ " \n",
283
+ " bwords = []\n",
284
+ " for word in text.split():\n",
285
+ " if word in rarewords and word not in transcript:\n",
286
+ " bwords.append(word)\n",
287
+ " \n",
288
+ " if len(bwords) > 0:\n",
289
+ " train_data.append({\n",
290
+ " \"split\": split_name,\n",
291
+ " \"idx\": idx,\n",
292
+ " \"text\": text,\n",
293
+ " \"transcript\": transcript,\n",
294
+ " \"b_words\": bwords,\n",
295
+ " })\n",
296
+ " pbar.set_description(f\"Len of train data: {len(train_data)}\")"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "id": "cac9a909-e1ce-426a-bda3-b65ba3985d06",
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "with open(f\"./train_data/{split_name}.json\", \"w\") as fout:\n",
307
+ " json.dump(train_data, fout, indent=4)"
308
+ ]
309
+ }
310
+ ],
311
+ "metadata": {
312
+ "kernelspec": {
313
+ "display_name": "Python 3 (ipykernel)",
314
+ "language": "python",
315
+ "name": "python3"
316
+ },
317
+ "language_info": {
318
+ "codemirror_mode": {
319
+ "name": "ipython",
320
+ "version": 3
321
+ },
322
+ "file_extension": ".py",
323
+ "mimetype": "text/x-python",
324
+ "name": "python",
325
+ "nbconvert_exporter": "python",
326
+ "pygments_lexer": "ipython3",
327
+ "version": "3.8.10"
328
+ }
329
+ },
330
+ "nbformat": 4,
331
+ "nbformat_minor": 5
332
+ }
prompting/.ipynb_checkpoints/generate_rare_words-checkpoint.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import json
3
+ import string
4
+ from tqdm import tqdm
5
+
6
+ def process(text):
7
+
8
+ # Lower case every letter
9
+ text = text.lower()
10
+
11
+ # Remove punctuation
12
+ punctuation_to_remove = string.punctuation.replace("'", "")
13
+ translation_table = str.maketrans('', '', punctuation_to_remove)
14
+ text = text.translate(translation_table)
15
+
16
+ # Remove whitespaces from front and behind
17
+ while text[0] == ' ' or text[-1] == ' ':
18
+ if text[0] == ' ':
19
+ text = text[1:]
20
+ if text[-1] == ' ':
21
+ text = text[:-1]
22
+
23
+ return text
24
+
25
+ split_name = "train.other.500"
26
+
27
+ with open("./blist/all_rare_words.txt") as fin:
28
+ rarewords = [process(word.strip()) for word in fin]
29
+
30
+ with open(f"./transcripts/{split_name}.txt") as fin:
31
+ transcripts = [line.strip() for line in fin]
32
+
33
+ from datasets import load_dataset
34
+
35
+ cache_dir = "./../cache"
36
+ dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
37
+
38
+ train_data = []
39
+
40
+ pbar = tqdm(dataset[split_name])
41
+ for idx, sample in enumerate(pbar):
42
+
43
+ text = process(sample["text"])
44
+ transcript = transcripts[idx]
45
+
46
+ bwords = []
47
+ for word in text.split():
48
+ if word in rarewords and word not in transcript:
49
+ bwords.append(word)
50
+
51
+ if len(bwords) > 0:
52
+ train_data.append({
53
+ "split": split_name,
54
+ "idx": idx,
55
+ "text": text,
56
+ "transcript": transcript,
57
+ "b_words": bwords,
58
+ })
59
+ pbar.set_description(f"Len of train data: {len(train_data)}")
60
+
61
+ with open(f"./train_data/{split_name}.json", "w") as fout:
62
+ json.dump(train_data, fout, indent=4)
prompting/.ipynb_checkpoints/generate_transcripts-checkpoint.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
+ from datasets import load_dataset
3
+ from tqdm import tqdm
4
+ from math import ceil
5
+ from model import generate, flush
6
+ import numpy as np
7
+ import os
8
+ import torch
9
+ import string
10
+
11
+ def process(text):
12
+
13
+ # Lower case every letter
14
+ text = text.lower()
15
+
16
+ # Remove punctuation
17
+ punctuation_to_remove = string.punctuation.replace("'", "")
18
+ translation_table = str.maketrans('', '', punctuation_to_remove)
19
+ text = text.translate(translation_table)
20
+
21
+ # Remove whitespaces from front and behind
22
+ while text[0] == ' ' or text[-1] == ' ':
23
+ if text[0] == ' ':
24
+ text = text[1:]
25
+ if text[-1] == ' ':
26
+ text = text[:-1]
27
+
28
+ return text
29
+
30
+ device = "cuda:0"
31
+ dtype = torch.float16
32
+ cache_dir = "./../cache"
33
+ model_id = "openai/whisper-small"
34
+ batch_size = 250
35
+ out_dir = "./transcripts"
36
+
37
+ dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
38
+
39
+ processor = WhisperProcessor.from_pretrained(model_id, cache_dir=cache_dir)
40
+ model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval()
41
+
42
+ for split in dataset.keys():
43
+
44
+ data = dataset[split]
45
+
46
+ os.makedirs(out_dir, exist_ok=True)
47
+
48
+ for idx in tqdm(range(ceil(len(data)/batch_size))):
49
+
50
+ audios = data[idx * batch_size: (idx + 1) * batch_size]["audio"]
51
+
52
+ arrays = [a["array"] for a in audios]
53
+
54
+ transcripts = generate(arrays, model, processor)
55
+
56
+ with open(os.path.join(out_dir, f"{split}.txt"), "a") as disk:
57
+ disk.writelines([process(text) + "\n" for text in transcripts])
58
+ disk.close()
59
+
60
+ flush()
prompting/.ipynb_checkpoints/get_error_word_count-checkpoint.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ # from normalizers.english import EnglishTextNormalizer
3
+
4
+ error_words_freqs = {}
5
+ infile = sys.argv[1]
6
+ # setname = sys.argv[2]
7
+ insert_error = 0
8
+ insert_rare = 0
9
+ freqlist_test = {}
10
+ # eng_norm = EnglishTextNormalizer()
11
+
12
+ freqlist = {}
13
+ with open("./blist/word_freq.txt") as fin:
14
+ for line in fin:
15
+ word, freq = line.split()
16
+ freqlist[word.upper()] = int(freq)
17
+
18
+ with open("./blist/all_rare_words.txt") as fin:
19
+ rareset = set()
20
+ for line in fin:
21
+ rareset.add(line.strip().upper())
22
+
23
+ project_set = set()
24
+ with open(infile) as fin:
25
+ lines = fin.readlines()
26
+ for i, line in enumerate(lines):
27
+ if line.startswith('id:'):
28
+ project = line.strip(')\n').split('-')[-3:]
29
+ project = '-'.join(project)
30
+ if "REF:" in line:
31
+ nextline = lines[i+1].split()
32
+ for j, word in enumerate(line.split()):
33
+ if '*' in word:
34
+ insert_error += 1
35
+ if nextline[j].upper() in rareset:
36
+ insert_rare += 1
37
+ line = line.replace('*', '')
38
+ line.replace('%BCACK', '')
39
+ for word in line.split()[1:]:
40
+ if not word.startswith('('):
41
+ if word.upper() not in freqlist_test:
42
+ freqlist_test[word.upper()] = 1
43
+ else:
44
+ freqlist_test[word.upper()] += 1
45
+
46
+ if word != word.lower() and word.upper() in error_words_freqs:
47
+ error_words_freqs[word.upper()] += 1
48
+ elif word != word.lower() and word.upper() not in error_words_freqs:
49
+ error_words_freqs[word.upper()] = 1
50
+ elif word == word.lower() and word.upper() not in error_words_freqs:
51
+ error_words_freqs[word.upper()] = 0
52
+ print(len(error_words_freqs.keys()))
53
+ print(insert_rare)
54
+
55
+ commonwords = []
56
+ rarewords = []
57
+ oovwords = []
58
+ common_freq = 0
59
+ rare_freq = 0
60
+ oov_freq = 0
61
+ common_error = 0
62
+ rare_error = 0
63
+ oov_error = 0
64
+ partial_error = 0
65
+ partial_freq = 0
66
+ very_common_error = 0
67
+ very_common_words = 0
68
+ words_error_freq = {}
69
+ words_total_freq = {}
70
+ for word, error in error_words_freqs.items():
71
+ if word in rareset:
72
+ rarewords.append(word)
73
+ rare_freq += freqlist_test[word]
74
+ rare_error += error
75
+ elif word not in freqlist:
76
+ oovwords.append(word)
77
+ oov_freq += freqlist_test[word] if word in freqlist_test else 1
78
+ oov_error += error
79
+ else:
80
+ if freqlist[word] <= 10 and freqlist[word] >= 3:
81
+ if freqlist[word] not in words_error_freq:
82
+ words_error_freq[freqlist[word]] = error
83
+ words_total_freq[freqlist[word]] = freqlist_test[word]
84
+ else:
85
+ words_error_freq[freqlist[word]] += error
86
+ words_total_freq[freqlist[word]] += freqlist_test[word]
87
+ if freqlist[word] <= 10 and freqlist[word] >= 3:
88
+ very_common_error += error
89
+ very_common_words += freqlist_test[word]
90
+ commonwords.append(word)
91
+ common_freq += freqlist_test[word]
92
+ common_error += error
93
+
94
+ total_words = common_freq + rare_freq + oov_freq
95
+ total_errors = common_error+rare_error+oov_error + insert_error
96
+ WER = total_errors / total_words
97
+ print('='*89)
98
+ print('Common words error freq: {} / {} = {}'.format(common_error, common_freq, common_error/common_freq))
99
+ print('Rare words error freq: {} / {} = {}'.format(rare_error+insert_rare, rare_freq, (rare_error + insert_rare)/rare_freq))
100
+ print('OOV words error freq: {} / {} = {}'.format(oov_error, oov_freq, oov_error/max(oov_freq, 1)))
101
+ print('WER estimate: {} / {} = {}'.format(total_errors, total_words, WER))
102
+ # print('Partial word count: {} / {}'.format(partial_error, partial_freq))
103
+ print('Insert error: {} / {} = {}'.format(insert_error - insert_rare, total_words, (insert_error - insert_rare)/total_words))
104
+ print('Insertion + OOV error {}'.format((insert_error + oov_error - insert_rare) / total_words))
105
+ # print('Very common words error freq: {} / {} = {}'.format(very_common_error, very_common_words, very_common_error/very_common_words))
106
+ print('='*89)
prompting/.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gc
4
+ from typing import List
5
+
6
+ def flush():
7
+ gc.collect()
8
+ torch.cuda.empty_cache()
9
+
10
+ @torch.no_grad()
11
+ def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]:
12
+ """
13
+ arrays: a list of audio arrays
14
+ model: the whisper model to use
15
+ processor: the wisper processor to use
16
+ """
17
+
18
+ inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features
19
+
20
+ # Cache the encoder hidden states
21
+ encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state
22
+
23
+ decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device)
24
+
25
+ # Tensor to keep track of which samples have reached the end of text token
26
+ inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device)
27
+
28
+ while inference_continues.any() and max_new_tokens > 0:
29
+
30
+ last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state
31
+
32
+ # A small optimization to only project the hidden states of the last token
33
+ last_token_hidden_state = last_hidden_state[:, -1, :]
34
+ logits = model.proj_out(last_token_hidden_state)
35
+
36
+ # Greedy Sampling
37
+ probas = torch.softmax(logits, dim=-1)
38
+ pred_idx = torch.argmax(probas, dim=-1, keepdim=True)
39
+
40
+ # Fill the samples where inference has stopped with <|end of text|> token
41
+ pred_idx[~inference_continues, :] = 50257
42
+
43
+ decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1)
44
+
45
+ # Check if any sample has reached the end of text token
46
+ reached_end_of_text = (pred_idx.squeeze(-1) == 50257)
47
+ inference_continues &= ~reached_end_of_text
48
+
49
+ max_new_tokens -= 1
50
+
51
+ transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True)
52
+
53
+ return transcripts
prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71787a0de524627122b64b419a50d551ca4a9ddb5ac3888c2e54990850cde26e
3
+ size 12684279
prompting/.ipynb_checkpoints/train_lora-checkpoint.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from peft import LoraConfig, prepare_model_for_kbit_training, TaskType
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ BitsAndBytesConfig,
7
+ TrainingArguments,
8
+ set_seed,
9
+ pipeline
10
+ )
11
+ from trl import SFTTrainer, SFTConfig
12
+ from random import randrange
13
+ import torch
14
+ import wandb
15
+
16
+ cache_dir = "./../cache"
17
+ model_id = "microsoft/Phi-3-mini-4k-instruct"
18
+ new_model = "python-phi-3-mini-4k-instruct"
19
+ username = "ellipticaloranges"
20
+ device_map = {"": 0}
21
+ hf_model_repo = username + "/" + new_model
22
+
23
+ ## ------------------------LoRA Configs------------------------------------------------------
24
+
25
+ lora_r = 16
26
+ lora_alpha = 16
27
+ lora_dropout = 0.05
28
+ target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
29
+
30
+ ## ------------------------------------------------------------------------------------------
31
+
32
+ dataset_name = "flytech/python-codes-25k"
33
+ dataset_split= "train"
34
+
35
+ dataset = load_dataset(dataset_name, split=dataset_split, cache_dir=cache_dir)
36
+ print(f"Dataset size: {len(dataset)}")
37
+
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True, add_eos_token=True, use_fast=True)
40
+ # The padding token is set to the unknown token.
41
+ tokenizer.pad_token = tokenizer.unk_token
42
+ # The ID of the padding token is set to the ID of the unknown token.
43
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
44
+ # ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input.
45
+ tokenizer.padding_side = 'left'
46
+
47
+
48
+ def create_message_column(row):
49
+ messages = []
50
+ user = {
51
+ "content": f"{row['instruction']}",
52
+ "role": "user"
53
+ }
54
+ messages.append(user)
55
+ assistant = {
56
+ "content": f"{row['input']}\n{row['output']}",
57
+ "role": "assistant"
58
+ }
59
+ messages.append(assistant)
60
+ return {"messages": messages}
61
+
62
+ def format_dataset_chatml(row):
63
+ return {"text": tokenizer.apply_chat_template(row["messages"], add_generation_prompt=False, tokenize=False)}
64
+
65
+ dataset_chatml = dataset.map(create_message_column)
66
+ dataset_chatml = dataset_chatml.map(format_dataset_chatml)
67
+ dataset_chatml = dataset_chatml.train_test_split(test_size=0.05, seed=1234)
68
+
69
+ # print("Max Seq Length", max(map(lambda x: len(tokenizer.encode(x["text"])), dataset)))
70
+
71
+ if torch.cuda.is_bf16_supported():
72
+ compute_dtype = torch.bfloat16
73
+ attn_implementation = 'flash_attention_2'
74
+ else:
75
+ compute_dtype = torch.float16
76
+ attn_implementation = 'sdpa'
77
+
78
+ print(f"Using {compute_dtype} with {attn_implementation} implementation")
79
+
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_id,
82
+ torch_dtype = compute_dtype,
83
+ trust_remote_code = True,
84
+ device_map = device_map,
85
+ attn_implementation = attn_implementation,
86
+ cache_dir = cache_dir
87
+ )
88
+
89
+ args = SFTConfig(
90
+ output_dir="./phi-3-mini-LoRA",
91
+ eval_strategy="steps",
92
+ do_eval=True,
93
+ optim="adamw_torch",
94
+ per_device_train_batch_size=8,
95
+ gradient_accumulation_steps=4,
96
+ per_device_eval_batch_size=8,
97
+ log_level="debug",
98
+ save_strategy="epoch",
99
+ logging_steps=10,
100
+ learning_rate=1e-4,
101
+ fp16 = not torch.cuda.is_bf16_supported(),
102
+ bf16 = torch.cuda.is_bf16_supported(),
103
+ eval_steps=100,
104
+ dataset_text_field="text",
105
+ max_seq_length=512,
106
+ num_train_epochs=3,
107
+ warmup_ratio=0.1,
108
+ lr_scheduler_type="linear",
109
+ report_to="wandb",
110
+ seed=42,
111
+ )
112
+
113
+ peft_config = LoraConfig(
114
+ r=lora_r,
115
+ lora_alpha=lora_alpha,
116
+ lora_dropout=lora_dropout,
117
+ task_type=TaskType.CAUSAL_LM,
118
+ target_modules=target_modules,
119
+ )
120
+
121
+ model.add_adapter(peft_config)
122
+
123
+ wandb.init(project = "Phi 3", name = "python-phi-3-lora")
124
+
125
+ trainer = SFTTrainer(
126
+ model=model,
127
+ train_dataset=dataset_chatml['train'],
128
+ eval_dataset=dataset_chatml['test'],
129
+ peft_config=peft_config,
130
+ tokenizer=tokenizer,
131
+ args=args,
132
+ )
133
+
134
+ trainer.train()
135
+
136
+ # Save the model to the `output_dir` after training
137
+ model.save_pretrained("./out/")
prompting/.ipynb_checkpoints/train_phi-checkpoint.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
2
+ from huggingface_hub import ModelCard, ModelCardData, HfApi
3
+ from datasets import load_dataset
4
+ from jinja2 import Template
5
+ from trl import SFTTrainer
6
+ import yaml
7
+ import torch
8
+
9
+ # Model Configs
10
+ MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
11
+ NEW_MODEL_NAME = "opus-phi-3-mini-4k-instruct"
12
+ CACHE_DIR = "./../cache"
13
+
14
+ # Dataset Configs
15
+ DATASET_NAME = ""
16
+ SPLIT = "train"
17
+
18
+ # the maximum length of the sequences that the model will handle
19
+ MAX_SEQ_LENGTH = 4096
20
+ num_train_epochs = 1
21
+ license = "apache-2.0"
22
+ username = "darshanmakwana412"
23
+ learning_rate = 1.41e-5
24
+ per_device_train_batch_size = 4
25
+ gradient_accumulation_steps = 1
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)
29
+ dataset = load_dataset(DATASET_NAME, split=SPLIT)
30
+
31
+ # EOS Token is used to mark the end of a sentence
32
+ EOS_TOKEN=tokenizer.eos_token_id
33
+
34
+ def formatting_prompts_func(examples):
35
+ # Extract the conversations from the examples.
36
+ convos = examples["conversations"]
37
+ # Initialize an empty list to store the formatted texts.
38
+ texts = []
39
+ # Define a dictionary to map the 'from' field in the conversation to a prefix.
40
+ mapper = {"system": "system\n", "human": "\nuser\n", "gpt": "\nassistant\n"}
41
+ # Define a dictionary to map the 'from' field in the conversation to a suffix.
42
+ end_mapper = {"system": "", "human": "", "gpt": ""}
43
+ # Iterate over each conversation.
44
+ for convo in convos:
45
+ # Format the conversation by joining each turn with its corresponding prefix and suffix.
46
+ # Append the EOS token to the end of the conversation.
47
+ text = "".join(f"{mapper[(turn := x['from'])]} {x['value']}\n{end_mapper[turn]}" for x in convo)
48
+ texts.append(f"{text}{EOS_TOKEN}")
49
+ # Return the formatted texts.
50
+ return {"text": texts}
51
+
52
+ dataset = dataset.map(formatting_prompts_func, batched=True)
53
+
54
+ args = TrainingArguments(
55
+ evaluation_strategy="steps",
56
+ per_device_train_batch_size=per_device_train_batch_size,
57
+ gradient_accumulation_steps=gradient_accumulation_steps,
58
+ gradient_checkpointing=True,
59
+ learning_rate=learning_rate,
60
+ fp16 = not torch.cuda.is_bf16_supported(),
61
+ bf16 = torch.cuda.is_bf16_supported(),
62
+ max_steps=-1,
63
+ num_train_epochs=num_train_epochs,
64
+ save_strategy="epoch",
65
+ logging_steps=10,
66
+ output_dir=NEW_MODEL_NAME,
67
+ optim="paged_adamw_32bit",
68
+ lr_scheduler_type="linear"
69
+ )
70
+
71
+ trainer = SFTTrainer(
72
+ model=model,
73
+ args=args,
74
+ train_dataset=dataset,
75
+ dataset_text_field="text",
76
+ max_seq_length=MAX_SEQ_LENGTH,
77
+ formatting_func=formatting_prompts_func
78
+ )
79
+
80
+ import gc
81
+ import os
82
+
83
+ gc.collect()
84
+ torch.cuda.empty_cache()
85
+
86
+ trainer.train()
prompting/.ipynb_checkpoints/training-checkpoint.ipynb ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "12315053-0630-4d3d-8028-02035c2dbf14",
6
+ "metadata": {
7
+ "jp-MarkdownHeadingCollapsed": true
8
+ },
9
+ "source": [
10
+ "## Slide Speech Dataset"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "f8eed0bf-d822-4091-8762-df6582095ab4",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "\"\"\"\n",
21
+ "Dir Structure:\n",
22
+ " - data\n",
23
+ " - info\n",
24
+ " - test\n",
25
+ " - val\n",
26
+ "\"\"\""
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "id": "e999f437-d756-492d-b873-6ee656279b53",
32
+ "metadata": {
33
+ "jp-MarkdownHeadingCollapsed": true
34
+ },
35
+ "source": [
36
+ "## Phi 3 Tinkering"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "3f5f1033-9118-4106-a7fb-6c3b527fe075",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "\"\"\"\n",
47
+ "Prompt template for Phi-3\n",
48
+ "<|system|>\n",
49
+ "You are a python developer.<|end|>\n",
50
+ "<|user|>\n",
51
+ "Help me generate a bubble sort algorithm<|end|>\n",
52
+ "<|assistant|>\n",
53
+ "\"\"\""
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "1e9e81b0-ae3d-46f1-97ff-138984d07a28",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "import torch\n",
64
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
65
+ "\n",
66
+ "cache_dir = \"./../cache\"\n",
67
+ "model_id = \"microsoft/Phi-3-mini-4k-instruct\"\n",
68
+ "device = \"cuda:0\"\n",
69
+ "dtype = torch.float16\n",
70
+ "\n",
71
+ "model = AutoModelForCausalLM.from_pretrained(\n",
72
+ " model_id,\n",
73
+ " device_map = device,\n",
74
+ " torch_dtype = dtype,\n",
75
+ " trust_remote_code=True,\n",
76
+ " cache_dir = cache_dir,\n",
77
+ " attn_implementation = \"flash_attention_2\"\n",
78
+ ")\n",
79
+ "tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir)\n",
80
+ "\n",
81
+ "pipe = pipeline(\n",
82
+ " \"text-generation\",\n",
83
+ " model = model,\n",
84
+ " tokenizer = tokenizer\n",
85
+ ")"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "93b9fb26-4661-4a35-ad62-b87834f577bc",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "messages = [\n",
96
+ " {\"role\": \"system\", \"content\": \"You are a python developer\"},\n",
97
+ " {\"role\": \"user\", \"content\": \"Help me generate a bubble sort algorithm\"}\n",
98
+ "]\n",
99
+ "\n",
100
+ "generation_args = {\n",
101
+ " \"max_new_tokens\": 600,\n",
102
+ " \"return_full_text\": False,\n",
103
+ " \"temperature\": 1.0,\n",
104
+ " \"do_sample\": True\n",
105
+ "}\n",
106
+ "\n",
107
+ "output = pipe(messages, **generation_args)\n",
108
+ "print(output[0][\"generated_text\"])"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "id": "62db724a-9e20-422d-a5b3-dd55cae55cc7",
114
+ "metadata": {
115
+ "jp-MarkdownHeadingCollapsed": true
116
+ },
117
+ "source": [
118
+ "## Training Phi 3"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "2036e6b5-c794-4668-9a79-8a53a2736cfa",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig\n",
129
+ "from huggingface_hub import ModelCard, ModelCardData, HfApi\n",
130
+ "from datasets import load_dataset\n",
131
+ "from jinja2 import Template\n",
132
+ "from trl import SFTTrainer\n",
133
+ "import yaml\n",
134
+ "import torch"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "2dcff92e-ab3e-407a-8595-31ffae5f7acd",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "# Model Configs\n",
145
+ "MODEL_ID = \"microsoft/Phi-3-mini-4k-instruct\"\n",
146
+ "NEW_MODEL_NAME = \"opus-phi-3-mini-4k-instruct\"\n",
147
+ "CACHE_DIR = \"./../cache\"\n",
148
+ "\n",
149
+ "# Dataset Configs\n",
150
+ "DATASET_NAME = \"\"\n",
151
+ "SPLIT = \"train\"\n",
152
+ "\n",
153
+ "# the maximum length of the sequences that the model will handle\n",
154
+ "MAX_SEQ_LENGTH = 4096\n",
155
+ "num_train_epochs = 1\n",
156
+ "license = \"apache-2.0\"\n",
157
+ "username = \"darshanmakwana412\"\n",
158
+ "learning_rate = 1.41e-5\n",
159
+ "per_device_train_batch_size = 4\n",
160
+ "gradient_accumulation_steps = 1\n",
161
+ "\n",
162
+ "# If bd16 is supported use bf16 otherwise use f16\n",
163
+ "if torch.cuda.is_bf16_supported():\n",
164
+ " compute_dtype = torch.bfloat16\n",
165
+ "else:\n",
166
+ " compute_dtype = torch.float16"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "0fbcf3c6-dd94-4133-a805-910d57c9f974",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)\n",
177
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)\n",
178
+ "# dataset = load_dataset(DATASET_NAME, split=SPLIT)\n",
179
+ "\n",
180
+ "# EOS Token is used to mark the end of a sentence\n",
181
+ "EOS_TOKEN=tokenizer.eos_token_id"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "e31954ea-1838-4992-9132-32e59c42a128",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "def formatting_prompts_func(examples):\n",
192
+ " # Extract the conversations from the examples.\n",
193
+ " convos = examples[\"conversations\"]\n",
194
+ " # Initialize an empty list to store the formatted texts.\n",
195
+ " texts = []\n",
196
+ " # Define a dictionary to map the 'from' field in the conversation to a prefix.\n",
197
+ " mapper = {\"system\": \"system\\n\", \"human\": \"\\nuser\\n\", \"gpt\": \"\\nassistant\\n\"}\n",
198
+ " # Define a dictionary to map the 'from' field in the conversation to a suffix.\n",
199
+ " end_mapper = {\"system\": \"\", \"human\": \"\", \"gpt\": \"\"}\n",
200
+ " # Iterate over each conversation.\n",
201
+ " for convo in convos:\n",
202
+ " # Format the conversation by joining each turn with its corresponding prefix and suffix.\n",
203
+ " # Append the EOS token to the end of the conversation.\n",
204
+ " text = \"\".join(f\"{mapper[(turn := x['from'])]} {x['value']}\\n{end_mapper[turn]}\" for x in convo)\n",
205
+ " texts.append(f\"{text}{EOS_TOKEN}\")\n",
206
+ " # Return the formatted texts.\n",
207
+ " return {\"text\": texts}\n",
208
+ "\n",
209
+ "dataset = dataset.map(formatting_prompts_func, batched=True)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "3086c2a4-7cca-461e-894c-376046089fab",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "args = TrainingArguments(\n",
220
+ " evaluation_strategy=\"steps\",\n",
221
+ " per_device_train_batch_size=7,\n",
222
+ " gradient_accumulation_steps=4,\n",
223
+ " gradient_checkpointing=True,\n",
224
+ " learning_rate=1e-4,\n",
225
+ " fp16 = not torch.cuda.is_bf16_supported(),\n",
226
+ " bf16 = torch.cuda.is_bf16_supported(),\n",
227
+ " max_steps=-1,\n",
228
+ " num_train_epochs=3,\n",
229
+ " save_strategy=\"epoch\",\n",
230
+ " logging_steps=10,\n",
231
+ " output_dir=NEW_MODEL_NAME,\n",
232
+ " optim=\"paged_adamw_32bit\",\n",
233
+ " lr_scheduler_type=\"linear\"\n",
234
+ ")\n",
235
+ "\n",
236
+ "trainer = SFTTrainer(\n",
237
+ " model=model,\n",
238
+ " args=args,\n",
239
+ " train_dataset=dataset,\n",
240
+ " dataset_text_field=\"text\",\n",
241
+ " max_seq_length=128,\n",
242
+ " formatting_func=formatting_prompts_func\n",
243
+ ")\n",
244
+ "\n",
245
+ "device = \"cuda:0\"\n",
246
+ "\n",
247
+ "import gc\n",
248
+ "import os\n",
249
+ "\n",
250
+ "gc.collect()\n",
251
+ "torch.cuda.empty_cache()\n",
252
+ "\n",
253
+ "trainer.train()"
254
+ ]
255
+ }
256
+ ],
257
+ "metadata": {
258
+ "kernelspec": {
259
+ "display_name": "Python 3 (ipykernel)",
260
+ "language": "python",
261
+ "name": "python3"
262
+ },
263
+ "language_info": {
264
+ "codemirror_mode": {
265
+ "name": "ipython",
266
+ "version": 3
267
+ },
268
+ "file_extension": ".py",
269
+ "mimetype": "text/x-python",
270
+ "name": "python",
271
+ "nbconvert_exporter": "python",
272
+ "pygments_lexer": "ipython3",
273
+ "version": "3.8.10"
274
+ }
275
+ },
276
+ "nbformat": 4,
277
+ "nbformat_minor": 5
278
+ }
prompting/RepCodec/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
prompting/RepCodec/.ipynb_checkpoints/tinker-checkpoint.ipynb ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "997bed07-1181-4562-962a-cb8aa18e1d16",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from repcodec.RepCodec import RepCodec\n",
11
+ "import torch\n",
12
+ "import yaml\n",
13
+ "\n",
14
+ "config = \"repcodec/configs/repcodec_dim1024.yaml\"\n",
15
+ "with open(config) as fp:\n",
16
+ " conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
17
+ "\n",
18
+ "model = RepCodec(**conf)\n",
19
+ "model.load_state_dict(torch.load(\"./../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
20
+ "model.quantizer.initial()\n",
21
+ "model.eval()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "5c5516f1-3565-4080-8612-d5ce52ea2a4d",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# input shape: (batch size, hidden dim, sequence length)\n",
32
+ "random_features = torch.randn(size=(1, 1024, 100))\n",
33
+ "with torch.no_grad():\n",
34
+ " x = model.encoder(random_features)\n",
35
+ " z = model.projector(x)\n",
36
+ " _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
37
+ " tokens = idx.cpu().data.numpy().tolist()[0]"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "439ecea7-f0d4-4a61-80c2-729138beee32",
43
+ "metadata": {},
44
+ "source": [
45
+ "## Dump Representations"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "6efa1891-0810-4cfb-9552-764297209e99",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "python3 examples/dump_feature.py --model_type data2vec --tsv_path \"./files/train.clean.100.tsv\" --ckpt_path \"./../models/vox_pretrained.pt\" --layer 18 --feat_dir \"./features/train.clean.100\""
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "id": "cbd1c550-0606-4217-ac65-55ae92843f19",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "from datasets import load_dataset\n",
66
+ "from tqdm import tqdm\n",
67
+ "import pandas as pd\n",
68
+ "\n",
69
+ "cache_dir = \"./../../cache\"\n",
70
+ "\n",
71
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
72
+ "\n",
73
+ "# for split in dataset.keys():\n",
74
+ "# data = dataset[split]\n",
75
+ "# num_frames = []\n",
76
+ "# for idx in tqdm(range(len(data))):\n",
77
+ "# audio = data[idx][\"audio\"]\n",
78
+ "# num_frames.append(int(len(audio[\"array\"]) * 16000 // audio[\"sampling_rate\"]))\n",
79
+ " \n",
80
+ "# df = pd.DataFrame.from_dict({\n",
81
+ "# \"file_path\": list(data[\"file\"]),\n",
82
+ "# \"num_frames\": num_frames\n",
83
+ "# })\n",
84
+ "# df.to_csv(f\"./files/{split}.tsv\", sep=\"\\t\", index=False)"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "5b4af1af-5726-4899-8272-dfe867cb48a8",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "dataset[\"train.clean.100\"][0]"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "ae6a0ef4-8c0a-4f6e-9a81-a9c3350e1266",
100
+ "metadata": {},
101
+ "source": [
102
+ "## Prepare the Dataset"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "id": "b1247988-5eaa-492a-a3ab-2b11505126a6",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "from datasets import Dataset, load_dataset\n",
113
+ "from collections import defaultdict\n",
114
+ "from tqdm import tqdm\n",
115
+ "import numpy as np\n",
116
+ "import string\n",
117
+ "\n",
118
+ "cache_dir = \"./../../cache\"\n",
119
+ "\n",
120
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
121
+ "\n",
122
+ "def process(text):\n",
123
+ "\n",
124
+ " # Lower case every letter\n",
125
+ " text = text.lower()\n",
126
+ "\n",
127
+ " # Remove punctuation\n",
128
+ " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
129
+ " translation_table = str.maketrans('', '', punctuation_to_remove)\n",
130
+ " text = text.translate(translation_table)\n",
131
+ "\n",
132
+ " # Remove whitespaces from front and behind\n",
133
+ " while text[0] == ' ' or text[-1] == ' ':\n",
134
+ " if text[0] == ' ':\n",
135
+ " text = text[1:]\n",
136
+ " if text[-1] == ' ':\n",
137
+ " text = text[:-1]\n",
138
+ " \n",
139
+ " return text\n",
140
+ "\n",
141
+ "dataset = dataset.remove_columns([\"audio\", \"speaker_id\", \"chapter_id\"])\n",
142
+ "\n",
143
+ "tokenized_ds = defaultdict(lambda: [])\n",
144
+ "\n",
145
+ "for split in dataset.keys():\n",
146
+ "\n",
147
+ " texts = []\n",
148
+ " tokens = []\n",
149
+ " tkns = np.load(f\"./examples/tkns/{split}.npz\")\n",
150
+ "\n",
151
+ " for idx, key in enumerate(tqdm(tkns.files)):\n",
152
+ " tokens.append(list(tkns[key]))\n",
153
+ " texts.append(process(dataset[split][idx][\"text\"]))\n",
154
+ "\n",
155
+ " tokenized_ds[split] = Dataset.from_dict({\n",
156
+ " \"text\": texts,\n",
157
+ " \"audio_tokens\": tokens\n",
158
+ " })"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "bfc82444-3081-4138-aa06-6fb0b7cbc6c3",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "from datasets import dataset_dict, DatasetDict\n",
169
+ "\n",
170
+ "tds = DatasetDict(tokenized_ds)"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "id": "006171b9-d479-4462-9642-d126f77edfc2",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "tds.save_to_disk(\"librispeech_tokenized.hf\")"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 2,
186
+ "id": "12970376-4f6f-4926-a954-29c32043b64c",
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "ename": "ValueError",
191
+ "evalue": "Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}",
192
+ "output_type": "error",
193
+ "traceback": [
194
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
195
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
196
+ "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./librispeech_tokenized.hf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
197
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2594\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2589\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[1;32m 2590\u001b[0m (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2591\u001b[0m )\n\u001b[1;32m 2593\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2594\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2595\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2596\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2597\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2598\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2599\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2600\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2601\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2604\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2605\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2606\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2607\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2608\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2609\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2611\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2612\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
198
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2266\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2264\u001b[0m download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 2265\u001b[0m download_config\u001b[38;5;241m.\u001b[39mstorage_options\u001b[38;5;241m.\u001b[39mupdate(storage_options)\n\u001b[0;32m-> 2266\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2267\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2268\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2269\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2270\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2271\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2272\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2273\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2275\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_require_default_config_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2276\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_custom_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2277\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2278\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 2279\u001b[0m builder_kwargs \u001b[38;5;241m=\u001b[39m dataset_module\u001b[38;5;241m.\u001b[39mbuilder_kwargs\n",
199
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1825\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m 1818\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m LocalDatasetModuleFactoryWithScript(\n\u001b[1;32m 1819\u001b[0m combined_path,\n\u001b[1;32m 1820\u001b[0m download_mode\u001b[38;5;241m=\u001b[39mdownload_mode,\n\u001b[1;32m 1821\u001b[0m dynamic_modules_path\u001b[38;5;241m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m 1822\u001b[0m trust_remote_code\u001b[38;5;241m=\u001b[39mtrust_remote_code,\n\u001b[1;32m 1823\u001b[0m )\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1824\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[0;32m-> 1825\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mLocalDatasetModuleFactoryWithoutScript\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1826\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\n\u001b[1;32m 1827\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1828\u001b[0m \u001b[38;5;66;03m# Try remotely\u001b[39;00m\n\u001b[1;32m 1829\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_relative_path(path) \u001b[38;5;129;01mand\u001b[39;00m path\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
200
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1040\u001b[0m, in \u001b[0;36mLocalDatasetModuleFactoryWithoutScript.get_module\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1034\u001b[0m patterns \u001b[38;5;241m=\u001b[39m get_data_patterns(base_path)\n\u001b[1;32m 1035\u001b[0m data_files \u001b[38;5;241m=\u001b[39m DataFilesDict\u001b[38;5;241m.\u001b[39mfrom_patterns(\n\u001b[1;32m 1036\u001b[0m patterns,\n\u001b[1;32m 1037\u001b[0m base_path\u001b[38;5;241m=\u001b[39mbase_path,\n\u001b[1;32m 1038\u001b[0m allowed_extensions\u001b[38;5;241m=\u001b[39mALL_ALLOWED_EXTENSIONS,\n\u001b[1;32m 1039\u001b[0m )\n\u001b[0;32m-> 1040\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43minfer_module_for_data_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1041\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1042\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1043\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1044\u001b[0m data_files \u001b[38;5;241m=\u001b[39m data_files\u001b[38;5;241m.\u001b[39mfilter_extensions(_MODULE_TO_EXTENSIONS[module_name])\n\u001b[1;32m 1045\u001b[0m \u001b[38;5;66;03m# Collect metadata files if the module supports them\u001b[39;00m\n",
201
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:596\u001b[0m, in \u001b[0;36minfer_module_for_data_files\u001b[0;34m(data_files, path, download_config)\u001b[0m\n\u001b[1;32m 594\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(split_modules\u001b[38;5;241m.\u001b[39mvalues()))\n\u001b[1;32m 595\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m((module_name, default_builder_kwargs) \u001b[38;5;241m!=\u001b[39m split_module \u001b[38;5;28;01mfor\u001b[39;00m split_module \u001b[38;5;129;01min\u001b[39;00m split_modules\u001b[38;5;241m.\u001b[39mvalues()):\n\u001b[0;32m--> 596\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt infer the same data file format for all splits. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msplit_modules\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m module_name:\n\u001b[1;32m 598\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DataFilesNotFoundError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo (supported) data files found\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m path \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
202
+ "\u001b[0;31mValueError\u001b[0m: Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}"
203
+ ]
204
+ }
205
+ ],
206
+ "source": [
207
+ "from datasets import load_dataset\n",
208
+ "\n",
209
+ "dataset = load_dataset(\"./librispeech_tokenized.hf\")"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 8,
215
+ "id": "b3ba0d58-b788-43b5-87a7-726aaa12dbbd",
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "from datasets import dataset_dict, DatasetDict, Dataset\n",
220
+ "\n",
221
+ "dataset = DatasetDict.load_from_disk(\"./librispeech_tokenized.hf\")"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 13,
227
+ "id": "b7239186-73ae-407a-b9f6-b5a16f3a7ddc",
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "data": {
232
+ "text/plain": [
233
+ "726"
234
+ ]
235
+ },
236
+ "execution_count": 13,
237
+ "metadata": {},
238
+ "output_type": "execute_result"
239
+ }
240
+ ],
241
+ "source": [
242
+ "len(dataset[\"train.clean.100\"][0][\"audio_tokens\"])"
243
+ ]
244
+ }
245
+ ],
246
+ "metadata": {
247
+ "kernelspec": {
248
+ "display_name": "Python 3 (ipykernel)",
249
+ "language": "python",
250
+ "name": "python3"
251
+ },
252
+ "language_info": {
253
+ "codemirror_mode": {
254
+ "name": "ipython",
255
+ "version": 3
256
+ },
257
+ "file_extension": ".py",
258
+ "mimetype": "text/x-python",
259
+ "name": "python",
260
+ "nbconvert_exporter": "python",
261
+ "pygments_lexer": "ipython3",
262
+ "version": "3.8.10"
263
+ }
264
+ },
265
+ "nbformat": 4,
266
+ "nbformat_minor": 5
267
+ }
prompting/RepCodec/LICENSE ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) ByteDance, Inc. and its affiliates.
4
+ Copyright (c) Chutong Meng
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+ Attribution-NonCommercial 4.0 International
31
+
32
+ =======================================================================
33
+
34
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
35
+ does not provide legal services or legal advice. Distribution of
36
+ Creative Commons public licenses does not create a lawyer-client or
37
+ other relationship. Creative Commons makes its licenses and related
38
+ information available on an "as-is" basis. Creative Commons gives no
39
+ warranties regarding its licenses, any material licensed under their
40
+ terms and conditions, or any related information. Creative Commons
41
+ disclaims all liability for damages resulting from their use to the
42
+ fullest extent possible.
43
+
44
+ Using Creative Commons Public Licenses
45
+
46
+ Creative Commons public licenses provide a standard set of terms and
47
+ conditions that creators and other rights holders may use to share
48
+ original works of authorship and other material subject to copyright
49
+ and certain other rights specified in the public license below. The
50
+ following considerations are for informational purposes only, are not
51
+ exhaustive, and do not form part of our licenses.
52
+
53
+ Considerations for licensors: Our public licenses are
54
+ intended for use by those authorized to give the public
55
+ permission to use material in ways otherwise restricted by
56
+ copyright and certain other rights. Our licenses are
57
+ irrevocable. Licensors should read and understand the terms
58
+ and conditions of the license they choose before applying it.
59
+ Licensors should also secure all rights necessary before
60
+ applying our licenses so that the public can reuse the
61
+ material as expected. Licensors should clearly mark any
62
+ material not subject to the license. This includes other CC-
63
+ licensed material, or material used under an exception or
64
+ limitation to copyright. More considerations for licensors:
65
+ wiki.creativecommons.org/Considerations_for_licensors
66
+
67
+ Considerations for the public: By using one of our public
68
+ licenses, a licensor grants the public permission to use the
69
+ licensed material under specified terms and conditions. If
70
+ the licensor's permission is not necessary for any reason--for
71
+ example, because of any applicable exception or limitation to
72
+ copyright--then that use is not regulated by the license. Our
73
+ licenses grant only permissions under copyright and certain
74
+ other rights that a licensor has authority to grant. Use of
75
+ the licensed material may still be restricted for other
76
+ reasons, including because others have copyright or other
77
+ rights in the material. A licensor may make special requests,
78
+ such as asking that all changes be marked or described.
79
+ Although not required by our licenses, you are encouraged to
80
+ respect those requests where reasonable. More_considerations
81
+ for the public:
82
+ wiki.creativecommons.org/Considerations_for_licensees
83
+
84
+ =======================================================================
85
+
86
+ Creative Commons Attribution-NonCommercial 4.0 International Public
87
+ License
88
+
89
+ By exercising the Licensed Rights (defined below), You accept and agree
90
+ to be bound by the terms and conditions of this Creative Commons
91
+ Attribution-NonCommercial 4.0 International Public License ("Public
92
+ License"). To the extent this Public License may be interpreted as a
93
+ contract, You are granted the Licensed Rights in consideration of Your
94
+ acceptance of these terms and conditions, and the Licensor grants You
95
+ such rights in consideration of benefits the Licensor receives from
96
+ making the Licensed Material available under these terms and
97
+ conditions.
98
+
99
+ Section 1 -- Definitions.
100
+
101
+ a. Adapted Material means material subject to Copyright and Similar
102
+ Rights that is derived from or based upon the Licensed Material
103
+ and in which the Licensed Material is translated, altered,
104
+ arranged, transformed, or otherwise modified in a manner requiring
105
+ permission under the Copyright and Similar Rights held by the
106
+ Licensor. For purposes of this Public License, where the Licensed
107
+ Material is a musical work, performance, or sound recording,
108
+ Adapted Material is always produced where the Licensed Material is
109
+ synched in timed relation with a moving image.
110
+
111
+ b. Adapter's License means the license You apply to Your Copyright
112
+ and Similar Rights in Your contributions to Adapted Material in
113
+ accordance with the terms and conditions of this Public License.
114
+
115
+ c. Copyright and Similar Rights means copyright and/or similar rights
116
+ closely related to copyright including, without limitation,
117
+ performance, broadcast, sound recording, and Sui Generis Database
118
+ Rights, without regard to how the rights are labeled or
119
+ categorized. For purposes of this Public License, the rights
120
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
121
+ Rights.
122
+ d. Effective Technological Measures means those measures that, in the
123
+ absence of proper authority, may not be circumvented under laws
124
+ fulfilling obligations under Article 11 of the WIPO Copyright
125
+ Treaty adopted on December 20, 1996, and/or similar international
126
+ agreements.
127
+
128
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
129
+ any other exception or limitation to Copyright and Similar Rights
130
+ that applies to Your use of the Licensed Material.
131
+
132
+ f. Licensed Material means the artistic or literary work, database,
133
+ or other material to which the Licensor applied this Public
134
+ License.
135
+
136
+ g. Licensed Rights means the rights granted to You subject to the
137
+ terms and conditions of this Public License, which are limited to
138
+ all Copyright and Similar Rights that apply to Your use of the
139
+ Licensed Material and that the Licensor has authority to license.
140
+
141
+ h. Licensor means the individual(s) or entity(ies) granting rights
142
+ under this Public License.
143
+
144
+ i. NonCommercial means not primarily intended for or directed towards
145
+ commercial advantage or monetary compensation. For purposes of
146
+ this Public License, the exchange of the Licensed Material for
147
+ other material subject to Copyright and Similar Rights by digital
148
+ file-sharing or similar means is NonCommercial provided there is
149
+ no payment of monetary compensation in connection with the
150
+ exchange.
151
+
152
+ j. Share means to provide material to the public by any means or
153
+ process that requires permission under the Licensed Rights, such
154
+ as reproduction, public display, public performance, distribution,
155
+ dissemination, communication, or importation, and to make material
156
+ available to the public including in ways that members of the
157
+ public may access the material from a place and at a time
158
+ individually chosen by them.
159
+
160
+ k. Sui Generis Database Rights means rights other than copyright
161
+ resulting from Directive 96/9/EC of the European Parliament and of
162
+ the Council of 11 March 1996 on the legal protection of databases,
163
+ as amended and/or succeeded, as well as other essentially
164
+ equivalent rights anywhere in the world.
165
+
166
+ l. You means the individual or entity exercising the Licensed Rights
167
+ under this Public License. Your has a corresponding meaning.
168
+
169
+ Section 2 -- Scope.
170
+
171
+ a. License grant.
172
+
173
+ 1. Subject to the terms and conditions of this Public License,
174
+ the Licensor hereby grants You a worldwide, royalty-free,
175
+ non-sublicensable, non-exclusive, irrevocable license to
176
+ exercise the Licensed Rights in the Licensed Material to:
177
+
178
+ a. reproduce and Share the Licensed Material, in whole or
179
+ in part, for NonCommercial purposes only; and
180
+
181
+ b. produce, reproduce, and Share Adapted Material for
182
+ NonCommercial purposes only.
183
+
184
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
185
+ Exceptions and Limitations apply to Your use, this Public
186
+ License does not apply, and You do not need to comply with
187
+ its terms and conditions.
188
+
189
+ 3. Term. The term of this Public License is specified in Section
190
+ 6(a).
191
+
192
+ 4. Media and formats; technical modifications allowed. The
193
+ Licensor authorizes You to exercise the Licensed Rights in
194
+ all media and formats whether now known or hereafter created,
195
+ and to make technical modifications necessary to do so. The
196
+ Licensor waives and/or agrees not to assert any right or
197
+ authority to forbid You from making technical modifications
198
+ necessary to exercise the Licensed Rights, including
199
+ technical modifications necessary to circumvent Effective
200
+ Technological Measures. For purposes of this Public License,
201
+ simply making modifications authorized by this Section 2(a)
202
+ (4) never produces Adapted Material.
203
+
204
+ 5. Downstream recipients.
205
+
206
+ a. Offer from the Licensor -- Licensed Material. Every
207
+ recipient of the Licensed Material automatically
208
+ receives an offer from the Licensor to exercise the
209
+ Licensed Rights under the terms and conditions of this
210
+ Public License.
211
+
212
+ b. No downstream restrictions. You may not offer or impose
213
+ any additional or different terms or conditions on, or
214
+ apply any Effective Technological Measures to, the
215
+ Licensed Material if doing so restricts exercise of the
216
+ Licensed Rights by any recipient of the Licensed
217
+ Material.
218
+
219
+ 6. No endorsement. Nothing in this Public License constitutes or
220
+ may be construed as permission to assert or imply that You
221
+ are, or that Your use of the Licensed Material is, connected
222
+ with, or sponsored, endorsed, or granted official status by,
223
+ the Licensor or others designated to receive attribution as
224
+ provided in Section 3(a)(1)(A)(i).
225
+
226
+ b. Other rights.
227
+
228
+ 1. Moral rights, such as the right of integrity, are not
229
+ licensed under this Public License, nor are publicity,
230
+ privacy, and/or other similar personality rights; however, to
231
+ the extent possible, the Licensor waives and/or agrees not to
232
+ assert any such rights held by the Licensor to the limited
233
+ extent necessary to allow You to exercise the Licensed
234
+ Rights, but not otherwise.
235
+
236
+ 2. Patent and trademark rights are not licensed under this
237
+ Public License.
238
+
239
+ 3. To the extent possible, the Licensor waives any right to
240
+ collect royalties from You for the exercise of the Licensed
241
+ Rights, whether directly or through a collecting society
242
+ under any voluntary or waivable statutory or compulsory
243
+ licensing scheme. In all other cases the Licensor expressly
244
+ reserves any right to collect such royalties, including when
245
+ the Licensed Material is used other than for NonCommercial
246
+ purposes.
247
+
248
+ Section 3 -- License Conditions.
249
+
250
+ Your exercise of the Licensed Rights is expressly made subject to the
251
+ following conditions.
252
+
253
+ a. Attribution.
254
+
255
+ 1. If You Share the Licensed Material (including in modified
256
+ form), You must:
257
+
258
+ a. retain the following if it is supplied by the Licensor
259
+ with the Licensed Material:
260
+
261
+ i. identification of the creator(s) of the Licensed
262
+ Material and any others designated to receive
263
+ attribution, in any reasonable manner requested by
264
+ the Licensor (including by pseudonym if
265
+ designated);
266
+
267
+ ii. a copyright notice;
268
+
269
+ iii. a notice that refers to this Public License;
270
+
271
+ iv. a notice that refers to the disclaimer of
272
+ warranties;
273
+
274
+ v. a URI or hyperlink to the Licensed Material to the
275
+ extent reasonably practicable;
276
+
277
+ b. indicate if You modified the Licensed Material and
278
+ retain an indication of any previous modifications; and
279
+
280
+ c. indicate the Licensed Material is licensed under this
281
+ Public License, and include the text of, or the URI or
282
+ hyperlink to, this Public License.
283
+
284
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
285
+ reasonable manner based on the medium, means, and context in
286
+ which You Share the Licensed Material. For example, it may be
287
+ reasonable to satisfy the conditions by providing a URI or
288
+ hyperlink to a resource that includes the required
289
+ information.
290
+
291
+ 3. If requested by the Licensor, You must remove any of the
292
+ information required by Section 3(a)(1)(A) to the extent
293
+ reasonably practicable.
294
+
295
+ 4. If You Share Adapted Material You produce, the Adapter's
296
+ License You apply must not prevent recipients of the Adapted
297
+ Material from complying with this Public License.
298
+
299
+ Section 4 -- Sui Generis Database Rights.
300
+
301
+ Where the Licensed Rights include Sui Generis Database Rights that
302
+ apply to Your use of the Licensed Material:
303
+
304
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
305
+ to extract, reuse, reproduce, and Share all or a substantial
306
+ portion of the contents of the database for NonCommercial purposes
307
+ only;
308
+
309
+ b. if You include all or a substantial portion of the database
310
+ contents in a database in which You have Sui Generis Database
311
+ Rights, then the database in which You have Sui Generis Database
312
+ Rights (but not its individual contents) is Adapted Material; and
313
+
314
+ c. You must comply with the conditions in Section 3(a) if You Share
315
+ all or a substantial portion of the contents of the database.
316
+
317
+ For the avoidance of doubt, this Section 4 supplements and does not
318
+ replace Your obligations under this Public License where the Licensed
319
+ Rights include other Copyright and Similar Rights.
320
+
321
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
322
+
323
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
324
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
325
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
326
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
327
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
328
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
329
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
330
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
331
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
332
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
333
+
334
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
335
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
336
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
337
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
338
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
339
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
340
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
341
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
342
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
343
+
344
+ c. The disclaimer of warranties and limitation of liability provided
345
+ above shall be interpreted in a manner that, to the extent
346
+ possible, most closely approximates an absolute disclaimer and
347
+ waiver of all liability.
348
+
349
+ Section 6 -- Term and Termination.
350
+
351
+ a. This Public License applies for the term of the Copyright and
352
+ Similar Rights licensed here. However, if You fail to comply with
353
+ this Public License, then Your rights under this Public License
354
+ terminate automatically.
355
+
356
+ b. Where Your right to use the Licensed Material has terminated under
357
+ Section 6(a), it reinstates:
358
+
359
+ 1. automatically as of the date the violation is cured, provided
360
+ it is cured within 30 days of Your discovery of the
361
+ violation; or
362
+
363
+ 2. upon express reinstatement by the Licensor.
364
+
365
+ For the avoidance of doubt, this Section 6(b) does not affect any
366
+ right the Licensor may have to seek remedies for Your violations
367
+ of this Public License.
368
+
369
+ c. For the avoidance of doubt, the Licensor may also offer the
370
+ Licensed Material under separate terms or conditions or stop
371
+ distributing the Licensed Material at any time; however, doing so
372
+ will not terminate this Public License.
373
+
374
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
375
+ License.
376
+
377
+ Section 7 -- Other Terms and Conditions.
378
+
379
+ a. The Licensor shall not be bound by any additional or different
380
+ terms or conditions communicated by You unless expressly agreed.
381
+
382
+ b. Any arrangements, understandings, or agreements regarding the
383
+ Licensed Material not stated herein are separate from and
384
+ independent of the terms and conditions of this Public License.
385
+
386
+ Section 8 -- Interpretation.
387
+
388
+ a. For the avoidance of doubt, this Public License does not, and
389
+ shall not be interpreted to, reduce, limit, restrict, or impose
390
+ conditions on any use of the Licensed Material that could lawfully
391
+ be made without permission under this Public License.
392
+
393
+ b. To the extent possible, if any provision of this Public License is
394
+ deemed unenforceable, it shall be automatically reformed to the
395
+ minimum extent necessary to make it enforceable. If the provision
396
+ cannot be reformed, it shall be severed from this Public License
397
+ without affecting the enforceability of the remaining terms and
398
+ conditions.
399
+
400
+ c. No term or condition of this Public License will be waived and no
401
+ failure to comply consented to unless expressly agreed to by the
402
+ Licensor.
403
+
404
+ d. Nothing in this Public License constitutes or may be interpreted
405
+ as a limitation upon, or waiver of, any privileges and immunities
406
+ that apply to the Licensor or You, including from the legal
407
+ processes of any jurisdiction or authority.
408
+
409
+ =======================================================================
410
+
411
+ Creative Commons is not a party to its public
412
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
413
+ its public licenses to material it publishes and in those instances
414
+ will be considered the “Licensor.” The text of the Creative Commons
415
+ public licenses is dedicated to the public domain under the CC0 Public
416
+ Domain Dedication. Except for the limited purpose of indicating that
417
+ material is shared under a Creative Commons public license or as
418
+ otherwise permitted by the Creative Commons policies published at
419
+ creativecommons.org/policies, Creative Commons does not authorize the
420
+ use of the trademark "Creative Commons" or any other trademark or logo
421
+ of Creative Commons without its prior written consent including,
422
+ without limitation, in connection with any unauthorized modifications
423
+ to any of its public licenses or any other arrangements,
424
+ understandings, or agreements concerning use of licensed material. For
425
+ the avoidance of doubt, this paragraph does not form part of the
426
+ public licenses.
427
+
428
+ Creative Commons may be contacted at creativecommons.org.
prompting/RepCodec/README.md ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RepCodec: A Speech Representation Codec for Speech Tokenization
2
+
3
+ > [**RepCodec: A Speech Representation Codec for Speech Tokenization**](https://arxiv.org/abs/2309.00169)
4
+
5
+ ## Introduction
6
+
7
+ **RepCodec** is a speech tokenization method for converting a speech waveform into a sequence of discrete semantic
8
+ tokens.
9
+ The main idea is to train a representation codec which learns a vector quantization codebook through reconstructing the
10
+ input speech representations from speech encoders like HuBERT or data2vec.
11
+ Extensive experiments show that RepCodec significantly outperforms the widely used k-means clustering approach in both
12
+ speech understanding and generation.
13
+ Also, RepCodec generalizes well across various speech encoders and languages.
14
+
15
+ <img src="images/RepCodec.png" alt="se" width="1000" />
16
+
17
+ ## RepCodec Models
18
+
19
+ | Feature Type | Speech Data | RepCodec Model |
20
+ |-----------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------|----------------------------------------------------------------------------------------------------------|
21
+ | [HuBERT base](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 9 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_base_l9](https://drive.google.com/file/d/1XD0HKl607FFjri2-VJT7lHQeSpxsCCFO/view?usp=sharing) |
22
+ | [HuBERT large](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_large_l18](https://drive.google.com/file/d/1mTbm5GeJ7gp_5L3QLP-JGXdf8RnRw5n6/view?usp=sharing) |
23
+ | [data2vec base](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 6 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_base_l6](https://drive.google.com/file/d/1d8sf3Ko_fYM9zlaiwxK_4xusLRKV5EMd/view?usp=sharing) |
24
+ | [data2vec large](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_large_l18](https://drive.google.com/file/d/1nuRIHaejT-uVi4cluftbT8o_JZqar5SU/view?usp=sharing) |
25
+ | [Whisper medium](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 24 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_medium_l24](https://drive.google.com/file/d/1V6YJSA2V4iywXrecJAN0oqsa3aHowexZ/view?usp=sharing) |
26
+ | [Whisper large-v2](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 32 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_large_l32](https://drive.google.com/file/d/1k_X7ZMPg8iOeDrIJe70v6CHfFygzufXC/view?usp=sharing) |
27
+
28
+ ## Speech Tokenization Using Pre-Trained Models
29
+
30
+ ### Installation
31
+
32
+ Please first install RepCodec by
33
+
34
+ ```
35
+ git clone https://github.com/mct10/RepCodec.git
36
+ cd RepCodec
37
+ pip install .
38
+ ```
39
+
40
+ We used Python 3.9.18 and PyTorch 1.12.1 to test the usage, but the code should be compatible with other recent Python
41
+ and PyTorch versions.
42
+
43
+ ### Representation Preparation
44
+
45
+ We adapt the `dump_hubert_feature.py` script
46
+ from [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans#hubert-feature)
47
+ to support dumping representations from **data2vec**, **HuBERT**, or **Whisper** encoders.
48
+
49
+ If you use our script (`examples/dump_feature.py`), please also install the following packages:
50
+
51
+ ```
52
+ pip install npy_append_array soundfile
53
+ ```
54
+
55
+ Additionally, if you want to dump representations from
56
+
57
+ - **data2vec** or **HuBERT**: please
58
+ follow [fairseq's instruction](https://github.com/facebookresearch/fairseq#requirements-and-installation) to install
59
+ the latest fairseq.
60
+
61
+ - **Whisper**: please follow [Whispers'instruction](https://github.com/openai/whisper/tree/main#setup) to install the
62
+ latest
63
+ Whisper.
64
+
65
+ Then, you can follow the given examples to dump representations:
66
+
67
+ ```
68
+ # Example 1: dump from HuBERT base layer 9
69
+ # (for data2vec, simply change "model_type" to data2vec and "ckpt_path" to the path of data2vec model)
70
+
71
+ layer=9
72
+
73
+ python3 examples/dump_feature.py \
74
+ --model_type hubert \
75
+ --tsv_path /path/to/tsv/file \
76
+ --ckpt_path /path/to/HuBERT/model \
77
+ --layer ${layer} \
78
+ --feat_dir /dir/to/save/representations
79
+
80
+
81
+ # Example 2: dump from Whisper medium layer 24
82
+
83
+ layer=24
84
+
85
+ python3 examples/dump_feature.py \
86
+ --model_type whisper \
87
+ --tsv_path /path/to/tsv/file \
88
+ --whisper_root /directory/to/save/whisper/model \
89
+ --whisper_name medium \
90
+ --layer ${layer} \
91
+ --feat_dir /dir/to/save/representations
92
+ ```
93
+
94
+ Explanations about the args:
95
+
96
+ - **model_type:** choose from `data2vec`, `hubert`, and `whisper`.
97
+
98
+ - **tsv_path:** path of the tsv file.
99
+ Should have the format of
100
+
101
+ ```
102
+ /dir/to/dataset
103
+ path_of_utterance_1 number_of_frames
104
+ path_of_utterance_2 number_of_frames
105
+ ```
106
+
107
+ You can follow [this script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py)
108
+ to generate the tsv file.
109
+
110
+ For example, by running
111
+
112
+ ```
113
+ python wav2vec_manifest.py \
114
+ /dir/to/LibriSpeech/dev-clean \
115
+ --dest /dir/to/manifest \
116
+ --ext flac \
117
+ --valid-percent 0
118
+ ```
119
+
120
+ you can obtain the `dev-clean.tsv` in `/dir/to/manifest` for LibriSpeech. (By default, the output file name
121
+ is `train.tsv`. Remember to rename the file.)
122
+
123
+ It should be similar to:
124
+
125
+ ```
126
+ /dir/to/LibriSpeech/dev-clean
127
+ 2277/149896/2277-149896-0026.flac 78720
128
+ 2277/149896/2277-149896-0005.flac 89600
129
+ 2277/149896/2277-149896-0033.flac 45520
130
+ ```
131
+
132
+ - **ckpt_path**:
133
+ must provide for data2vec and HuBERT.
134
+ You need to download the model
135
+ from [data2vec website](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2)
136
+ or [HuBERT website](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models)
137
+ yourself.
138
+ `--ckpt_path` is the path of the data2vec/HuBERT model.
139
+ - **whisper_root** and **whisper_name**:
140
+ must provide **BOTH** `--whisper_root` and `--whisper_name` for Whisper.
141
+ If there is no corresponding model in `--whisper_root`, the script will download for you.
142
+
143
+ - **layer**:
144
+ which Transformer encoder layer of the model should the representations be extracted from.
145
+ It is **1-based**.
146
+ For example, if layer=9, then the outputs from the 9<sup>th</sup> Transformer encoder layer are dumped.
147
+ Range: [1, number of Transformer encoder layers]
148
+
149
+ - **feat_dir**: The output representations will be saved to `${feat_dir}/0_1.npy`
150
+ and `${feat_dir}/0_1.len`.
151
+
152
+ For other useful functionalities (e.g., sharding), please check the argument list in `examples/dump_feature.py`.
153
+
154
+ ### Command Line Usage
155
+
156
+ We expect to have `${feat_dir}/0_1.npy` and `${feat_dir}/0_1.len` in the provided
157
+ directory `/dir/to/representaitons`.
158
+
159
+ Also, the tsv file should be the **same** as the one used in [Representation Preparation](#representation-preparation).
160
+
161
+ ```
162
+ repcodec /dir/to/representaitons \
163
+ --model /path/to/repcodec/model \
164
+ --tsv_path /path/to/tsv/file \
165
+ [--model_config_path /path/to/train/config] \
166
+ [--use_gpu] \
167
+ [--out_dir /path/to/output]
168
+ ```
169
+
170
+ If you trained the model yourself following [Training New RepCodec Models](#training-new-repcodec-models),
171
+ please provide the training config file using `--model_config_path`.
172
+ If you use the model we provide [here](#repcodec-models), then you do not have to provide that.
173
+
174
+ This command will tokenize the representations and the output discrete tokens will be saved to `${out_dir}/tokens`.
175
+ The tokens are in the same order as the provided tsv file.
176
+
177
+ An example of the output file:
178
+
179
+ ```
180
+ /dir/to/LibriSpeech/dev-clean
181
+ 2277/149896/2277-149896-0026.flac 696 696 198 198 198 498 ...
182
+ 2277/149896/2277-149896-0005.flac 696 696 198 198 198 907 ...
183
+ 2277/149896/2277-149896-0033.flac 696 696 198 198 198 696 ...
184
+ ```
185
+
186
+ Under `examples/tokens`, we provide some token files as references. They are obtained from LibriSpeech dev-clean subset
187
+ using the 6 types of representations and corresponding [RepCodec Models](#repcodec-models).
188
+ Your results should be very similar to ours.
189
+
190
+ ### Python Usage
191
+
192
+ ```python
193
+ import torch
194
+ import yaml
195
+
196
+ from repcodec.RepCodec import RepCodec
197
+
198
+ # for feature types of HubERT base & data2vec base, please use repcodec_dim768.yaml;
199
+ # for feature types of HuBERT large & data2vec large & Whisper medium, please use repcodec_dim1024.yaml;
200
+ # for feature types of Whisper large-v2, please use repcodec_dim1280.yaml
201
+ config = "repcodec/configs/repcodec_dim768.yaml"
202
+ with open(config) as fp:
203
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
204
+
205
+ model = RepCodec(**conf)
206
+ model.load_state_dict(torch.load("./hubert_base_l9.pkl", map_location="cpu")["model"]["repcodec"])
207
+ model.quantizer.initial()
208
+ model.eval()
209
+
210
+ # input shape: (batch size, hidden dim, sequence length)
211
+ random_features = torch.randn(size=(1, 768, 100))
212
+ with torch.no_grad():
213
+ x = model.encoder(random_features)
214
+ z = model.projector(x)
215
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
216
+ tokens = idx.cpu().data.numpy().tolist()[0]
217
+ ```
218
+
219
+ ## Training New RepCodec Models
220
+
221
+ We use a config file to set up all the training configurations, e.g., data, model architecture,
222
+ optimizer, scheduler.
223
+ We provide an example [here](./train_configs/ex_dim768_mse.yaml).
224
+
225
+ Please first install required packages following [Installation](#installation)
226
+ and prepare the representations following [Representation Preparation](#representation-preparation).
227
+
228
+ The input data directory is expected to have the following structure
229
+ ```
230
+ /dir/to/representations/
231
+ train_set_name/
232
+ 0_1.npy
233
+ 0_1.len
234
+ valid_set_name/
235
+ 0_1.npy
236
+ 0_1.len
237
+ test_set_name/
238
+ 0_1.npy
239
+ 0_1.len
240
+ ```
241
+
242
+ The names of subsets should be the same as the fields in the config file.
243
+
244
+ Then, you can run training by
245
+ ```
246
+ python train.py \
247
+ -c /path/to/config/file \
248
+ --tag $tag \
249
+ --exp_root exp
250
+ ```
251
+
252
+ `tag` is the name of the output folder.
253
+ All outputs will be saved to `exp_root/tag/`.
254
+
255
+ ## Acknowledge
256
+
257
+ Our implementation is based on [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec).
258
+ We thank them for open-sourcing their code!
259
+
260
+ ## Citation
261
+
262
+ If you find our work useful, please cite the following article.
263
+
264
+ ```
265
+ @misc{huang2023repcodec,
266
+ title={RepCodec: A Speech Representation Codec for Speech Tokenization},
267
+ author={Zhichao Huang and Chutong Meng and Tom Ko},
268
+ year={2023},
269
+ eprint={2309.00169},
270
+ archivePrefix={arXiv},
271
+ primaryClass={eess.AS}
272
+ }
273
+ ```
prompting/RepCodec/dataloader/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .collater import *
2
+ from .dataset import *
prompting/RepCodec/dataloader/collater.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class ReprCollater(object):
13
+ def __call__(self, batch):
14
+ xs = []
15
+ for b in batch:
16
+ if b is not None:
17
+ xs.append(b)
18
+
19
+ x_batch = np.stack(xs, axis=0)
20
+ x_batch = torch.tensor(x_batch, dtype=torch.float).transpose(1, 2) # (B, T, C) -> (B, C, T)
21
+
22
+ return x_batch
prompting/RepCodec/dataloader/dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import glob
9
+ import logging
10
+ import os
11
+ from typing import List
12
+
13
+ import numpy as np
14
+ from torch.utils.data import Dataset
15
+
16
+ logging.basicConfig(
17
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
18
+ datefmt="%Y-%m-%d %H:%M:%S",
19
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
20
+ )
21
+ logger = logging.getLogger("dataset")
22
+
23
+
24
+ class ReprDataset(Dataset):
25
+ def __init__(
26
+ self,
27
+ data_dir: str,
28
+ batch_len: int,
29
+ ):
30
+ self.batch_len = batch_len
31
+
32
+ self.blocks = self._load_blocks(data_dir)
33
+ self.offsets = self._load_offsets(data_dir)
34
+ assert len(self.blocks) == len(self.offsets)
35
+ # check len
36
+ for i in range(len(self.blocks)):
37
+ assert self.blocks[i].shape[0] == self.offsets[i][-1]
38
+
39
+ self.n_examples = np.cumsum([0] + [offset.shape[0] - 1 for offset in self.offsets])
40
+
41
+ def __len__(self):
42
+ return self.n_examples[-1]
43
+
44
+ def __getitem__(self, idx):
45
+ # find which block
46
+ block_id = -1
47
+ for n in range(len(self.n_examples) - 1):
48
+ if self.n_examples[n] <= idx < self.n_examples[n + 1]:
49
+ block_id = n
50
+ break
51
+ assert 0 <= block_id < len(self.blocks), f"Failed to find {idx}"
52
+ block_offset = idx - self.n_examples[block_id]
53
+ start = self.offsets[block_id][block_offset]
54
+ end = self.offsets[block_id][block_offset + 1]
55
+
56
+ # randomly choose a slice
57
+ if end - start < self.batch_len:
58
+ return None
59
+ elif end - start == self.batch_len:
60
+ return self.blocks[block_id][start:end]
61
+ else:
62
+ start_offset = np.random.randint(low=start, high=end - self.batch_len)
63
+ return self.blocks[block_id][start_offset:start_offset + self.batch_len]
64
+
65
+ @staticmethod
66
+ def _load_blocks(feat_dir: str) -> List[np.ndarray]:
67
+ # e.g., 0_2.npy, 1_2.npy
68
+ file_names = glob.glob(os.path.join(feat_dir, "*.npy"), recursive=False)
69
+ # sort by index
70
+ file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0]))
71
+ logger.info(f"Found following blocks: {file_names}")
72
+ blocks = [np.load(name, mmap_mode="r") for name in file_names]
73
+ return blocks
74
+
75
+ @staticmethod
76
+ def _load_offsets(feat_dir: str):
77
+ def load_lens(file_name: str):
78
+ with open(file_name, mode="r") as fp:
79
+ res = fp.read().strip().split("\n")
80
+ # for easy use. [res[i], res[i+1]) denotes the range for ith element
81
+ res = [0] + [int(r) for r in res]
82
+ return np.cumsum(res, dtype=int)
83
+
84
+ # e.g., 0_2.len, 1_2.len
85
+ file_names = glob.glob(os.path.join(feat_dir, "*.len"), recursive=False)
86
+ file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0]))
87
+ file_lens = []
88
+ for name in file_names:
89
+ file_lens.append(load_lens(name))
90
+ return file_lens
prompting/RepCodec/examples/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "72bf1b45-66fd-450d-8d5c-bec9e0b3d08f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from data2vec_feature_reader import Data2vecFeatureReader\n",
11
+ "\n",
12
+ "reader = Data2vecFeatureReader(\"./../../models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "84a9d238-048a-4772-a47b-5aadc50f36df",
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "data": {
23
+ "application/vnd.jupyter.widget-view+json": {
24
+ "model_id": "490421d1c2f54cca9855f1a5397185f8",
25
+ "version_major": 2,
26
+ "version_minor": 0
27
+ },
28
+ "text/plain": [
29
+ "Loading dataset shards: 0%| | 0/45 [00:00<?, ?it/s]"
30
+ ]
31
+ },
32
+ "metadata": {},
33
+ "output_type": "display_data"
34
+ },
35
+ {
36
+ "data": {
37
+ "application/vnd.jupyter.widget-view+json": {
38
+ "model_id": "be44942581b34d5388b0264e7b40d472",
39
+ "version_major": 2,
40
+ "version_minor": 0
41
+ },
42
+ "text/plain": [
43
+ "Loading dataset shards: 0%| | 0/60 [00:00<?, ?it/s]"
44
+ ]
45
+ },
46
+ "metadata": {},
47
+ "output_type": "display_data"
48
+ }
49
+ ],
50
+ "source": [
51
+ "from datasets import load_dataset\n",
52
+ "from tqdm import tqdm\n",
53
+ "import pandas as pd\n",
54
+ "\n",
55
+ "cache_dir = \"./../../../cache\"\n",
56
+ "\n",
57
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "id": "cffd49ca-3524-4ac4-8ba5-bc4fcc9e0f53",
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "data": {
68
+ "text/plain": [
69
+ "RepCodec(\n",
70
+ " (encoder): Encoder(\n",
71
+ " (conv): Conv1d(\n",
72
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
73
+ " )\n",
74
+ " (conv_blocks): ModuleList(\n",
75
+ " (0-1): 2 x EncoderBlock(\n",
76
+ " (res_units): ModuleList(\n",
77
+ " (0-1): 2 x ResidualUnit(\n",
78
+ " (activation): ELU(alpha=1.0)\n",
79
+ " (conv1): Conv1d(\n",
80
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
81
+ " )\n",
82
+ " (conv2): Conv1d1x1(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n",
83
+ " )\n",
84
+ " )\n",
85
+ " (conv): Conv1d(\n",
86
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
87
+ " )\n",
88
+ " )\n",
89
+ " )\n",
90
+ " )\n",
91
+ " (decoder): Decoder(\n",
92
+ " (conv1): Conv1d(\n",
93
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
94
+ " )\n",
95
+ " (conv_blocks): ModuleList(\n",
96
+ " (0-1): 2 x DecoderBlock(\n",
97
+ " (conv): Conv1d(\n",
98
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
99
+ " )\n",
100
+ " (res_units): ModuleList(\n",
101
+ " (0-1): 2 x ResidualUnit(\n",
102
+ " (activation): ELU(alpha=1.0)\n",
103
+ " (conv1): Conv1d(\n",
104
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
105
+ " )\n",
106
+ " (conv2): Conv1d1x1(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n",
107
+ " )\n",
108
+ " )\n",
109
+ " )\n",
110
+ " )\n",
111
+ " (conv2): Conv1d(\n",
112
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
113
+ " )\n",
114
+ " )\n",
115
+ " (projector): Projector(\n",
116
+ " (project): Conv1d(\n",
117
+ " (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
118
+ " )\n",
119
+ " )\n",
120
+ " (quantizer): Quantizer(\n",
121
+ " (codebook): ResidualVQ(\n",
122
+ " (layers): ModuleList(\n",
123
+ " (0): VectorQuantize()\n",
124
+ " )\n",
125
+ " )\n",
126
+ " )\n",
127
+ ")"
128
+ ]
129
+ },
130
+ "execution_count": 3,
131
+ "metadata": {},
132
+ "output_type": "execute_result"
133
+ }
134
+ ],
135
+ "source": [
136
+ "from repcodec.RepCodec import RepCodec\n",
137
+ "import torch\n",
138
+ "import yaml\n",
139
+ "\n",
140
+ "config = \"./../repcodec/configs/repcodec_dim1024.yaml\"\n",
141
+ "with open(config) as fp:\n",
142
+ " conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
143
+ "\n",
144
+ "model = RepCodec(**conf)\n",
145
+ "model.load_state_dict(torch.load(\"./../../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
146
+ "model.quantizer.initial()\n",
147
+ "model.eval()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 22,
153
+ "id": "a9a1731e-052c-4af0-a29c-b171a988b300",
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "ename": "RuntimeError",
158
+ "evalue": "Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same",
159
+ "output_type": "error",
160
+ "traceback": [
161
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
162
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
163
+ "Cell \u001b[0;32mIn[22], line 27\u001b[0m\n\u001b[1;32m 23\u001b[0m feat\u001b[38;5;241m.\u001b[39mappend(feat_chunk)\n\u001b[1;32m 25\u001b[0m features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(feat, \u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 27\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m z \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mprojector(x)\n\u001b[1;32m 29\u001b[0m _, idx \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mquantizer\u001b[38;5;241m.\u001b[39mcodebook\u001b[38;5;241m.\u001b[39mforward_index(z\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m))\n",
164
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
165
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
166
+ "File \u001b[0;32m/jupyter_workspace/users/Darshan/RepCodec/repcodec/modules/encoder.py:86\u001b[0m, in \u001b[0;36mEncoder.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 86\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_blocks):\n\u001b[1;32m 88\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv_blocks[i](x)\n",
167
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
168
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
169
+ "File \u001b[0;32m/jupyter_workspace/users/Darshan/RepCodec/repcodec/layers/conv_layer.py:55\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124;03m x (Tensor): Float tensor variable with the shape (B, C, T).\u001b[39;00m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;124;03m Returns:\u001b[39;00m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;124;03m Tensor: Float tensor variable with the shape (B, C, T).\u001b[39;00m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 55\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
170
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
171
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
172
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:310\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
173
+ "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:306\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 304\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 305\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 307\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
174
+ "\u001b[0;31mRuntimeError\u001b[0m: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same"
175
+ ]
176
+ }
177
+ ],
178
+ "source": [
179
+ "import torch.nn.functional as F\n",
180
+ "\n",
181
+ "sample = dataset[\"train.clean.100\"][1]\n",
182
+ "\n",
183
+ "x = sample[\"audio\"][\"array\"]\n",
184
+ "\n",
185
+ "with torch.no_grad():\n",
186
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
187
+ " if reader.task.cfg.normalize:\n",
188
+ " x = F.layer_norm(x, x.shape)\n",
189
+ " x = x.view(1, -1)\n",
190
+ "\n",
191
+ " feat = []\n",
192
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
193
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
194
+ " res = reader.model.extract_features(\n",
195
+ " source=x_chunk,\n",
196
+ " padding_mask=None,\n",
197
+ " mask=False,\n",
198
+ " layer=reader.layer,\n",
199
+ " )\n",
200
+ " feat_chunk = res[\"x\"]\n",
201
+ " feat.append(feat_chunk)\n",
202
+ " \n",
203
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
204
+ "\n",
205
+ " x = model.encoder(features)\n",
206
+ " z = model.projector(x)\n",
207
+ " _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
208
+ " tokens = idx.cpu().data.numpy().tolist()[0]"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 14,
214
+ "id": "d51709a9-6fb3-450b-a517-005367095663",
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "data": {
219
+ "text/plain": [
220
+ "torch.Size([1, 804, 1024])"
221
+ ]
222
+ },
223
+ "execution_count": 14,
224
+ "metadata": {},
225
+ "output_type": "execute_result"
226
+ }
227
+ ],
228
+ "source": [
229
+ "features.shape"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": 8,
235
+ "id": "dfc977d7-f27c-40d7-b545-fbdf26728cbe",
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "data": {
240
+ "text/plain": [
241
+ "torch.Size([726, 1024])"
242
+ ]
243
+ },
244
+ "execution_count": 8,
245
+ "metadata": {},
246
+ "output_type": "execute_result"
247
+ }
248
+ ],
249
+ "source": [
250
+ "feat.shape"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "id": "1810e6dc-2ece-4aca-a29a-e1933b8ce82a",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "import logging\n",
261
+ "import os\n",
262
+ "import sys\n",
263
+ "\n",
264
+ "import tqdm\n",
265
+ "from npy_append_array import NpyAppendArray\n",
266
+ "\n",
267
+ "def get_shard_range(tot, nshard, rank):\n",
268
+ " assert rank < nshard and rank >= 0, f\"invaid rank/nshard {rank}/{nshard}\"\n",
269
+ " start = round(tot / nshard * rank)\n",
270
+ " end = round(tot / nshard * (rank + 1))\n",
271
+ " assert start < end, f\"start={start}, end={end}\"\n",
272
+ " logger.info(\n",
273
+ " f\"rank {rank} of {nshard}, process {end-start} \"\n",
274
+ " f\"({start}-{end}) out of {tot}\"\n",
275
+ " )\n",
276
+ " return start, end\n",
277
+ "\n",
278
+ "def get_path_iterator(tsv, nshard, rank):\n",
279
+ " with open(tsv, \"r\") as f:\n",
280
+ " root = f.readline().rstrip()\n",
281
+ " lines = [line.rstrip() for line in f]\n",
282
+ " start, end = get_shard_range(len(lines), nshard, rank)\n",
283
+ " lines = lines[start:end]\n",
284
+ " def iterate():\n",
285
+ " for line in lines:\n",
286
+ " subpath, nsample = line.split(\"\\t\")\n",
287
+ " yield f\"{root}/{subpath}\", int(nsample)\n",
288
+ " return iterate, len(lines)\n",
289
+ "\n",
290
+ "def dump_feature(reader, generator, num, nshard, rank, feat_dir):\n",
291
+ " iterator = generator()\n",
292
+ "\n",
293
+ " feat_path = f\"{feat_dir}/{rank}_{nshard}.npy\"\n",
294
+ " leng_path = f\"{feat_dir}/{rank}_{nshard}.len\"\n",
295
+ "\n",
296
+ " os.makedirs(feat_dir, exist_ok=True)\n",
297
+ " if os.path.exists(feat_path):\n",
298
+ " os.remove(feat_path)\n",
299
+ "\n",
300
+ " feat_f = NpyAppendArray(feat_path)\n",
301
+ " with open(leng_path, \"w\") as leng_f:\n",
302
+ " for path, nsample in tqdm.tqdm(iterator, total=num):\n",
303
+ " feat = reader.get_feats(path, nsample)\n",
304
+ " feat_f.append(feat.cpu().numpy())\n",
305
+ " leng_f.write(f\"{len(feat)}\\n\")\n",
306
+ " logger.info(\"finished successfully\")\n",
307
+ "\n",
308
+ "generator, num = get_path_iterator(tsv_path, nshard, rank)\n",
309
+ "dump_feature(reader, generator, num, nshard, rank, feat_dir)"
310
+ ]
311
+ }
312
+ ],
313
+ "metadata": {
314
+ "kernelspec": {
315
+ "display_name": "Python 3 (ipykernel)",
316
+ "language": "python",
317
+ "name": "python3"
318
+ },
319
+ "language_info": {
320
+ "codemirror_mode": {
321
+ "name": "ipython",
322
+ "version": 3
323
+ },
324
+ "file_extension": ".py",
325
+ "mimetype": "text/x-python",
326
+ "name": "python",
327
+ "nbconvert_exporter": "python",
328
+ "pygments_lexer": "ipython3",
329
+ "version": "3.8.10"
330
+ }
331
+ },
332
+ "nbformat": 4,
333
+ "nbformat_minor": 5
334
+ }
prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_audio-checkpoint.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
9
+
10
+ import logging
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ from omegaconf import II
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.distributed as dist
21
+
22
+ from fairseq.modules import EMAModule, EMAModuleConfig
23
+ from fairseq.data.data_utils import compute_mask_indices
24
+ from fairseq.models import BaseFairseqModel, register_model
25
+ from fairseq.models.wav2vec import (
26
+ ConvFeatureExtractionModel,
27
+ Wav2Vec2Config,
28
+ TransformerEncoder,
29
+ )
30
+ from fairseq.modules import (
31
+ GradMultiply,
32
+ LayerNorm,
33
+ )
34
+ from fairseq.utils import index_put
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class Data2VecAudioConfig(Wav2Vec2Config):
42
+
43
+ loss_beta: float = field(
44
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
45
+ )
46
+ loss_scale: Optional[float] = field(
47
+ default=None,
48
+ metadata={
49
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
50
+ },
51
+ )
52
+ average_top_k_layers: int = field(
53
+ default=8, metadata={"help": "how many layers to average"}
54
+ )
55
+
56
+ layer_norm_target_layer: bool = False
57
+ instance_norm_target_layer: bool = False
58
+ instance_norm_targets: bool = False
59
+ layer_norm_targets: bool = False
60
+ batch_norm_target_layer: bool = False
61
+ group_norm_target_layer: bool = False
62
+
63
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
64
+ ema_end_decay: float = field(
65
+ default=0.9999, metadata={"help": "final ema decay rate"}
66
+ )
67
+
68
+ # when to finish annealing ema decay rate
69
+ ema_anneal_end_step: int = II("optimization.max_update")
70
+
71
+ ema_transformer_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer"},
74
+ )
75
+ ema_layers_only: bool = field(
76
+ default=True,
77
+ metadata={"help": "whether to momentum update only the transformer layers"},
78
+ )
79
+
80
+ max_update: int = II("optimization.max_update")
81
+
82
+ min_target_var: float = field(
83
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
84
+ )
85
+ min_pred_var: float = field(
86
+ default=0.01,
87
+ metadata={"help": "stop training if prediction var falls below this"},
88
+ )
89
+
90
+
91
+ def get_annealed_rate(start, end, curr_step, total_steps):
92
+ r = end - start
93
+ pct_remaining = 1 - curr_step / total_steps
94
+ return end - r * pct_remaining
95
+
96
+
97
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
98
+ class Data2VecAudioModel(BaseFairseqModel):
99
+ def __init__(self, cfg: Data2VecAudioConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ feature_enc_layers = eval(cfg.conv_feature_layers)
104
+ self.extractor_embed = feature_enc_layers[-1][0]
105
+
106
+ self.ema = None
107
+ self.embed = cfg.encoder_embed_dim
108
+
109
+ self.average_top_k_layers = cfg.average_top_k_layers
110
+ self.loss_beta = cfg.loss_beta
111
+ self.loss_scale = cfg.loss_scale
112
+
113
+ self.feature_extractor = ConvFeatureExtractionModel(
114
+ conv_layers=feature_enc_layers,
115
+ dropout=0.0,
116
+ mode=cfg.extractor_mode,
117
+ conv_bias=cfg.conv_bias,
118
+ )
119
+
120
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
121
+
122
+ self.mask_prob = cfg.mask_prob
123
+ self.mask_selection = cfg.mask_selection
124
+ self.mask_other = cfg.mask_other
125
+ self.mask_length = cfg.mask_length
126
+ self.no_mask_overlap = cfg.no_mask_overlap
127
+ self.mask_min_space = cfg.mask_min_space
128
+
129
+ self.mask_channel_prob = cfg.mask_channel_prob
130
+ self.mask_channel_before = cfg.mask_channel_before
131
+ self.mask_channel_selection = cfg.mask_channel_selection
132
+ self.mask_channel_other = cfg.mask_channel_other
133
+ self.mask_channel_length = cfg.mask_channel_length
134
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
135
+ self.mask_channel_min_space = cfg.mask_channel_min_space
136
+
137
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
138
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
139
+
140
+ self.feature_grad_mult = cfg.feature_grad_mult
141
+
142
+ self.mask_emb = nn.Parameter(
143
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
144
+ )
145
+
146
+ self.encoder = TransformerEncoder(cfg)
147
+ self.layer_norm = LayerNorm(self.extractor_embed)
148
+
149
+ self.final_proj = nn.Linear(self.embed, self.embed)
150
+
151
+ self.num_updates = 0
152
+
153
+ def make_ema_teacher(self):
154
+ ema_config = EMAModuleConfig(
155
+ ema_decay=self.cfg.ema_decay,
156
+ ema_fp32=True,
157
+ )
158
+ skip_keys = set()
159
+ if self.cfg.ema_layers_only:
160
+ self.cfg.ema_transformer_only = True
161
+ for k, _ in self.encoder.pos_conv.named_parameters():
162
+ skip_keys.add(f"pos_conv.{k}")
163
+
164
+ self.ema = EMAModule(
165
+ self.encoder if self.cfg.ema_transformer_only else self,
166
+ ema_config,
167
+ skip_keys=skip_keys,
168
+ )
169
+
170
+ def set_num_updates(self, num_updates):
171
+ super().set_num_updates(num_updates)
172
+
173
+ if self.ema is None and self.final_proj is not None:
174
+ logger.info(f"making ema teacher")
175
+ self.make_ema_teacher()
176
+ elif self.training and self.ema is not None:
177
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
178
+ if num_updates >= self.cfg.ema_anneal_end_step:
179
+ decay = self.cfg.ema_end_decay
180
+ else:
181
+ decay = get_annealed_rate(
182
+ self.cfg.ema_decay,
183
+ self.cfg.ema_end_decay,
184
+ num_updates,
185
+ self.cfg.ema_anneal_end_step,
186
+ )
187
+ self.ema.set_decay(decay)
188
+ if self.ema.get_decay() < 1:
189
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
190
+
191
+ self.num_updates = num_updates
192
+
193
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
194
+ state = super().state_dict(destination, prefix, keep_vars)
195
+
196
+ if self.ema is not None:
197
+ state[prefix + "_ema"] = self.ema.fp32_params
198
+
199
+ return state
200
+
201
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
202
+ if self.ema is not None:
203
+ k = prefix + "_ema"
204
+ assert k in state_dict
205
+ self.ema.restore(state_dict[k], True)
206
+ del state_dict[k]
207
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
+
209
+ @classmethod
210
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
211
+ """Build a new model instance."""
212
+
213
+ return cls(cfg)
214
+
215
+ def apply_mask(
216
+ self,
217
+ x,
218
+ padding_mask,
219
+ mask_indices=None,
220
+ mask_channel_indices=None,
221
+ ):
222
+ B, T, C = x.shape
223
+
224
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
225
+ mask_channel_indices = compute_mask_indices(
226
+ (B, C),
227
+ None,
228
+ self.mask_channel_prob,
229
+ self.mask_channel_length,
230
+ self.mask_channel_selection,
231
+ self.mask_channel_other,
232
+ no_overlap=self.no_mask_channel_overlap,
233
+ min_space=self.mask_channel_min_space,
234
+ )
235
+ mask_channel_indices = (
236
+ torch.from_numpy(mask_channel_indices)
237
+ .to(x.device)
238
+ .unsqueeze(1)
239
+ .expand(-1, T, -1)
240
+ )
241
+ x[mask_channel_indices] = 0
242
+
243
+ if self.mask_prob > 0:
244
+ if mask_indices is None:
245
+ mask_indices = compute_mask_indices(
246
+ (B, T),
247
+ padding_mask,
248
+ self.mask_prob,
249
+ self.mask_length,
250
+ self.mask_selection,
251
+ self.mask_other,
252
+ min_masks=1,
253
+ no_overlap=self.no_mask_overlap,
254
+ min_space=self.mask_min_space,
255
+ require_same_masks=self.cfg.require_same_masks,
256
+ mask_dropout=self.cfg.mask_dropout,
257
+ )
258
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
259
+ x = index_put(x, mask_indices, self.mask_emb)
260
+ else:
261
+ mask_indices = None
262
+
263
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
264
+ if mask_channel_indices is None:
265
+ mask_channel_indices = compute_mask_indices(
266
+ (B, C),
267
+ None,
268
+ self.mask_channel_prob,
269
+ self.mask_channel_length,
270
+ self.mask_channel_selection,
271
+ self.mask_channel_other,
272
+ no_overlap=self.no_mask_channel_overlap,
273
+ min_space=self.mask_channel_min_space,
274
+ )
275
+ mask_channel_indices = (
276
+ torch.from_numpy(mask_channel_indices)
277
+ .to(x.device)
278
+ .unsqueeze(1)
279
+ .expand(-1, T, -1)
280
+ )
281
+ x = index_put(x, mask_channel_indices, 0)
282
+
283
+ return x, mask_indices
284
+
285
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
286
+ """
287
+ Computes the output length of the convolutional layers
288
+ """
289
+
290
+ def _conv_out_length(input_length, kernel_size, stride):
291
+ return torch.floor((input_length - kernel_size) / stride + 1)
292
+
293
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
294
+
295
+ for i in range(len(conv_cfg_list)):
296
+ input_lengths = _conv_out_length(
297
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
298
+ )
299
+
300
+ return input_lengths.to(torch.long)
301
+
302
+ def forward(
303
+ self,
304
+ source,
305
+ padding_mask=None,
306
+ mask=True,
307
+ features_only=False,
308
+ layer=None,
309
+ mask_indices=None,
310
+ mask_channel_indices=None,
311
+ padding_count=None,
312
+ ):
313
+ features = source
314
+
315
+ if self.feature_grad_mult > 0:
316
+ features = self.feature_extractor(features)
317
+ if self.feature_grad_mult != 1.0:
318
+ features = GradMultiply.apply(features, self.feature_grad_mult)
319
+ else:
320
+ with torch.no_grad():
321
+ features = self.feature_extractor(features)
322
+
323
+ features = features.transpose(1, 2)
324
+
325
+ features = self.layer_norm(features)
326
+
327
+ orig_padding_mask = padding_mask
328
+
329
+ if padding_mask is not None and padding_mask.any():
330
+ input_lengths = (1 - padding_mask.long()).sum(-1)
331
+ # apply conv formula to get real output_lengths
332
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
333
+
334
+ padding_mask = torch.zeros(
335
+ features.shape[:2], dtype=features.dtype, device=features.device
336
+ )
337
+
338
+ # these two operations makes sure that all values
339
+ # before the output lengths indices are attended to
340
+ padding_mask[
341
+ (
342
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
343
+ output_lengths - 1,
344
+ )
345
+ ] = 1
346
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
347
+ else:
348
+ padding_mask = None
349
+
350
+ if self.post_extract_proj is not None:
351
+ features = self.post_extract_proj(features)
352
+
353
+ pre_encoder_features = None
354
+ if self.cfg.ema_transformer_only:
355
+ pre_encoder_features = features.clone()
356
+
357
+ features = self.dropout_input(features)
358
+
359
+ if mask:
360
+ x, mask_indices = self.apply_mask(
361
+ features,
362
+ padding_mask,
363
+ mask_indices=mask_indices,
364
+ mask_channel_indices=mask_channel_indices,
365
+ )
366
+ else:
367
+ x = features
368
+ mask_indices = None
369
+
370
+ x, layer_results = self.encoder(
371
+ x,
372
+ padding_mask=padding_mask,
373
+ layer=layer,
374
+ )
375
+
376
+ if features_only:
377
+ return {
378
+ "x": x,
379
+ "padding_mask": padding_mask,
380
+ "layer_results": layer_results,
381
+ }
382
+
383
+ result = {
384
+ "losses": {},
385
+ }
386
+
387
+ with torch.no_grad():
388
+ self.ema.model.eval()
389
+
390
+ if self.cfg.ema_transformer_only:
391
+ y, layer_results = self.ema.model.extract_features(
392
+ pre_encoder_features,
393
+ padding_mask=padding_mask,
394
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
395
+ )
396
+ y = {
397
+ "x": y,
398
+ "padding_mask": padding_mask,
399
+ "layer_results": layer_results,
400
+ }
401
+ else:
402
+ y = self.ema.model.extract_features(
403
+ source=source,
404
+ padding_mask=orig_padding_mask,
405
+ mask=False,
406
+ )
407
+
408
+ target_layer_results = [l[2] for l in y["layer_results"]]
409
+
410
+ permuted = False
411
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
412
+ target_layer_results = [
413
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
414
+ ]
415
+ permuted = True
416
+
417
+ if self.cfg.batch_norm_target_layer:
418
+ target_layer_results = [
419
+ F.batch_norm(
420
+ tl.float(), running_mean=None, running_var=None, training=True
421
+ )
422
+ for tl in target_layer_results
423
+ ]
424
+
425
+ if self.cfg.instance_norm_target_layer:
426
+ target_layer_results = [
427
+ F.instance_norm(tl.float()) for tl in target_layer_results
428
+ ]
429
+
430
+ if permuted:
431
+ target_layer_results = [
432
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
433
+ ]
434
+
435
+ if self.cfg.group_norm_target_layer:
436
+ target_layer_results = [
437
+ F.layer_norm(tl.float(), tl.shape[-2:])
438
+ for tl in target_layer_results
439
+ ]
440
+
441
+ if self.cfg.layer_norm_target_layer:
442
+ target_layer_results = [
443
+ F.layer_norm(tl.float(), tl.shape[-1:])
444
+ for tl in target_layer_results
445
+ ]
446
+
447
+ y = sum(target_layer_results) / len(target_layer_results)
448
+
449
+ if self.cfg.layer_norm_targets:
450
+ y = F.layer_norm(y.float(), y.shape[-1:])
451
+
452
+ if self.cfg.instance_norm_targets:
453
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
454
+
455
+ if not permuted:
456
+ y = y.transpose(0, 1)
457
+
458
+ y = y[mask_indices]
459
+
460
+ x = x[mask_indices]
461
+ x = self.final_proj(x)
462
+
463
+ sz = x.size(-1)
464
+
465
+ if self.loss_beta == 0:
466
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
467
+ else:
468
+ loss = F.smooth_l1_loss(
469
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
470
+ ).sum(dim=-1)
471
+
472
+ if self.loss_scale is not None:
473
+ scale = self.loss_scale
474
+ else:
475
+ scale = 1 / math.sqrt(sz)
476
+
477
+ result["losses"]["regression"] = loss.sum() * scale
478
+
479
+ if "sample_size" not in result:
480
+ result["sample_size"] = loss.numel()
481
+
482
+ with torch.no_grad():
483
+ result["target_var"] = self.compute_var(y)
484
+ result["pred_var"] = self.compute_var(x.float())
485
+
486
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
487
+ logger.error(
488
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
489
+ )
490
+ raise Exception(
491
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
492
+ )
493
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
494
+ logger.error(
495
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
496
+ )
497
+ raise Exception(
498
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
499
+ )
500
+
501
+ if self.ema is not None:
502
+ result["ema_decay"] = self.ema.get_decay() * 1000
503
+
504
+ return result
505
+
506
+ @staticmethod
507
+ def compute_var(y):
508
+ y = y.view(-1, y.size(-1))
509
+ if dist.is_initialized():
510
+ zc = torch.tensor(y.size(0)).cuda()
511
+ zs = y.sum(dim=0)
512
+ zss = (y ** 2).sum(dim=0)
513
+
514
+ dist.all_reduce(zc)
515
+ dist.all_reduce(zs)
516
+ dist.all_reduce(zss)
517
+
518
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
519
+ return torch.sqrt(var + 1e-6).mean()
520
+ else:
521
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
522
+
523
+ def extract_features(
524
+ self, source, padding_mask, mask=False, layer=None
525
+ ):
526
+ res = self.forward(
527
+ source,
528
+ padding_mask,
529
+ mask=mask,
530
+ features_only=True,
531
+ layer=layer,
532
+ )
533
+ return res
534
+
535
+ def remove_pretraining_modules(self, last_layer=None):
536
+ self.final_proj = None
537
+ self.ema = None
538
+ if last_layer is not None:
539
+ self.encoder.layers = nn.ModuleList(
540
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
541
+ )
prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_feature_reader-checkpoint.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import tasks
13
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+ from omegaconf import OmegaConf
16
+
17
+ from data2vec_audio import Data2VecAudioModel
18
+
19
+ logger = logging.getLogger("dump_feature")
20
+
21
+
22
+ class Data2vecFeatureReader(object):
23
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
24
+ state = load_checkpoint_to_cpu(ckpt_path)
25
+ cfg = state["cfg"]
26
+ # load task
27
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
28
+ task.load_state_dict(state["task_state"])
29
+ # load model config
30
+ if "layer_type" not in cfg.model:
31
+ # fix a missing key
32
+ model_config = {k: v for k, v in cfg.model.items()}
33
+ model_config["layer_type"] = "transformer"
34
+ model_config = OmegaConf.create(model_config)
35
+ else:
36
+ model_config = cfg.model
37
+
38
+ # fix param name in the state
39
+ state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
40
+ state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
41
+ del state["model"]["_ema"]
42
+
43
+ # load model
44
+ model = Data2VecAudioModel.build_model(model_config)
45
+ model.load_state_dict(
46
+ state["model"], strict=True, model_cfg=model_config
47
+ )
48
+
49
+ self.device = device
50
+ logger.info(f"device = {self.device}")
51
+
52
+ self.model = model.eval().to(self.device)
53
+ self.task = task
54
+ self.layer = layer - 1 # make it 1-based
55
+ self.max_chunk = max_chunk
56
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
57
+ logger.info(f" max_chunk = {self.max_chunk}")
58
+
59
+ def read_audio(self, path, ref_len=None):
60
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
61
+ if wav.ndim == 2:
62
+ wav = wav.mean(-1)
63
+ assert wav.ndim == 1, wav.ndim
64
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
65
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
66
+ return wav
67
+
68
+ def get_feats(self, path, ref_len=None):
69
+ x = self.read_audio(path, ref_len=ref_len)
70
+ with torch.no_grad():
71
+ x = torch.from_numpy(x).float().to(self.device)
72
+ if self.task.cfg.normalize:
73
+ x = F.layer_norm(x, x.shape)
74
+ x = x.view(1, -1)
75
+
76
+ feat = []
77
+ for start in range(0, x.size(1), self.max_chunk):
78
+ x_chunk = x[:, start: start + self.max_chunk]
79
+ res = self.model.extract_features(
80
+ source=x_chunk,
81
+ padding_mask=None,
82
+ mask=False,
83
+ layer=self.layer,
84
+ )
85
+ feat_chunk = res["x"]
86
+ feat.append(feat_chunk)
87
+ return torch.cat(feat, 1).squeeze(0)
prompting/RepCodec/examples/.ipynb_checkpoints/dump_feature-checkpoint.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ from feature_utils import get_path_iterator, dump_feature
13
+
14
+ logging.basicConfig(
15
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
16
+ datefmt="%Y-%m-%d %H:%M:%S",
17
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
18
+ stream=sys.stdout,
19
+ )
20
+ logger = logging.getLogger("dump_feature")
21
+
22
+
23
+ def main(
24
+ model_type: str,
25
+ tsv_path: str,
26
+ ckpt_path: str,
27
+ whisper_root: str,
28
+ whisper_name: str,
29
+ layer: int,
30
+ nshard: int,
31
+ rank: int,
32
+ feat_dir: str,
33
+ max_chunk: int,
34
+ use_cpu: bool = False
35
+ ):
36
+ device = "cpu" if use_cpu else "cuda"
37
+
38
+ # some checks
39
+ if model_type in ["hubert", "data2vec"]:
40
+ assert ckpt_path and os.path.exists(ckpt_path)
41
+ elif model_type in ["whisper"]:
42
+ assert whisper_name and whisper_root
43
+ else:
44
+ raise ValueError(f"Unsupported model type {model_type}")
45
+
46
+ reader = None
47
+ if model_type == "hubert":
48
+ from hubert_feature_reader import HubertFeatureReader
49
+ reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
50
+ elif model_type == "data2vec":
51
+ from data2vec_feature_reader import Data2vecFeatureReader
52
+ reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
53
+ elif model_type == "whisper":
54
+ from whisper_feature_reader import WhisperFeatureReader
55
+ reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
56
+
57
+ assert reader is not None
58
+
59
+ generator, num = get_path_iterator(tsv_path, nshard, rank)
60
+ dump_feature(reader, generator, num, nshard, rank, feat_dir)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ import argparse
65
+
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument(
68
+ "--model_type",
69
+ required=True,
70
+ type=str,
71
+ choices=["data2vec", "hubert", "whisper"],
72
+ help="the type of the speech encoder."
73
+ )
74
+ parser.add_argument(
75
+ "--tsv_path",
76
+ required=True,
77
+ type=str,
78
+ help="the path to the tsv file."
79
+ )
80
+ parser.add_argument(
81
+ "--ckpt_path",
82
+ required=False,
83
+ type=str,
84
+ default=None,
85
+ help="path to the speech model. must provide for HuBERT and data2vec"
86
+ )
87
+ parser.add_argument(
88
+ "--whisper_root",
89
+ required=False,
90
+ type=str,
91
+ default=None,
92
+ help="root dir to download/store whisper model. must provide for whisper model."
93
+ )
94
+ parser.add_argument(
95
+ "--whisper_name",
96
+ required=False,
97
+ type=str,
98
+ default=None,
99
+ help="name of whisper model. e.g., large-v2. must provide for whisper model."
100
+ )
101
+ parser.add_argument(
102
+ "--layer",
103
+ required=True,
104
+ type=int,
105
+ help="which layer of the model. this is 1-based."
106
+ )
107
+ parser.add_argument(
108
+ "--feat_dir",
109
+ required=True,
110
+ type=str,
111
+ help="the output dir to save the representations."
112
+ )
113
+ parser.add_argument(
114
+ "--nshard",
115
+ required=False,
116
+ type=int,
117
+ default=1,
118
+ help="total number of shards."
119
+ )
120
+ parser.add_argument(
121
+ "--rank",
122
+ required=False,
123
+ type=int,
124
+ default=0,
125
+ help="shard id of this process."
126
+ )
127
+ parser.add_argument(
128
+ "--max_chunk",
129
+ type=int,
130
+ default=1600000,
131
+ help="max number of frames of each batch."
132
+ )
133
+ parser.add_argument(
134
+ "--use_cpu",
135
+ default=False,
136
+ action="store_true",
137
+ help="whether use cpu instead of gpu."
138
+ )
139
+ args = parser.parse_args()
140
+ logger.info(args)
141
+
142
+ main(**vars(args))
prompting/RepCodec/examples/.ipynb_checkpoints/feature_utils-checkpoint.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
9
+
10
+ import logging
11
+ import os
12
+ import sys
13
+
14
+ import tqdm
15
+ from npy_append_array import NpyAppendArray
16
+
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("feature_utils")
25
+
26
+
27
+ def get_shard_range(tot, nshard, rank):
28
+ assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
29
+ start = round(tot / nshard * rank)
30
+ end = round(tot / nshard * (rank + 1))
31
+ assert start < end, f"start={start}, end={end}"
32
+ logger.info(
33
+ f"rank {rank} of {nshard}, process {end-start} "
34
+ f"({start}-{end}) out of {tot}"
35
+ )
36
+ return start, end
37
+
38
+
39
+ def get_path_iterator(tsv, nshard, rank):
40
+ with open(tsv, "r") as f:
41
+ root = f.readline().rstrip()
42
+ lines = [line.rstrip() for line in f]
43
+ start, end = get_shard_range(len(lines), nshard, rank)
44
+ lines = lines[start:end]
45
+ def iterate():
46
+ for line in lines:
47
+ subpath, nsample = line.split("\t")
48
+ yield f"{subpath}", int(nsample)
49
+ return iterate, len(lines)
50
+
51
+
52
+ def dump_feature(reader, generator, num, nshard, rank, feat_dir):
53
+ iterator = generator()
54
+
55
+ feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
56
+ leng_path = f"{feat_dir}/{rank}_{nshard}.len"
57
+
58
+ os.makedirs(feat_dir, exist_ok=True)
59
+ if os.path.exists(feat_path):
60
+ os.remove(feat_path)
61
+
62
+ feat_f = NpyAppendArray(feat_path)
63
+ with open(leng_path, "w") as leng_f:
64
+ for path, nsample in tqdm.tqdm(iterator, total=num):
65
+ feat = reader.get_feats(path, nsample)
66
+ feat_f.append(feat.cpu().numpy())
67
+ leng_f.write(f"{len(feat)}\n")
68
+ logger.info("finished successfully")
69
+
70
+
prompting/RepCodec/examples/.ipynb_checkpoints/some_run-checkpoint.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+
5
+ cache_dir = "./../../../cache"
6
+
7
+ dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
8
+
9
+ from repcodec.RepCodec import RepCodec
10
+ import torch
11
+ import yaml
12
+
13
+ config = "./../repcodec/configs/repcodec_dim1024.yaml"
14
+ with open(config) as fp:
15
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
16
+
17
+ model = RepCodec(**conf)
18
+ model.load_state_dict(torch.load("./../../models/data2vec_large_l18.pkl", map_location="cuda:0")["model"]["repcodec"])
19
+ model.quantizer.initial()
20
+ model.eval()
21
+ model.to("cuda:0")
22
+
23
+ from data2vec_feature_reader import Data2vecFeatureReader
24
+
25
+ reader = Data2vecFeatureReader("./../../models/vox_pretrained.pt", 18, device="cuda:0", max_chunk=1600000)
26
+
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+
30
+ for split in dataset.keys():
31
+
32
+ tokens = []
33
+
34
+ for idx in tqdm(range(len(dataset[split]))):
35
+
36
+ sample = dataset[split][idx]
37
+
38
+ x = sample["audio"]["array"]
39
+
40
+ with torch.no_grad():
41
+ x = torch.from_numpy(x).float().to(reader.device)
42
+ if reader.task.cfg.normalize:
43
+ x = F.layer_norm(x, x.shape)
44
+ x = x.view(1, -1)
45
+
46
+ feat = []
47
+ for start in range(0, x.size(1), reader.max_chunk):
48
+ x_chunk = x[:, start: start + reader.max_chunk]
49
+ res = reader.model.extract_features(
50
+ source=x_chunk,
51
+ padding_mask=None,
52
+ mask=False,
53
+ layer=reader.layer,
54
+ )
55
+ feat_chunk = res["x"]
56
+ feat.append(feat_chunk)
57
+
58
+ features = torch.cat(feat, 1).permute(0, 2, 1)
59
+
60
+ x = model.encoder(features)
61
+ z = model.projector(x)
62
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
63
+ tkn = idx.detach().cpu().data.numpy()[0]
64
+
65
+ tokens.append(tkn)
66
+ np.savez(f"./tkns/{split}.npz", *tokens)
prompting/RepCodec/examples/Untitled.ipynb ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "72bf1b45-66fd-450d-8d5c-bec9e0b3d08f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from data2vec_feature_reader import Data2vecFeatureReader\n",
11
+ "\n",
12
+ "reader = Data2vecFeatureReader(\"./../../models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "84a9d238-048a-4772-a47b-5aadc50f36df",
19
+ "metadata": {},
20
+ "outputs": [
21
+ {
22
+ "data": {
23
+ "application/vnd.jupyter.widget-view+json": {
24
+ "model_id": "fb01bc434d964db08fde7f9f2c90ea3c",
25
+ "version_major": 2,
26
+ "version_minor": 0
27
+ },
28
+ "text/plain": [
29
+ "Loading dataset shards: 0%| | 0/45 [00:00<?, ?it/s]"
30
+ ]
31
+ },
32
+ "metadata": {},
33
+ "output_type": "display_data"
34
+ },
35
+ {
36
+ "data": {
37
+ "application/vnd.jupyter.widget-view+json": {
38
+ "model_id": "d4adc62013644ed0b16056aa217448a9",
39
+ "version_major": 2,
40
+ "version_minor": 0
41
+ },
42
+ "text/plain": [
43
+ "Loading dataset shards: 0%| | 0/60 [00:00<?, ?it/s]"
44
+ ]
45
+ },
46
+ "metadata": {},
47
+ "output_type": "display_data"
48
+ }
49
+ ],
50
+ "source": [
51
+ "from datasets import load_dataset\n",
52
+ "from tqdm import tqdm\n",
53
+ "import pandas as pd\n",
54
+ "\n",
55
+ "cache_dir = \"./../../../cache\"\n",
56
+ "\n",
57
+ "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 5,
63
+ "id": "cffd49ca-3524-4ac4-8ba5-bc4fcc9e0f53",
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "ename": "ImportError",
68
+ "evalue": "attempted relative import with no known parent package",
69
+ "output_type": "error",
70
+ "traceback": [
71
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
72
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
73
+ "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mRepCodec\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RepCodec\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n",
74
+ "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
75
+ ]
76
+ }
77
+ ],
78
+ "source": [
79
+ "from .RepCodec import RepCodec\n",
80
+ "import torch\n",
81
+ "import yaml\n",
82
+ "\n",
83
+ "config = \"./../repcodec/configs/repcodec_dim1024.yaml\"\n",
84
+ "with open(config) as fp:\n",
85
+ " conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
86
+ "\n",
87
+ "model = RepCodec(**conf)\n",
88
+ "model.load_state_dict(torch.load(\"./../../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
89
+ "model.quantizer.initial()\n",
90
+ "model.eval()\n",
91
+ "model.to(\"cuda:0\")"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "a9a1731e-052c-4af0-a29c-b171a988b300",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "import torch.nn.functional as F\n",
102
+ "\n",
103
+ "sample = dataset[\"train.clean.100\"][1]\n",
104
+ "\n",
105
+ "x = sample[\"audio\"][\"array\"]\n",
106
+ "\n",
107
+ "with torch.no_grad():\n",
108
+ " x = torch.from_numpy(x).float().to(reader.device)\n",
109
+ " if reader.task.cfg.normalize:\n",
110
+ " x = F.layer_norm(x, x.shape)\n",
111
+ " x = x.view(1, -1)\n",
112
+ "\n",
113
+ " feat = []\n",
114
+ " for start in range(0, x.size(1), reader.max_chunk):\n",
115
+ " x_chunk = x[:, start: start + reader.max_chunk]\n",
116
+ " res = reader.model.extract_features(\n",
117
+ " source=x_chunk,\n",
118
+ " padding_mask=None,\n",
119
+ " mask=False,\n",
120
+ " layer=reader.layer,\n",
121
+ " )\n",
122
+ " feat_chunk = res[\"x\"]\n",
123
+ " feat.append(feat_chunk)\n",
124
+ " \n",
125
+ " features = torch.cat(feat, 1).permute(0, 2, 1)\n",
126
+ "\n",
127
+ " x = model.encoder(features)\n",
128
+ " z = model.projector(x)\n",
129
+ " _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
130
+ " tokens = idx.cpu().data.numpy().tolist()[0]"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "1810e6dc-2ece-4aca-a29a-e1933b8ce82a",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "import logging\n",
141
+ "import os\n",
142
+ "import sys\n",
143
+ "\n",
144
+ "import tqdm\n",
145
+ "from npy_append_array import NpyAppendArray\n",
146
+ "\n",
147
+ "def get_shard_range(tot, nshard, rank):\n",
148
+ " assert rank < nshard and rank >= 0, f\"invaid rank/nshard {rank}/{nshard}\"\n",
149
+ " start = round(tot / nshard * rank)\n",
150
+ " end = round(tot / nshard * (rank + 1))\n",
151
+ " assert start < end, f\"start={start}, end={end}\"\n",
152
+ " logger.info(\n",
153
+ " f\"rank {rank} of {nshard}, process {end-start} \"\n",
154
+ " f\"({start}-{end}) out of {tot}\"\n",
155
+ " )\n",
156
+ " return start, end\n",
157
+ "\n",
158
+ "def get_path_iterator(tsv, nshard, rank):\n",
159
+ " with open(tsv, \"r\") as f:\n",
160
+ " root = f.readline().rstrip()\n",
161
+ " lines = [line.rstrip() for line in f]\n",
162
+ " start, end = get_shard_range(len(lines), nshard, rank)\n",
163
+ " lines = lines[start:end]\n",
164
+ " def iterate():\n",
165
+ " for line in lines:\n",
166
+ " subpath, nsample = line.split(\"\\t\")\n",
167
+ " yield f\"{root}/{subpath}\", int(nsample)\n",
168
+ " return iterate, len(lines)\n",
169
+ "\n",
170
+ "def dump_feature(reader, generator, num, nshard, rank, feat_dir):\n",
171
+ " iterator = generator()\n",
172
+ "\n",
173
+ " feat_path = f\"{feat_dir}/{rank}_{nshard}.npy\"\n",
174
+ " leng_path = f\"{feat_dir}/{rank}_{nshard}.len\"\n",
175
+ "\n",
176
+ " os.makedirs(feat_dir, exist_ok=True)\n",
177
+ " if os.path.exists(feat_path):\n",
178
+ " os.remove(feat_path)\n",
179
+ "\n",
180
+ " feat_f = NpyAppendArray(feat_path)\n",
181
+ " with open(leng_path, \"w\") as leng_f:\n",
182
+ " for path, nsample in tqdm.tqdm(iterator, total=num):\n",
183
+ " feat = reader.get_feats(path, nsample)\n",
184
+ " feat_f.append(feat.cpu().numpy())\n",
185
+ " leng_f.write(f\"{len(feat)}\\n\")\n",
186
+ " logger.info(\"finished successfully\")\n",
187
+ "\n",
188
+ "generator, num = get_path_iterator(tsv_path, nshard, rank)\n",
189
+ "dump_feature(reader, generator, num, nshard, rank, feat_dir)"
190
+ ]
191
+ }
192
+ ],
193
+ "metadata": {
194
+ "kernelspec": {
195
+ "display_name": "Python 3 (ipykernel)",
196
+ "language": "python",
197
+ "name": "python3"
198
+ },
199
+ "language_info": {
200
+ "codemirror_mode": {
201
+ "name": "ipython",
202
+ "version": 3
203
+ },
204
+ "file_extension": ".py",
205
+ "mimetype": "text/x-python",
206
+ "name": "python",
207
+ "nbconvert_exporter": "python",
208
+ "pygments_lexer": "ipython3",
209
+ "version": "3.8.10"
210
+ }
211
+ },
212
+ "nbformat": 4,
213
+ "nbformat_minor": 5
214
+ }
prompting/RepCodec/examples/__pycache__/data2vec_audio.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
prompting/RepCodec/examples/__pycache__/data2vec_feature_reader.cpython-38.pyc ADDED
Binary file (3 kB). View file
 
prompting/RepCodec/examples/__pycache__/feature_utils.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
prompting/RepCodec/examples/__pycache__/hubert_feature_reader.cpython-38.pyc ADDED
Binary file (2.22 kB). View file
 
prompting/RepCodec/examples/__pycache__/tokenize.cpython-38.pyc ADDED
Binary file (1.89 kB). View file
 
prompting/RepCodec/examples/data2vec_audio.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
9
+
10
+ import logging
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ from omegaconf import II
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.distributed as dist
21
+
22
+ from fairseq.modules import EMAModule, EMAModuleConfig
23
+ from fairseq.data.data_utils import compute_mask_indices
24
+ from fairseq.models import BaseFairseqModel, register_model
25
+ from fairseq.models.wav2vec import (
26
+ ConvFeatureExtractionModel,
27
+ Wav2Vec2Config,
28
+ TransformerEncoder,
29
+ )
30
+ from fairseq.modules import (
31
+ GradMultiply,
32
+ LayerNorm,
33
+ )
34
+ from fairseq.utils import index_put
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class Data2VecAudioConfig(Wav2Vec2Config):
42
+
43
+ loss_beta: float = field(
44
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
45
+ )
46
+ loss_scale: Optional[float] = field(
47
+ default=None,
48
+ metadata={
49
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
50
+ },
51
+ )
52
+ average_top_k_layers: int = field(
53
+ default=8, metadata={"help": "how many layers to average"}
54
+ )
55
+
56
+ layer_norm_target_layer: bool = False
57
+ instance_norm_target_layer: bool = False
58
+ instance_norm_targets: bool = False
59
+ layer_norm_targets: bool = False
60
+ batch_norm_target_layer: bool = False
61
+ group_norm_target_layer: bool = False
62
+
63
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
64
+ ema_end_decay: float = field(
65
+ default=0.9999, metadata={"help": "final ema decay rate"}
66
+ )
67
+
68
+ # when to finish annealing ema decay rate
69
+ ema_anneal_end_step: int = II("optimization.max_update")
70
+
71
+ ema_transformer_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer"},
74
+ )
75
+ ema_layers_only: bool = field(
76
+ default=True,
77
+ metadata={"help": "whether to momentum update only the transformer layers"},
78
+ )
79
+
80
+ max_update: int = II("optimization.max_update")
81
+
82
+ min_target_var: float = field(
83
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
84
+ )
85
+ min_pred_var: float = field(
86
+ default=0.01,
87
+ metadata={"help": "stop training if prediction var falls below this"},
88
+ )
89
+
90
+
91
+ def get_annealed_rate(start, end, curr_step, total_steps):
92
+ r = end - start
93
+ pct_remaining = 1 - curr_step / total_steps
94
+ return end - r * pct_remaining
95
+
96
+
97
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
98
+ class Data2VecAudioModel(BaseFairseqModel):
99
+ def __init__(self, cfg: Data2VecAudioConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ feature_enc_layers = eval(cfg.conv_feature_layers)
104
+ self.extractor_embed = feature_enc_layers[-1][0]
105
+
106
+ self.ema = None
107
+ self.embed = cfg.encoder_embed_dim
108
+
109
+ self.average_top_k_layers = cfg.average_top_k_layers
110
+ self.loss_beta = cfg.loss_beta
111
+ self.loss_scale = cfg.loss_scale
112
+
113
+ self.feature_extractor = ConvFeatureExtractionModel(
114
+ conv_layers=feature_enc_layers,
115
+ dropout=0.0,
116
+ mode=cfg.extractor_mode,
117
+ conv_bias=cfg.conv_bias,
118
+ )
119
+
120
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
121
+
122
+ self.mask_prob = cfg.mask_prob
123
+ self.mask_selection = cfg.mask_selection
124
+ self.mask_other = cfg.mask_other
125
+ self.mask_length = cfg.mask_length
126
+ self.no_mask_overlap = cfg.no_mask_overlap
127
+ self.mask_min_space = cfg.mask_min_space
128
+
129
+ self.mask_channel_prob = cfg.mask_channel_prob
130
+ self.mask_channel_before = cfg.mask_channel_before
131
+ self.mask_channel_selection = cfg.mask_channel_selection
132
+ self.mask_channel_other = cfg.mask_channel_other
133
+ self.mask_channel_length = cfg.mask_channel_length
134
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
135
+ self.mask_channel_min_space = cfg.mask_channel_min_space
136
+
137
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
138
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
139
+
140
+ self.feature_grad_mult = cfg.feature_grad_mult
141
+
142
+ self.mask_emb = nn.Parameter(
143
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
144
+ )
145
+
146
+ self.encoder = TransformerEncoder(cfg)
147
+ self.layer_norm = LayerNorm(self.extractor_embed)
148
+
149
+ self.final_proj = nn.Linear(self.embed, self.embed)
150
+
151
+ self.num_updates = 0
152
+
153
+ def make_ema_teacher(self):
154
+ ema_config = EMAModuleConfig(
155
+ ema_decay=self.cfg.ema_decay,
156
+ ema_fp32=True,
157
+ )
158
+ skip_keys = set()
159
+ if self.cfg.ema_layers_only:
160
+ self.cfg.ema_transformer_only = True
161
+ for k, _ in self.encoder.pos_conv.named_parameters():
162
+ skip_keys.add(f"pos_conv.{k}")
163
+
164
+ self.ema = EMAModule(
165
+ self.encoder if self.cfg.ema_transformer_only else self,
166
+ ema_config,
167
+ skip_keys=skip_keys,
168
+ )
169
+
170
+ def set_num_updates(self, num_updates):
171
+ super().set_num_updates(num_updates)
172
+
173
+ if self.ema is None and self.final_proj is not None:
174
+ logger.info(f"making ema teacher")
175
+ self.make_ema_teacher()
176
+ elif self.training and self.ema is not None:
177
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
178
+ if num_updates >= self.cfg.ema_anneal_end_step:
179
+ decay = self.cfg.ema_end_decay
180
+ else:
181
+ decay = get_annealed_rate(
182
+ self.cfg.ema_decay,
183
+ self.cfg.ema_end_decay,
184
+ num_updates,
185
+ self.cfg.ema_anneal_end_step,
186
+ )
187
+ self.ema.set_decay(decay)
188
+ if self.ema.get_decay() < 1:
189
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
190
+
191
+ self.num_updates = num_updates
192
+
193
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
194
+ state = super().state_dict(destination, prefix, keep_vars)
195
+
196
+ if self.ema is not None:
197
+ state[prefix + "_ema"] = self.ema.fp32_params
198
+
199
+ return state
200
+
201
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
202
+ if self.ema is not None:
203
+ k = prefix + "_ema"
204
+ assert k in state_dict
205
+ self.ema.restore(state_dict[k], True)
206
+ del state_dict[k]
207
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
+
209
+ @classmethod
210
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
211
+ """Build a new model instance."""
212
+
213
+ return cls(cfg)
214
+
215
+ def apply_mask(
216
+ self,
217
+ x,
218
+ padding_mask,
219
+ mask_indices=None,
220
+ mask_channel_indices=None,
221
+ ):
222
+ B, T, C = x.shape
223
+
224
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
225
+ mask_channel_indices = compute_mask_indices(
226
+ (B, C),
227
+ None,
228
+ self.mask_channel_prob,
229
+ self.mask_channel_length,
230
+ self.mask_channel_selection,
231
+ self.mask_channel_other,
232
+ no_overlap=self.no_mask_channel_overlap,
233
+ min_space=self.mask_channel_min_space,
234
+ )
235
+ mask_channel_indices = (
236
+ torch.from_numpy(mask_channel_indices)
237
+ .to(x.device)
238
+ .unsqueeze(1)
239
+ .expand(-1, T, -1)
240
+ )
241
+ x[mask_channel_indices] = 0
242
+
243
+ if self.mask_prob > 0:
244
+ if mask_indices is None:
245
+ mask_indices = compute_mask_indices(
246
+ (B, T),
247
+ padding_mask,
248
+ self.mask_prob,
249
+ self.mask_length,
250
+ self.mask_selection,
251
+ self.mask_other,
252
+ min_masks=1,
253
+ no_overlap=self.no_mask_overlap,
254
+ min_space=self.mask_min_space,
255
+ require_same_masks=self.cfg.require_same_masks,
256
+ mask_dropout=self.cfg.mask_dropout,
257
+ )
258
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
259
+ x = index_put(x, mask_indices, self.mask_emb)
260
+ else:
261
+ mask_indices = None
262
+
263
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
264
+ if mask_channel_indices is None:
265
+ mask_channel_indices = compute_mask_indices(
266
+ (B, C),
267
+ None,
268
+ self.mask_channel_prob,
269
+ self.mask_channel_length,
270
+ self.mask_channel_selection,
271
+ self.mask_channel_other,
272
+ no_overlap=self.no_mask_channel_overlap,
273
+ min_space=self.mask_channel_min_space,
274
+ )
275
+ mask_channel_indices = (
276
+ torch.from_numpy(mask_channel_indices)
277
+ .to(x.device)
278
+ .unsqueeze(1)
279
+ .expand(-1, T, -1)
280
+ )
281
+ x = index_put(x, mask_channel_indices, 0)
282
+
283
+ return x, mask_indices
284
+
285
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
286
+ """
287
+ Computes the output length of the convolutional layers
288
+ """
289
+
290
+ def _conv_out_length(input_length, kernel_size, stride):
291
+ return torch.floor((input_length - kernel_size) / stride + 1)
292
+
293
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
294
+
295
+ for i in range(len(conv_cfg_list)):
296
+ input_lengths = _conv_out_length(
297
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
298
+ )
299
+
300
+ return input_lengths.to(torch.long)
301
+
302
+ def forward(
303
+ self,
304
+ source,
305
+ padding_mask=None,
306
+ mask=True,
307
+ features_only=False,
308
+ layer=None,
309
+ mask_indices=None,
310
+ mask_channel_indices=None,
311
+ padding_count=None,
312
+ ):
313
+ features = source
314
+
315
+ if self.feature_grad_mult > 0:
316
+ features = self.feature_extractor(features)
317
+ if self.feature_grad_mult != 1.0:
318
+ features = GradMultiply.apply(features, self.feature_grad_mult)
319
+ else:
320
+ with torch.no_grad():
321
+ features = self.feature_extractor(features)
322
+
323
+ features = features.transpose(1, 2)
324
+
325
+ features = self.layer_norm(features)
326
+
327
+ orig_padding_mask = padding_mask
328
+
329
+ if padding_mask is not None and padding_mask.any():
330
+ input_lengths = (1 - padding_mask.long()).sum(-1)
331
+ # apply conv formula to get real output_lengths
332
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
333
+
334
+ padding_mask = torch.zeros(
335
+ features.shape[:2], dtype=features.dtype, device=features.device
336
+ )
337
+
338
+ # these two operations makes sure that all values
339
+ # before the output lengths indices are attended to
340
+ padding_mask[
341
+ (
342
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
343
+ output_lengths - 1,
344
+ )
345
+ ] = 1
346
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
347
+ else:
348
+ padding_mask = None
349
+
350
+ if self.post_extract_proj is not None:
351
+ features = self.post_extract_proj(features)
352
+
353
+ pre_encoder_features = None
354
+ if self.cfg.ema_transformer_only:
355
+ pre_encoder_features = features.clone()
356
+
357
+ features = self.dropout_input(features)
358
+
359
+ if mask:
360
+ x, mask_indices = self.apply_mask(
361
+ features,
362
+ padding_mask,
363
+ mask_indices=mask_indices,
364
+ mask_channel_indices=mask_channel_indices,
365
+ )
366
+ else:
367
+ x = features
368
+ mask_indices = None
369
+
370
+ x, layer_results = self.encoder(
371
+ x,
372
+ padding_mask=padding_mask,
373
+ layer=layer,
374
+ )
375
+
376
+ if features_only:
377
+ return {
378
+ "x": x,
379
+ "padding_mask": padding_mask,
380
+ "layer_results": layer_results,
381
+ }
382
+
383
+ result = {
384
+ "losses": {},
385
+ }
386
+
387
+ with torch.no_grad():
388
+ self.ema.model.eval()
389
+
390
+ if self.cfg.ema_transformer_only:
391
+ y, layer_results = self.ema.model.extract_features(
392
+ pre_encoder_features,
393
+ padding_mask=padding_mask,
394
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
395
+ )
396
+ y = {
397
+ "x": y,
398
+ "padding_mask": padding_mask,
399
+ "layer_results": layer_results,
400
+ }
401
+ else:
402
+ y = self.ema.model.extract_features(
403
+ source=source,
404
+ padding_mask=orig_padding_mask,
405
+ mask=False,
406
+ )
407
+
408
+ target_layer_results = [l[2] for l in y["layer_results"]]
409
+
410
+ permuted = False
411
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
412
+ target_layer_results = [
413
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
414
+ ]
415
+ permuted = True
416
+
417
+ if self.cfg.batch_norm_target_layer:
418
+ target_layer_results = [
419
+ F.batch_norm(
420
+ tl.float(), running_mean=None, running_var=None, training=True
421
+ )
422
+ for tl in target_layer_results
423
+ ]
424
+
425
+ if self.cfg.instance_norm_target_layer:
426
+ target_layer_results = [
427
+ F.instance_norm(tl.float()) for tl in target_layer_results
428
+ ]
429
+
430
+ if permuted:
431
+ target_layer_results = [
432
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
433
+ ]
434
+
435
+ if self.cfg.group_norm_target_layer:
436
+ target_layer_results = [
437
+ F.layer_norm(tl.float(), tl.shape[-2:])
438
+ for tl in target_layer_results
439
+ ]
440
+
441
+ if self.cfg.layer_norm_target_layer:
442
+ target_layer_results = [
443
+ F.layer_norm(tl.float(), tl.shape[-1:])
444
+ for tl in target_layer_results
445
+ ]
446
+
447
+ y = sum(target_layer_results) / len(target_layer_results)
448
+
449
+ if self.cfg.layer_norm_targets:
450
+ y = F.layer_norm(y.float(), y.shape[-1:])
451
+
452
+ if self.cfg.instance_norm_targets:
453
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
454
+
455
+ if not permuted:
456
+ y = y.transpose(0, 1)
457
+
458
+ y = y[mask_indices]
459
+
460
+ x = x[mask_indices]
461
+ x = self.final_proj(x)
462
+
463
+ sz = x.size(-1)
464
+
465
+ if self.loss_beta == 0:
466
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
467
+ else:
468
+ loss = F.smooth_l1_loss(
469
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
470
+ ).sum(dim=-1)
471
+
472
+ if self.loss_scale is not None:
473
+ scale = self.loss_scale
474
+ else:
475
+ scale = 1 / math.sqrt(sz)
476
+
477
+ result["losses"]["regression"] = loss.sum() * scale
478
+
479
+ if "sample_size" not in result:
480
+ result["sample_size"] = loss.numel()
481
+
482
+ with torch.no_grad():
483
+ result["target_var"] = self.compute_var(y)
484
+ result["pred_var"] = self.compute_var(x.float())
485
+
486
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
487
+ logger.error(
488
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
489
+ )
490
+ raise Exception(
491
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
492
+ )
493
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
494
+ logger.error(
495
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
496
+ )
497
+ raise Exception(
498
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
499
+ )
500
+
501
+ if self.ema is not None:
502
+ result["ema_decay"] = self.ema.get_decay() * 1000
503
+
504
+ return result
505
+
506
+ @staticmethod
507
+ def compute_var(y):
508
+ y = y.view(-1, y.size(-1))
509
+ if dist.is_initialized():
510
+ zc = torch.tensor(y.size(0)).cuda()
511
+ zs = y.sum(dim=0)
512
+ zss = (y ** 2).sum(dim=0)
513
+
514
+ dist.all_reduce(zc)
515
+ dist.all_reduce(zs)
516
+ dist.all_reduce(zss)
517
+
518
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
519
+ return torch.sqrt(var + 1e-6).mean()
520
+ else:
521
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
522
+
523
+ def extract_features(
524
+ self, source, padding_mask, mask=False, layer=None
525
+ ):
526
+ res = self.forward(
527
+ source,
528
+ padding_mask,
529
+ mask=mask,
530
+ features_only=True,
531
+ layer=layer,
532
+ )
533
+ return res
534
+
535
+ def remove_pretraining_modules(self, last_layer=None):
536
+ self.final_proj = None
537
+ self.ema = None
538
+ if last_layer is not None:
539
+ self.encoder.layers = nn.ModuleList(
540
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
541
+ )
prompting/RepCodec/examples/data2vec_feature_reader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from fairseq import tasks
13
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+ from omegaconf import OmegaConf
16
+
17
+ from data2vec_audio import Data2VecAudioModel
18
+
19
+ logger = logging.getLogger("dump_feature")
20
+
21
+
22
+ class Data2vecFeatureReader(object):
23
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
24
+ state = load_checkpoint_to_cpu(ckpt_path)
25
+ cfg = state["cfg"]
26
+ # load task
27
+ task = tasks.setup_task(cfg.task, from_checkpoint=True)
28
+ task.load_state_dict(state["task_state"])
29
+ # load model config
30
+ if "layer_type" not in cfg.model:
31
+ # fix a missing key
32
+ model_config = {k: v for k, v in cfg.model.items()}
33
+ model_config["layer_type"] = "transformer"
34
+ model_config = OmegaConf.create(model_config)
35
+ else:
36
+ model_config = cfg.model
37
+
38
+ # fix param name in the state
39
+ state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
40
+ state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
41
+ del state["model"]["_ema"]
42
+
43
+ # load model
44
+ model = Data2VecAudioModel.build_model(model_config)
45
+ model.load_state_dict(
46
+ state["model"], strict=True, model_cfg=model_config
47
+ )
48
+
49
+ self.device = device
50
+ logger.info(f"device = {self.device}")
51
+
52
+ self.model = model.eval().to(self.device)
53
+ self.task = task
54
+ self.layer = layer - 1 # make it 1-based
55
+ self.max_chunk = max_chunk
56
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
57
+ logger.info(f" max_chunk = {self.max_chunk}")
58
+
59
+ def read_audio(self, path, ref_len=None):
60
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
61
+ if wav.ndim == 2:
62
+ wav = wav.mean(-1)
63
+ assert wav.ndim == 1, wav.ndim
64
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
65
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
66
+ return wav
67
+
68
+ def get_feats(self, path, ref_len=None):
69
+ x = self.read_audio(path, ref_len=ref_len)
70
+ with torch.no_grad():
71
+ x = torch.from_numpy(x).float().to(self.device)
72
+ if self.task.cfg.normalize:
73
+ x = F.layer_norm(x, x.shape)
74
+ x = x.view(1, -1)
75
+
76
+ feat = []
77
+ for start in range(0, x.size(1), self.max_chunk):
78
+ x_chunk = x[:, start: start + self.max_chunk]
79
+ res = self.model.extract_features(
80
+ source=x_chunk,
81
+ padding_mask=None,
82
+ mask=False,
83
+ layer=self.layer,
84
+ )
85
+ feat_chunk = res["x"]
86
+ feat.append(feat_chunk)
87
+ return torch.cat(feat, 1).squeeze(0)
prompting/RepCodec/examples/dump_feature.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ from feature_utils import get_path_iterator, dump_feature
13
+
14
+ logging.basicConfig(
15
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
16
+ datefmt="%Y-%m-%d %H:%M:%S",
17
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
18
+ stream=sys.stdout,
19
+ )
20
+ logger = logging.getLogger("dump_feature")
21
+
22
+
23
+ def main(
24
+ model_type: str,
25
+ tsv_path: str,
26
+ ckpt_path: str,
27
+ whisper_root: str,
28
+ whisper_name: str,
29
+ layer: int,
30
+ nshard: int,
31
+ rank: int,
32
+ feat_dir: str,
33
+ max_chunk: int,
34
+ use_cpu: bool = False
35
+ ):
36
+ device = "cpu" if use_cpu else "cuda"
37
+
38
+ # some checks
39
+ if model_type in ["hubert", "data2vec"]:
40
+ assert ckpt_path and os.path.exists(ckpt_path)
41
+ elif model_type in ["whisper"]:
42
+ assert whisper_name and whisper_root
43
+ else:
44
+ raise ValueError(f"Unsupported model type {model_type}")
45
+
46
+ reader = None
47
+ if model_type == "hubert":
48
+ from hubert_feature_reader import HubertFeatureReader
49
+ reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
50
+ elif model_type == "data2vec":
51
+ from data2vec_feature_reader import Data2vecFeatureReader
52
+ reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
53
+ elif model_type == "whisper":
54
+ from whisper_feature_reader import WhisperFeatureReader
55
+ reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
56
+
57
+ assert reader is not None
58
+
59
+ generator, num = get_path_iterator(tsv_path, nshard, rank)
60
+ dump_feature(reader, generator, num, nshard, rank, feat_dir)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ import argparse
65
+
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument(
68
+ "--model_type",
69
+ required=True,
70
+ type=str,
71
+ choices=["data2vec", "hubert", "whisper"],
72
+ help="the type of the speech encoder."
73
+ )
74
+ parser.add_argument(
75
+ "--tsv_path",
76
+ required=True,
77
+ type=str,
78
+ help="the path to the tsv file."
79
+ )
80
+ parser.add_argument(
81
+ "--ckpt_path",
82
+ required=False,
83
+ type=str,
84
+ default=None,
85
+ help="path to the speech model. must provide for HuBERT and data2vec"
86
+ )
87
+ parser.add_argument(
88
+ "--whisper_root",
89
+ required=False,
90
+ type=str,
91
+ default=None,
92
+ help="root dir to download/store whisper model. must provide for whisper model."
93
+ )
94
+ parser.add_argument(
95
+ "--whisper_name",
96
+ required=False,
97
+ type=str,
98
+ default=None,
99
+ help="name of whisper model. e.g., large-v2. must provide for whisper model."
100
+ )
101
+ parser.add_argument(
102
+ "--layer",
103
+ required=True,
104
+ type=int,
105
+ help="which layer of the model. this is 1-based."
106
+ )
107
+ parser.add_argument(
108
+ "--feat_dir",
109
+ required=True,
110
+ type=str,
111
+ help="the output dir to save the representations."
112
+ )
113
+ parser.add_argument(
114
+ "--nshard",
115
+ required=False,
116
+ type=int,
117
+ default=1,
118
+ help="total number of shards."
119
+ )
120
+ parser.add_argument(
121
+ "--rank",
122
+ required=False,
123
+ type=int,
124
+ default=0,
125
+ help="shard id of this process."
126
+ )
127
+ parser.add_argument(
128
+ "--max_chunk",
129
+ type=int,
130
+ default=1600000,
131
+ help="max number of frames of each batch."
132
+ )
133
+ parser.add_argument(
134
+ "--use_cpu",
135
+ default=False,
136
+ action="store_true",
137
+ help="whether use cpu instead of gpu."
138
+ )
139
+ args = parser.parse_args()
140
+ logger.info(args)
141
+
142
+ main(**vars(args))
prompting/RepCodec/examples/feature_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
9
+
10
+ import logging
11
+ import os
12
+ import sys
13
+
14
+ import tqdm
15
+ from npy_append_array import NpyAppendArray
16
+
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("feature_utils")
25
+
26
+
27
+ def get_shard_range(tot, nshard, rank):
28
+ assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
29
+ start = round(tot / nshard * rank)
30
+ end = round(tot / nshard * (rank + 1))
31
+ assert start < end, f"start={start}, end={end}"
32
+ logger.info(
33
+ f"rank {rank} of {nshard}, process {end-start} "
34
+ f"({start}-{end}) out of {tot}"
35
+ )
36
+ return start, end
37
+
38
+
39
+ def get_path_iterator(tsv, nshard, rank):
40
+ with open(tsv, "r") as f:
41
+ root = f.readline().rstrip()
42
+ lines = [line.rstrip() for line in f]
43
+ start, end = get_shard_range(len(lines), nshard, rank)
44
+ lines = lines[start:end]
45
+ def iterate():
46
+ for line in lines:
47
+ subpath, nsample = line.split("\t")
48
+ yield f"{subpath}", int(nsample)
49
+ return iterate, len(lines)
50
+
51
+
52
+ def dump_feature(reader, generator, num, nshard, rank, feat_dir):
53
+ iterator = generator()
54
+
55
+ feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
56
+ leng_path = f"{feat_dir}/{rank}_{nshard}.len"
57
+
58
+ os.makedirs(feat_dir, exist_ok=True)
59
+ if os.path.exists(feat_path):
60
+ os.remove(feat_path)
61
+
62
+ feat_f = NpyAppendArray(feat_path)
63
+ with open(leng_path, "w") as leng_f:
64
+ for path, nsample in tqdm.tqdm(iterator, total=num):
65
+ feat = reader.get_feats(path, nsample)
66
+ feat_f.append(feat.cpu().numpy())
67
+ leng_f.write(f"{len(feat)}\n")
68
+ logger.info("finished successfully")
69
+
70
+
prompting/RepCodec/examples/hubert_feature_reader.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ import logging
9
+
10
+ import fairseq
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
15
+
16
+ logger = logging.getLogger("dump_feature")
17
+
18
+
19
+ class HubertFeatureReader(object):
20
+ def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
21
+ (
22
+ model,
23
+ cfg,
24
+ task,
25
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
26
+
27
+ self.device = device
28
+ logger.info(f"device = {self.device}")
29
+
30
+ self.model = model[0].eval().to(self.device)
31
+ self.task = task
32
+ self.layer = layer
33
+ self.max_chunk = max_chunk
34
+ logger.info(f"TASK CONFIG:\n{self.task.cfg}")
35
+ logger.info(f" max_chunk = {self.max_chunk}")
36
+
37
+ def read_audio(self, path, ref_len=None):
38
+ wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
39
+ if wav.ndim == 2:
40
+ wav = wav.mean(-1)
41
+ assert wav.ndim == 1, wav.ndim
42
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
43
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
44
+ return wav
45
+
46
+ def get_feats(self, path, ref_len=None):
47
+ x = self.read_audio(path, ref_len=ref_len)
48
+ with torch.no_grad():
49
+ x = torch.from_numpy(x).float().to(self.device)
50
+ if self.task.cfg.normalize:
51
+ x = F.layer_norm(x, x.shape)
52
+ x = x.view(1, -1)
53
+
54
+ feat = []
55
+ for start in range(0, x.size(1), self.max_chunk):
56
+ x_chunk = x[:, start: start + self.max_chunk]
57
+ feat_chunk, _ = self.model.extract_features(
58
+ source=x_chunk,
59
+ padding_mask=None,
60
+ mask=False,
61
+ output_layer=self.layer,
62
+ )
63
+ feat.append(feat_chunk)
64
+ return torch.cat(feat, 1).squeeze(0)
prompting/RepCodec/examples/some_run.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+
5
+ cache_dir = "./../../../cache"
6
+
7
+ dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
8
+
9
+ from repcodec.RepCodec import RepCodec
10
+ import torch
11
+ import yaml
12
+
13
+ config = "./../repcodec/configs/repcodec_dim1024.yaml"
14
+ with open(config) as fp:
15
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
16
+
17
+ model = RepCodec(**conf)
18
+ model.load_state_dict(torch.load("./../../models/data2vec_large_l18.pkl", map_location="cuda:0")["model"]["repcodec"])
19
+ model.quantizer.initial()
20
+ model.eval()
21
+ model.to("cuda:0")
22
+
23
+ from data2vec_feature_reader import Data2vecFeatureReader
24
+
25
+ reader = Data2vecFeatureReader("./../../models/vox_pretrained.pt", 18, device="cuda:0", max_chunk=1600000)
26
+
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+
30
+ for split in dataset.keys():
31
+
32
+ tokens = []
33
+
34
+ for idx in tqdm(range(len(dataset[split]))):
35
+
36
+ sample = dataset[split][idx]
37
+
38
+ x = sample["audio"]["array"]
39
+
40
+ with torch.no_grad():
41
+ x = torch.from_numpy(x).float().to(reader.device)
42
+ if reader.task.cfg.normalize:
43
+ x = F.layer_norm(x, x.shape)
44
+ x = x.view(1, -1)
45
+
46
+ feat = []
47
+ for start in range(0, x.size(1), reader.max_chunk):
48
+ x_chunk = x[:, start: start + reader.max_chunk]
49
+ res = reader.model.extract_features(
50
+ source=x_chunk,
51
+ padding_mask=None,
52
+ mask=False,
53
+ layer=reader.layer,
54
+ )
55
+ feat_chunk = res["x"]
56
+ feat.append(feat_chunk)
57
+
58
+ features = torch.cat(feat, 1).permute(0, 2, 1)
59
+
60
+ x = model.encoder(features)
61
+ z = model.projector(x)
62
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
63
+ tkn = idx.detach().cpu().data.numpy()[0]
64
+
65
+ tokens.append(tkn)
66
+ np.savez(f"./tkns/{split}.npz", *tokens)
prompting/RepCodec/examples/tkns/test.clean.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b2ee3dbd59f84e1903c251bdce1e05a8ed6d1f30bf492ab368d4873dc0e713b
3
+ size 8415186
prompting/RepCodec/examples/tkns/test.other.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:188625227fc71069d6607920a733b76b0c025f6d75bd698abf87bbcfd6089d2c
3
+ size 8403370
prompting/RepCodec/examples/tkns/train.clean.100.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5966bdb4154a433ffe605d4cb152fdda7349ab6456747f0e2275f154739c004
3
+ size 151819656
prompting/RepCodec/examples/tkns/train.clean.360.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed187075b7d515f3a1162df967548bf0d0c384aa31419ac12e7e5b95616ab2c
3
+ size 549056222
prompting/RepCodec/examples/tkns/train.other.500.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14441ffb0f6e0ca5edd1772ddf8fbc6065f8b3e70796b4244dd62a5c286467a7
3
+ size 751972686
prompting/RepCodec/examples/tkns/validation.clean.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5bb0560fa0169e387353d30f13a72eb40704a6518c0b8f72bdfbccada9a8660
3
+ size 8412602
prompting/RepCodec/examples/tkns/validation.other.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:314083f9013d459db663631f237f4ae68860cf25a3137dd7c9d42377e7efb8c1
3
+ size 8067914
prompting/RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens ADDED
The diff for this file is too large to render. See raw diff
 
prompting/RepCodec/examples/whisper_feature_reader.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq) and
7
+ # Whisper (https://github.com/openai/whisper/)
8
+
9
+ import io
10
+ import logging
11
+ import os
12
+ from typing import Optional, Union
13
+
14
+ import soundfile as sf
15
+ import torch
16
+ from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
17
+ from whisper.audio import log_mel_spectrogram
18
+ from whisper.model import ModelDimensions
19
+
20
+ from whisper_model import Whisper_
21
+
22
+ logger = logging.getLogger("dump_feature")
23
+
24
+
25
+ def load_model(
26
+ name: str,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ download_root: str = None,
29
+ in_memory: bool = False,
30
+ ) -> Whisper_:
31
+ """
32
+ Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
33
+ But we will load a `Whisper_` model for feature extraction.
34
+
35
+ Parameters
36
+ ----------
37
+ name : str
38
+ one of the official model names listed by `whisper.available_models()`, or
39
+ path to a model checkpoint containing the model dimensions and the model state_dict.
40
+ device : Union[str, torch.device]
41
+ the PyTorch device to put the model into
42
+ download_root: str
43
+ path to download the model files; by default, it uses "~/.cache/whisper"
44
+ in_memory: bool
45
+ whether to preload the model weights into host memory
46
+
47
+ Returns
48
+ -------
49
+ model : Whisper
50
+ The Whisper ASR model instance
51
+ """
52
+
53
+ if device is None:
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ if download_root is None:
56
+ default = os.path.join(os.path.expanduser("~"), ".cache")
57
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
58
+
59
+ if name in _MODELS:
60
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
61
+ alignment_heads = _ALIGNMENT_HEADS[name]
62
+ elif os.path.isfile(name):
63
+ checkpoint_file = open(name, "rb").read() if in_memory else name
64
+ alignment_heads = None
65
+ else:
66
+ raise RuntimeError(
67
+ f"Model {name} not found; available models = {available_models()}"
68
+ )
69
+
70
+ with (
71
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
72
+ ) as fp:
73
+ checkpoint = torch.load(fp, map_location=device)
74
+ del checkpoint_file
75
+
76
+ dims = ModelDimensions(**checkpoint["dims"])
77
+ model = Whisper_(dims)
78
+ model.load_state_dict(checkpoint["model_state_dict"])
79
+
80
+ if alignment_heads is not None:
81
+ model.set_alignment_heads(alignment_heads)
82
+
83
+ return model.to(device)
84
+
85
+
86
+ class WhisperFeatureReader(object):
87
+ def __init__(self, root, ckpt, layer, device):
88
+ self.device = device
89
+ logger.info(f"device = {self.device}")
90
+
91
+ self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
92
+ self.model.decoder = None # to save some memory by deleting the decoder
93
+ self.layer = layer # one-based
94
+
95
+ def read_audio(self, path, ref_len=None):
96
+ wav, sample_rate = sf.read(path)
97
+ assert sample_rate == 16000, sample_rate
98
+ if ref_len is not None and abs(ref_len - len(wav)) > 160:
99
+ logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
100
+ return wav
101
+
102
+ def get_feats(self, path, ref_len=None):
103
+ wav = self.read_audio(path, ref_len)
104
+ audio_length = len(wav)
105
+ with torch.no_grad():
106
+ mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
107
+ hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
108
+ feature_length = audio_length // 320
109
+ hidden = hidden[0, :feature_length]
110
+ return hidden.contiguous()
prompting/RepCodec/examples/whisper_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq) and
7
+ # Whisper (https://github.com/openai/whisper/)
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions
15
+
16
+
17
+ class AudioEncoder_(AudioEncoder):
18
+ def __init__(self, *args, **kwargs):
19
+ super(AudioEncoder_, self).__init__(*args, **kwargs)
20
+
21
+ def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
22
+ """
23
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
24
+ the mel spectrogram of the audio
25
+ """
26
+ x = F.gelu(self.conv1(x))
27
+ x = F.gelu(self.conv2(x))
28
+ x = x.permute(0, 2, 1)
29
+
30
+ length_x = x.shape[1]
31
+ if length_x > self.positional_embedding.shape[0]:
32
+ self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
33
+ self.positional_embedding = self.positional_embedding.to(x.device)
34
+ x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)
35
+
36
+ if target_layer is None:
37
+ target_layer = len(self.blocks)
38
+
39
+ for block in self.blocks[:target_layer]:
40
+ x = block(x)
41
+
42
+ return x
43
+
44
+
45
+ class Whisper_(Whisper):
46
+ def __init__(self, dims: ModelDimensions):
47
+ super(Whisper_, self).__init__(dims)
48
+ # replace audio encoder with our audio encoder
49
+ self.encoder = AudioEncoder_(
50
+ self.dims.n_mels,
51
+ self.dims.n_audio_ctx,
52
+ self.dims.n_audio_state,
53
+ self.dims.n_audio_head,
54
+ self.dims.n_audio_layer,
55
+ )
56
+
57
+ def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
58
+ return self.encoder.extract_feature(mel, target_layer)