Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- prompting/.ipynb_checkpoints/dataset_generation-checkpoint.ipynb +332 -0
- prompting/.ipynb_checkpoints/generate_rare_words-checkpoint.py +62 -0
- prompting/.ipynb_checkpoints/generate_transcripts-checkpoint.py +60 -0
- prompting/.ipynb_checkpoints/get_error_word_count-checkpoint.py +106 -0
- prompting/.ipynb_checkpoints/model-checkpoint.py +53 -0
- prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json +3 -0
- prompting/.ipynb_checkpoints/train_lora-checkpoint.py +137 -0
- prompting/.ipynb_checkpoints/train_phi-checkpoint.py +86 -0
- prompting/.ipynb_checkpoints/training-checkpoint.ipynb +278 -0
- prompting/RepCodec/.gitignore +160 -0
- prompting/RepCodec/.ipynb_checkpoints/tinker-checkpoint.ipynb +267 -0
- prompting/RepCodec/LICENSE +428 -0
- prompting/RepCodec/README.md +273 -0
- prompting/RepCodec/dataloader/__init__.py +2 -0
- prompting/RepCodec/dataloader/collater.py +22 -0
- prompting/RepCodec/dataloader/dataset.py +90 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/Untitled-checkpoint.ipynb +334 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_audio-checkpoint.py +541 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_feature_reader-checkpoint.py +87 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/dump_feature-checkpoint.py +142 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/feature_utils-checkpoint.py +70 -0
- prompting/RepCodec/examples/.ipynb_checkpoints/some_run-checkpoint.py +66 -0
- prompting/RepCodec/examples/Untitled.ipynb +214 -0
- prompting/RepCodec/examples/__pycache__/data2vec_audio.cpython-38.pyc +0 -0
- prompting/RepCodec/examples/__pycache__/data2vec_feature_reader.cpython-38.pyc +0 -0
- prompting/RepCodec/examples/__pycache__/feature_utils.cpython-38.pyc +0 -0
- prompting/RepCodec/examples/__pycache__/hubert_feature_reader.cpython-38.pyc +0 -0
- prompting/RepCodec/examples/__pycache__/tokenize.cpython-38.pyc +0 -0
- prompting/RepCodec/examples/data2vec_audio.py +541 -0
- prompting/RepCodec/examples/data2vec_feature_reader.py +87 -0
- prompting/RepCodec/examples/dump_feature.py +142 -0
- prompting/RepCodec/examples/feature_utils.py +70 -0
- prompting/RepCodec/examples/hubert_feature_reader.py +64 -0
- prompting/RepCodec/examples/some_run.py +66 -0
- prompting/RepCodec/examples/tkns/test.clean.npz +3 -0
- prompting/RepCodec/examples/tkns/test.other.npz +3 -0
- prompting/RepCodec/examples/tkns/train.clean.100.npz +3 -0
- prompting/RepCodec/examples/tkns/train.clean.360.npz +3 -0
- prompting/RepCodec/examples/tkns/train.other.500.npz +3 -0
- prompting/RepCodec/examples/tkns/validation.clean.npz +3 -0
- prompting/RepCodec/examples/tkns/validation.other.npz +3 -0
- prompting/RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens +0 -0
- prompting/RepCodec/examples/whisper_feature_reader.py +110 -0
- prompting/RepCodec/examples/whisper_model.py +58 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
prompting/train_clean_100_error.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
prompting/train_data/train.clean.360.json filter=lfs diff=lfs merge=lfs -text
|
39 |
+
prompting/train_data/train.other.500.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
prompting/transcripts/train.clean.360.txt filter=lfs diff=lfs merge=lfs -text
|
41 |
+
prompting/transcripts/train.other.500.txt filter=lfs diff=lfs merge=lfs -text
|
42 |
+
prompting/wandb/run-20240615_114519-wfpe2teb/run-wfpe2teb.wandb filter=lfs diff=lfs merge=lfs -text
|
prompting/.ipynb_checkpoints/dataset_generation-checkpoint.ipynb
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "59efc3d7-a57f-43cc-8aa3-34bb57de0251",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Librispeech"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "327243de-fd0f-449d-998a-63282a1c67a2",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from datasets import load_dataset\n",
|
19 |
+
"\n",
|
20 |
+
"cache_dir = \"./../cache\"\n",
|
21 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir)"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": null,
|
27 |
+
"id": "456889e1-f8cc-440b-bf6b-f6fbfafc367d",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"from torchmetrics import WordErrorRate, CharErrorRate\n",
|
32 |
+
"from edit_distance import SequenceMatcher\n",
|
33 |
+
"from tqdm import tqdm\n",
|
34 |
+
"import jiwer\n",
|
35 |
+
"\n",
|
36 |
+
"def correct_text(text):\n",
|
37 |
+
" transforms = jiwer.Compose(\n",
|
38 |
+
" [\n",
|
39 |
+
" jiwer.ExpandCommonEnglishContractions(),\n",
|
40 |
+
" jiwer.ToLowerCase(),\n",
|
41 |
+
" jiwer.RemoveMultipleSpaces(),\n",
|
42 |
+
" jiwer.Strip(),\n",
|
43 |
+
" jiwer.RemovePunctuation(),\n",
|
44 |
+
" jiwer.ReduceToListOfListOfWords(),\n",
|
45 |
+
" ]\n",
|
46 |
+
" )\n",
|
47 |
+
" return transforms(text)\n",
|
48 |
+
"\n",
|
49 |
+
"def align_gt_asr(gt, asr):\n",
|
50 |
+
"\n",
|
51 |
+
" sm = SequenceMatcher(a=gt, b=asr)\n",
|
52 |
+
" best_path = []\n",
|
53 |
+
" opcodes = sm.get_opcodes()\n",
|
54 |
+
"\n",
|
55 |
+
" for tag, i1, i2, j1, j2 in opcodes:\n",
|
56 |
+
"\n",
|
57 |
+
" if tag == \"delete\":\n",
|
58 |
+
" for i in range(i1, i2):\n",
|
59 |
+
" best_path.append([gt[i], \"\"])\n",
|
60 |
+
"\n",
|
61 |
+
" if tag == \"replace\" or tag == \"equal\":\n",
|
62 |
+
" for i, j in zip(range(i1, i2), range(j1, j2)):\n",
|
63 |
+
" best_path.append([gt[i], asr[j]])\n",
|
64 |
+
"\n",
|
65 |
+
" if tag == \"insert\":\n",
|
66 |
+
" for j in range(j1, j2):\n",
|
67 |
+
" best_path.append([\"\", asr[j]])\n",
|
68 |
+
"\n",
|
69 |
+
" return best_path\n",
|
70 |
+
"\n",
|
71 |
+
"import string\n",
|
72 |
+
"def process(text):\n",
|
73 |
+
"\n",
|
74 |
+
" # Lower case every letter\n",
|
75 |
+
" text = text.lower()\n",
|
76 |
+
"\n",
|
77 |
+
" # Remove punctuation\n",
|
78 |
+
" punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
|
79 |
+
" translation_table = str.maketrans('', '', punctuation_to_remove)\n",
|
80 |
+
" text = text.translate(translation_table)\n",
|
81 |
+
"\n",
|
82 |
+
" # Remove whitespaces from front and behind\n",
|
83 |
+
" while text[0] == ' ' or text[-1] == ' ':\n",
|
84 |
+
" if text[0] == ' ':\n",
|
85 |
+
" text = text[1:]\n",
|
86 |
+
" if text[-1] == ' ':\n",
|
87 |
+
" text = text[:-1]\n",
|
88 |
+
" \n",
|
89 |
+
" return text"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": null,
|
95 |
+
"id": "3bc907b0-2ebe-46ac-b6a1-02919e69af88",
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"from tqdm import tqdm\n",
|
100 |
+
"\n",
|
101 |
+
"gens = []\n",
|
102 |
+
"texts = []\n",
|
103 |
+
"\n",
|
104 |
+
"unmatches = []\n",
|
105 |
+
"\n",
|
106 |
+
"for split in [\"validation.clean\"]:\n",
|
107 |
+
" data = dataset[split]\n",
|
108 |
+
" with open(f\"./transcripts/{split}.txt\", \"r\") as f:\n",
|
109 |
+
" for idx, line in enumerate(tqdm(f)):\n",
|
110 |
+
" preds = process(line.rstrip())\n",
|
111 |
+
" text = data[idx][\"text\"]\n",
|
112 |
+
"\n",
|
113 |
+
" path = align_gt_asr(correct_text(text)[0], correct_text(preds)[0])\n",
|
114 |
+
" un = 0\n",
|
115 |
+
" for a, b in path:\n",
|
116 |
+
" if a!=b:\n",
|
117 |
+
" un+=1\n",
|
118 |
+
" \n",
|
119 |
+
" unmatches.append(un)\n",
|
120 |
+
"\n",
|
121 |
+
" # texts.append(process(text))\n",
|
122 |
+
" # gens.append(preds)"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"id": "cac10009-1b47-4e2f-a232-f71b23ee983e",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"import numpy as np\n",
|
133 |
+
"\n",
|
134 |
+
"np.count_nonzero(unmatches)"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"id": "afdc9f74-c2cf-4d52-8563-1bd827f6d900",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"def align_gt_asr(gt, asr):\n",
|
145 |
+
"\n",
|
146 |
+
" sm = SequenceMatcher(a=gt, b=asr)\n",
|
147 |
+
" best_path = []\n",
|
148 |
+
" opcodes = sm.get_opcodes()\n",
|
149 |
+
" \n",
|
150 |
+
" for tag, i1, i2, j1, j2 in opcodes:\n",
|
151 |
+
" \n",
|
152 |
+
" if tag == \"delete\":\n",
|
153 |
+
" for i in range(i1, i2):\n",
|
154 |
+
" best_path.append([gt[i], \"\"])\n",
|
155 |
+
" \n",
|
156 |
+
" if tag == \"replace\" or tag == \"equal\":\n",
|
157 |
+
" for i, j in zip(range(i1, i2), range(j1, j2)):\n",
|
158 |
+
" best_path.append([gt[i], asr[j]])\n",
|
159 |
+
" \n",
|
160 |
+
" if tag == \"insert\":\n",
|
161 |
+
" for j in range(j1, j2):\n",
|
162 |
+
" best_path.append([\"\", asr[j]])\n",
|
163 |
+
" \n",
|
164 |
+
" return best_path\n",
|
165 |
+
"\n",
|
166 |
+
"# align_gt_asr(correct_text(text), correct_text(preds))"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": null,
|
172 |
+
"id": "3cdfd3d9-6c22-4ccd-a22b-df8e79fc20b0",
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"correct_text(text)"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"id": "2c33f46a-f3dd-435f-81e3-e7b10ae03470",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [],
|
185 |
+
"source": [
|
186 |
+
"correct_text([\"hello\", \"hey\"])"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": null,
|
192 |
+
"id": "2cfab12a-2b2c-4c00-bd80-ab571c012f29",
|
193 |
+
"metadata": {},
|
194 |
+
"outputs": [],
|
195 |
+
"source": [
|
196 |
+
"## Transcript of whisper small WER\n",
|
197 |
+
"## validation.clean 4.62\n",
|
198 |
+
"## validation.other 8.11\n",
|
199 |
+
"## test.clean 4.22\n",
|
200 |
+
"## test.other 8.56\n"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"id": "24cb2d8f-9ce2-42f2-bbf0-522106078aac",
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [],
|
209 |
+
"source": [
|
210 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
211 |
+
"from datasets import load_dataset\n",
|
212 |
+
"import numpy as np\n",
|
213 |
+
"import torch\n",
|
214 |
+
"\n",
|
215 |
+
"device = \"cuda:0\"\n",
|
216 |
+
"dtype = torch.float16\n",
|
217 |
+
"cache_dir = \"./../cache\"\n",
|
218 |
+
"model_id = \"openai/whisper-small\"\n",
|
219 |
+
"\n",
|
220 |
+
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", cache_dir=cache_dir)\n",
|
221 |
+
"model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation=\"sdpa\").to(device).to(dtype).eval()"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "markdown",
|
226 |
+
"id": "d5fa6f8e-43f2-44ce-b719-2d8fde4067ce",
|
227 |
+
"metadata": {},
|
228 |
+
"source": [
|
229 |
+
"## Biasing List"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": null,
|
235 |
+
"id": "3cc0f934-d208-445e-aecd-31df73be6986",
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [],
|
238 |
+
"source": [
|
239 |
+
"import sys, os\n",
|
240 |
+
"import json\n",
|
241 |
+
"import string\n",
|
242 |
+
"from tqdm import tqdm\n",
|
243 |
+
"def process(text):\n",
|
244 |
+
"\n",
|
245 |
+
" # Lower case every letter\n",
|
246 |
+
" text = text.lower()\n",
|
247 |
+
"\n",
|
248 |
+
" # Remove punctuation\n",
|
249 |
+
" punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
|
250 |
+
" translation_table = str.maketrans('', '', punctuation_to_remove)\n",
|
251 |
+
" text = text.translate(translation_table)\n",
|
252 |
+
"\n",
|
253 |
+
" # Remove whitespaces from front and behind\n",
|
254 |
+
" while text[0] == ' ' or text[-1] == ' ':\n",
|
255 |
+
" if text[0] == ' ':\n",
|
256 |
+
" text = text[1:]\n",
|
257 |
+
" if text[-1] == ' ':\n",
|
258 |
+
" text = text[:-1]\n",
|
259 |
+
" \n",
|
260 |
+
" return text\n",
|
261 |
+
"\n",
|
262 |
+
"split_name = \"train.clean.100\"\n",
|
263 |
+
"\n",
|
264 |
+
"with open(\"./blist/all_rare_words.txt\") as fin:\n",
|
265 |
+
" rarewords = [process(word.strip()) for word in fin]\n",
|
266 |
+
"\n",
|
267 |
+
"with open(f\"./transcripts/{split_name}.txt\") as fin:\n",
|
268 |
+
" transcripts = [line.strip() for line in fin]\n",
|
269 |
+
"\n",
|
270 |
+
"from datasets import load_dataset\n",
|
271 |
+
"\n",
|
272 |
+
"cache_dir = \"./../cache\"\n",
|
273 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
|
274 |
+
"\n",
|
275 |
+
"train_data = []\n",
|
276 |
+
"\n",
|
277 |
+
"pbar = tqdm(dataset[split_name])\n",
|
278 |
+
"for idx, sample in enumerate(pbar):\n",
|
279 |
+
" \n",
|
280 |
+
" text = process(sample[\"text\"])\n",
|
281 |
+
" transcript = transcripts[idx]\n",
|
282 |
+
" \n",
|
283 |
+
" bwords = []\n",
|
284 |
+
" for word in text.split():\n",
|
285 |
+
" if word in rarewords and word not in transcript:\n",
|
286 |
+
" bwords.append(word)\n",
|
287 |
+
" \n",
|
288 |
+
" if len(bwords) > 0:\n",
|
289 |
+
" train_data.append({\n",
|
290 |
+
" \"split\": split_name,\n",
|
291 |
+
" \"idx\": idx,\n",
|
292 |
+
" \"text\": text,\n",
|
293 |
+
" \"transcript\": transcript,\n",
|
294 |
+
" \"b_words\": bwords,\n",
|
295 |
+
" })\n",
|
296 |
+
" pbar.set_description(f\"Len of train data: {len(train_data)}\")"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "code",
|
301 |
+
"execution_count": null,
|
302 |
+
"id": "cac9a909-e1ce-426a-bda3-b65ba3985d06",
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [],
|
305 |
+
"source": [
|
306 |
+
"with open(f\"./train_data/{split_name}.json\", \"w\") as fout:\n",
|
307 |
+
" json.dump(train_data, fout, indent=4)"
|
308 |
+
]
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"metadata": {
|
312 |
+
"kernelspec": {
|
313 |
+
"display_name": "Python 3 (ipykernel)",
|
314 |
+
"language": "python",
|
315 |
+
"name": "python3"
|
316 |
+
},
|
317 |
+
"language_info": {
|
318 |
+
"codemirror_mode": {
|
319 |
+
"name": "ipython",
|
320 |
+
"version": 3
|
321 |
+
},
|
322 |
+
"file_extension": ".py",
|
323 |
+
"mimetype": "text/x-python",
|
324 |
+
"name": "python",
|
325 |
+
"nbconvert_exporter": "python",
|
326 |
+
"pygments_lexer": "ipython3",
|
327 |
+
"version": "3.8.10"
|
328 |
+
}
|
329 |
+
},
|
330 |
+
"nbformat": 4,
|
331 |
+
"nbformat_minor": 5
|
332 |
+
}
|
prompting/.ipynb_checkpoints/generate_rare_words-checkpoint.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import json
|
3 |
+
import string
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def process(text):
|
7 |
+
|
8 |
+
# Lower case every letter
|
9 |
+
text = text.lower()
|
10 |
+
|
11 |
+
# Remove punctuation
|
12 |
+
punctuation_to_remove = string.punctuation.replace("'", "")
|
13 |
+
translation_table = str.maketrans('', '', punctuation_to_remove)
|
14 |
+
text = text.translate(translation_table)
|
15 |
+
|
16 |
+
# Remove whitespaces from front and behind
|
17 |
+
while text[0] == ' ' or text[-1] == ' ':
|
18 |
+
if text[0] == ' ':
|
19 |
+
text = text[1:]
|
20 |
+
if text[-1] == ' ':
|
21 |
+
text = text[:-1]
|
22 |
+
|
23 |
+
return text
|
24 |
+
|
25 |
+
split_name = "train.other.500"
|
26 |
+
|
27 |
+
with open("./blist/all_rare_words.txt") as fin:
|
28 |
+
rarewords = [process(word.strip()) for word in fin]
|
29 |
+
|
30 |
+
with open(f"./transcripts/{split_name}.txt") as fin:
|
31 |
+
transcripts = [line.strip() for line in fin]
|
32 |
+
|
33 |
+
from datasets import load_dataset
|
34 |
+
|
35 |
+
cache_dir = "./../cache"
|
36 |
+
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
|
37 |
+
|
38 |
+
train_data = []
|
39 |
+
|
40 |
+
pbar = tqdm(dataset[split_name])
|
41 |
+
for idx, sample in enumerate(pbar):
|
42 |
+
|
43 |
+
text = process(sample["text"])
|
44 |
+
transcript = transcripts[idx]
|
45 |
+
|
46 |
+
bwords = []
|
47 |
+
for word in text.split():
|
48 |
+
if word in rarewords and word not in transcript:
|
49 |
+
bwords.append(word)
|
50 |
+
|
51 |
+
if len(bwords) > 0:
|
52 |
+
train_data.append({
|
53 |
+
"split": split_name,
|
54 |
+
"idx": idx,
|
55 |
+
"text": text,
|
56 |
+
"transcript": transcript,
|
57 |
+
"b_words": bwords,
|
58 |
+
})
|
59 |
+
pbar.set_description(f"Len of train data: {len(train_data)}")
|
60 |
+
|
61 |
+
with open(f"./train_data/{split_name}.json", "w") as fout:
|
62 |
+
json.dump(train_data, fout, indent=4)
|
prompting/.ipynb_checkpoints/generate_transcripts-checkpoint.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
2 |
+
from datasets import load_dataset
|
3 |
+
from tqdm import tqdm
|
4 |
+
from math import ceil
|
5 |
+
from model import generate, flush
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import string
|
10 |
+
|
11 |
+
def process(text):
|
12 |
+
|
13 |
+
# Lower case every letter
|
14 |
+
text = text.lower()
|
15 |
+
|
16 |
+
# Remove punctuation
|
17 |
+
punctuation_to_remove = string.punctuation.replace("'", "")
|
18 |
+
translation_table = str.maketrans('', '', punctuation_to_remove)
|
19 |
+
text = text.translate(translation_table)
|
20 |
+
|
21 |
+
# Remove whitespaces from front and behind
|
22 |
+
while text[0] == ' ' or text[-1] == ' ':
|
23 |
+
if text[0] == ' ':
|
24 |
+
text = text[1:]
|
25 |
+
if text[-1] == ' ':
|
26 |
+
text = text[:-1]
|
27 |
+
|
28 |
+
return text
|
29 |
+
|
30 |
+
device = "cuda:0"
|
31 |
+
dtype = torch.float16
|
32 |
+
cache_dir = "./../cache"
|
33 |
+
model_id = "openai/whisper-small"
|
34 |
+
batch_size = 250
|
35 |
+
out_dir = "./transcripts"
|
36 |
+
|
37 |
+
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
|
38 |
+
|
39 |
+
processor = WhisperProcessor.from_pretrained(model_id, cache_dir=cache_dir)
|
40 |
+
model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval()
|
41 |
+
|
42 |
+
for split in dataset.keys():
|
43 |
+
|
44 |
+
data = dataset[split]
|
45 |
+
|
46 |
+
os.makedirs(out_dir, exist_ok=True)
|
47 |
+
|
48 |
+
for idx in tqdm(range(ceil(len(data)/batch_size))):
|
49 |
+
|
50 |
+
audios = data[idx * batch_size: (idx + 1) * batch_size]["audio"]
|
51 |
+
|
52 |
+
arrays = [a["array"] for a in audios]
|
53 |
+
|
54 |
+
transcripts = generate(arrays, model, processor)
|
55 |
+
|
56 |
+
with open(os.path.join(out_dir, f"{split}.txt"), "a") as disk:
|
57 |
+
disk.writelines([process(text) + "\n" for text in transcripts])
|
58 |
+
disk.close()
|
59 |
+
|
60 |
+
flush()
|
prompting/.ipynb_checkpoints/get_error_word_count-checkpoint.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
# from normalizers.english import EnglishTextNormalizer
|
3 |
+
|
4 |
+
error_words_freqs = {}
|
5 |
+
infile = sys.argv[1]
|
6 |
+
# setname = sys.argv[2]
|
7 |
+
insert_error = 0
|
8 |
+
insert_rare = 0
|
9 |
+
freqlist_test = {}
|
10 |
+
# eng_norm = EnglishTextNormalizer()
|
11 |
+
|
12 |
+
freqlist = {}
|
13 |
+
with open("./blist/word_freq.txt") as fin:
|
14 |
+
for line in fin:
|
15 |
+
word, freq = line.split()
|
16 |
+
freqlist[word.upper()] = int(freq)
|
17 |
+
|
18 |
+
with open("./blist/all_rare_words.txt") as fin:
|
19 |
+
rareset = set()
|
20 |
+
for line in fin:
|
21 |
+
rareset.add(line.strip().upper())
|
22 |
+
|
23 |
+
project_set = set()
|
24 |
+
with open(infile) as fin:
|
25 |
+
lines = fin.readlines()
|
26 |
+
for i, line in enumerate(lines):
|
27 |
+
if line.startswith('id:'):
|
28 |
+
project = line.strip(')\n').split('-')[-3:]
|
29 |
+
project = '-'.join(project)
|
30 |
+
if "REF:" in line:
|
31 |
+
nextline = lines[i+1].split()
|
32 |
+
for j, word in enumerate(line.split()):
|
33 |
+
if '*' in word:
|
34 |
+
insert_error += 1
|
35 |
+
if nextline[j].upper() in rareset:
|
36 |
+
insert_rare += 1
|
37 |
+
line = line.replace('*', '')
|
38 |
+
line.replace('%BCACK', '')
|
39 |
+
for word in line.split()[1:]:
|
40 |
+
if not word.startswith('('):
|
41 |
+
if word.upper() not in freqlist_test:
|
42 |
+
freqlist_test[word.upper()] = 1
|
43 |
+
else:
|
44 |
+
freqlist_test[word.upper()] += 1
|
45 |
+
|
46 |
+
if word != word.lower() and word.upper() in error_words_freqs:
|
47 |
+
error_words_freqs[word.upper()] += 1
|
48 |
+
elif word != word.lower() and word.upper() not in error_words_freqs:
|
49 |
+
error_words_freqs[word.upper()] = 1
|
50 |
+
elif word == word.lower() and word.upper() not in error_words_freqs:
|
51 |
+
error_words_freqs[word.upper()] = 0
|
52 |
+
print(len(error_words_freqs.keys()))
|
53 |
+
print(insert_rare)
|
54 |
+
|
55 |
+
commonwords = []
|
56 |
+
rarewords = []
|
57 |
+
oovwords = []
|
58 |
+
common_freq = 0
|
59 |
+
rare_freq = 0
|
60 |
+
oov_freq = 0
|
61 |
+
common_error = 0
|
62 |
+
rare_error = 0
|
63 |
+
oov_error = 0
|
64 |
+
partial_error = 0
|
65 |
+
partial_freq = 0
|
66 |
+
very_common_error = 0
|
67 |
+
very_common_words = 0
|
68 |
+
words_error_freq = {}
|
69 |
+
words_total_freq = {}
|
70 |
+
for word, error in error_words_freqs.items():
|
71 |
+
if word in rareset:
|
72 |
+
rarewords.append(word)
|
73 |
+
rare_freq += freqlist_test[word]
|
74 |
+
rare_error += error
|
75 |
+
elif word not in freqlist:
|
76 |
+
oovwords.append(word)
|
77 |
+
oov_freq += freqlist_test[word] if word in freqlist_test else 1
|
78 |
+
oov_error += error
|
79 |
+
else:
|
80 |
+
if freqlist[word] <= 10 and freqlist[word] >= 3:
|
81 |
+
if freqlist[word] not in words_error_freq:
|
82 |
+
words_error_freq[freqlist[word]] = error
|
83 |
+
words_total_freq[freqlist[word]] = freqlist_test[word]
|
84 |
+
else:
|
85 |
+
words_error_freq[freqlist[word]] += error
|
86 |
+
words_total_freq[freqlist[word]] += freqlist_test[word]
|
87 |
+
if freqlist[word] <= 10 and freqlist[word] >= 3:
|
88 |
+
very_common_error += error
|
89 |
+
very_common_words += freqlist_test[word]
|
90 |
+
commonwords.append(word)
|
91 |
+
common_freq += freqlist_test[word]
|
92 |
+
common_error += error
|
93 |
+
|
94 |
+
total_words = common_freq + rare_freq + oov_freq
|
95 |
+
total_errors = common_error+rare_error+oov_error + insert_error
|
96 |
+
WER = total_errors / total_words
|
97 |
+
print('='*89)
|
98 |
+
print('Common words error freq: {} / {} = {}'.format(common_error, common_freq, common_error/common_freq))
|
99 |
+
print('Rare words error freq: {} / {} = {}'.format(rare_error+insert_rare, rare_freq, (rare_error + insert_rare)/rare_freq))
|
100 |
+
print('OOV words error freq: {} / {} = {}'.format(oov_error, oov_freq, oov_error/max(oov_freq, 1)))
|
101 |
+
print('WER estimate: {} / {} = {}'.format(total_errors, total_words, WER))
|
102 |
+
# print('Partial word count: {} / {}'.format(partial_error, partial_freq))
|
103 |
+
print('Insert error: {} / {} = {}'.format(insert_error - insert_rare, total_words, (insert_error - insert_rare)/total_words))
|
104 |
+
print('Insertion + OOV error {}'.format((insert_error + oov_error - insert_rare) / total_words))
|
105 |
+
# print('Very common words error freq: {} / {} = {}'.format(very_common_error, very_common_words, very_common_error/very_common_words))
|
106 |
+
print('='*89)
|
prompting/.ipynb_checkpoints/model-checkpoint.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gc
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
def flush():
|
7 |
+
gc.collect()
|
8 |
+
torch.cuda.empty_cache()
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]:
|
12 |
+
"""
|
13 |
+
arrays: a list of audio arrays
|
14 |
+
model: the whisper model to use
|
15 |
+
processor: the wisper processor to use
|
16 |
+
"""
|
17 |
+
|
18 |
+
inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features
|
19 |
+
|
20 |
+
# Cache the encoder hidden states
|
21 |
+
encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state
|
22 |
+
|
23 |
+
decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device)
|
24 |
+
|
25 |
+
# Tensor to keep track of which samples have reached the end of text token
|
26 |
+
inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device)
|
27 |
+
|
28 |
+
while inference_continues.any() and max_new_tokens > 0:
|
29 |
+
|
30 |
+
last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state
|
31 |
+
|
32 |
+
# A small optimization to only project the hidden states of the last token
|
33 |
+
last_token_hidden_state = last_hidden_state[:, -1, :]
|
34 |
+
logits = model.proj_out(last_token_hidden_state)
|
35 |
+
|
36 |
+
# Greedy Sampling
|
37 |
+
probas = torch.softmax(logits, dim=-1)
|
38 |
+
pred_idx = torch.argmax(probas, dim=-1, keepdim=True)
|
39 |
+
|
40 |
+
# Fill the samples where inference has stopped with <|end of text|> token
|
41 |
+
pred_idx[~inference_continues, :] = 50257
|
42 |
+
|
43 |
+
decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1)
|
44 |
+
|
45 |
+
# Check if any sample has reached the end of text token
|
46 |
+
reached_end_of_text = (pred_idx.squeeze(-1) == 50257)
|
47 |
+
inference_continues &= ~reached_end_of_text
|
48 |
+
|
49 |
+
max_new_tokens -= 1
|
50 |
+
|
51 |
+
transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True)
|
52 |
+
|
53 |
+
return transcripts
|
prompting/.ipynb_checkpoints/train_clean_100_error-checkpoint.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71787a0de524627122b64b419a50d551ca4a9ddb5ac3888c2e54990850cde26e
|
3 |
+
size 12684279
|
prompting/.ipynb_checkpoints/train_lora-checkpoint.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from peft import LoraConfig, prepare_model_for_kbit_training, TaskType
|
3 |
+
from transformers import (
|
4 |
+
AutoModelForCausalLM,
|
5 |
+
AutoTokenizer,
|
6 |
+
BitsAndBytesConfig,
|
7 |
+
TrainingArguments,
|
8 |
+
set_seed,
|
9 |
+
pipeline
|
10 |
+
)
|
11 |
+
from trl import SFTTrainer, SFTConfig
|
12 |
+
from random import randrange
|
13 |
+
import torch
|
14 |
+
import wandb
|
15 |
+
|
16 |
+
cache_dir = "./../cache"
|
17 |
+
model_id = "microsoft/Phi-3-mini-4k-instruct"
|
18 |
+
new_model = "python-phi-3-mini-4k-instruct"
|
19 |
+
username = "ellipticaloranges"
|
20 |
+
device_map = {"": 0}
|
21 |
+
hf_model_repo = username + "/" + new_model
|
22 |
+
|
23 |
+
## ------------------------LoRA Configs------------------------------------------------------
|
24 |
+
|
25 |
+
lora_r = 16
|
26 |
+
lora_alpha = 16
|
27 |
+
lora_dropout = 0.05
|
28 |
+
target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
|
29 |
+
|
30 |
+
## ------------------------------------------------------------------------------------------
|
31 |
+
|
32 |
+
dataset_name = "flytech/python-codes-25k"
|
33 |
+
dataset_split= "train"
|
34 |
+
|
35 |
+
dataset = load_dataset(dataset_name, split=dataset_split, cache_dir=cache_dir)
|
36 |
+
print(f"Dataset size: {len(dataset)}")
|
37 |
+
|
38 |
+
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True, add_eos_token=True, use_fast=True)
|
40 |
+
# The padding token is set to the unknown token.
|
41 |
+
tokenizer.pad_token = tokenizer.unk_token
|
42 |
+
# The ID of the padding token is set to the ID of the unknown token.
|
43 |
+
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
44 |
+
# ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input.
|
45 |
+
tokenizer.padding_side = 'left'
|
46 |
+
|
47 |
+
|
48 |
+
def create_message_column(row):
|
49 |
+
messages = []
|
50 |
+
user = {
|
51 |
+
"content": f"{row['instruction']}",
|
52 |
+
"role": "user"
|
53 |
+
}
|
54 |
+
messages.append(user)
|
55 |
+
assistant = {
|
56 |
+
"content": f"{row['input']}\n{row['output']}",
|
57 |
+
"role": "assistant"
|
58 |
+
}
|
59 |
+
messages.append(assistant)
|
60 |
+
return {"messages": messages}
|
61 |
+
|
62 |
+
def format_dataset_chatml(row):
|
63 |
+
return {"text": tokenizer.apply_chat_template(row["messages"], add_generation_prompt=False, tokenize=False)}
|
64 |
+
|
65 |
+
dataset_chatml = dataset.map(create_message_column)
|
66 |
+
dataset_chatml = dataset_chatml.map(format_dataset_chatml)
|
67 |
+
dataset_chatml = dataset_chatml.train_test_split(test_size=0.05, seed=1234)
|
68 |
+
|
69 |
+
# print("Max Seq Length", max(map(lambda x: len(tokenizer.encode(x["text"])), dataset)))
|
70 |
+
|
71 |
+
if torch.cuda.is_bf16_supported():
|
72 |
+
compute_dtype = torch.bfloat16
|
73 |
+
attn_implementation = 'flash_attention_2'
|
74 |
+
else:
|
75 |
+
compute_dtype = torch.float16
|
76 |
+
attn_implementation = 'sdpa'
|
77 |
+
|
78 |
+
print(f"Using {compute_dtype} with {attn_implementation} implementation")
|
79 |
+
|
80 |
+
model = AutoModelForCausalLM.from_pretrained(
|
81 |
+
model_id,
|
82 |
+
torch_dtype = compute_dtype,
|
83 |
+
trust_remote_code = True,
|
84 |
+
device_map = device_map,
|
85 |
+
attn_implementation = attn_implementation,
|
86 |
+
cache_dir = cache_dir
|
87 |
+
)
|
88 |
+
|
89 |
+
args = SFTConfig(
|
90 |
+
output_dir="./phi-3-mini-LoRA",
|
91 |
+
eval_strategy="steps",
|
92 |
+
do_eval=True,
|
93 |
+
optim="adamw_torch",
|
94 |
+
per_device_train_batch_size=8,
|
95 |
+
gradient_accumulation_steps=4,
|
96 |
+
per_device_eval_batch_size=8,
|
97 |
+
log_level="debug",
|
98 |
+
save_strategy="epoch",
|
99 |
+
logging_steps=10,
|
100 |
+
learning_rate=1e-4,
|
101 |
+
fp16 = not torch.cuda.is_bf16_supported(),
|
102 |
+
bf16 = torch.cuda.is_bf16_supported(),
|
103 |
+
eval_steps=100,
|
104 |
+
dataset_text_field="text",
|
105 |
+
max_seq_length=512,
|
106 |
+
num_train_epochs=3,
|
107 |
+
warmup_ratio=0.1,
|
108 |
+
lr_scheduler_type="linear",
|
109 |
+
report_to="wandb",
|
110 |
+
seed=42,
|
111 |
+
)
|
112 |
+
|
113 |
+
peft_config = LoraConfig(
|
114 |
+
r=lora_r,
|
115 |
+
lora_alpha=lora_alpha,
|
116 |
+
lora_dropout=lora_dropout,
|
117 |
+
task_type=TaskType.CAUSAL_LM,
|
118 |
+
target_modules=target_modules,
|
119 |
+
)
|
120 |
+
|
121 |
+
model.add_adapter(peft_config)
|
122 |
+
|
123 |
+
wandb.init(project = "Phi 3", name = "python-phi-3-lora")
|
124 |
+
|
125 |
+
trainer = SFTTrainer(
|
126 |
+
model=model,
|
127 |
+
train_dataset=dataset_chatml['train'],
|
128 |
+
eval_dataset=dataset_chatml['test'],
|
129 |
+
peft_config=peft_config,
|
130 |
+
tokenizer=tokenizer,
|
131 |
+
args=args,
|
132 |
+
)
|
133 |
+
|
134 |
+
trainer.train()
|
135 |
+
|
136 |
+
# Save the model to the `output_dir` after training
|
137 |
+
model.save_pretrained("./out/")
|
prompting/.ipynb_checkpoints/train_phi-checkpoint.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
|
2 |
+
from huggingface_hub import ModelCard, ModelCardData, HfApi
|
3 |
+
from datasets import load_dataset
|
4 |
+
from jinja2 import Template
|
5 |
+
from trl import SFTTrainer
|
6 |
+
import yaml
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# Model Configs
|
10 |
+
MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
|
11 |
+
NEW_MODEL_NAME = "opus-phi-3-mini-4k-instruct"
|
12 |
+
CACHE_DIR = "./../cache"
|
13 |
+
|
14 |
+
# Dataset Configs
|
15 |
+
DATASET_NAME = ""
|
16 |
+
SPLIT = "train"
|
17 |
+
|
18 |
+
# the maximum length of the sequences that the model will handle
|
19 |
+
MAX_SEQ_LENGTH = 4096
|
20 |
+
num_train_epochs = 1
|
21 |
+
license = "apache-2.0"
|
22 |
+
username = "darshanmakwana412"
|
23 |
+
learning_rate = 1.41e-5
|
24 |
+
per_device_train_batch_size = 4
|
25 |
+
gradient_accumulation_steps = 1
|
26 |
+
|
27 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)
|
29 |
+
dataset = load_dataset(DATASET_NAME, split=SPLIT)
|
30 |
+
|
31 |
+
# EOS Token is used to mark the end of a sentence
|
32 |
+
EOS_TOKEN=tokenizer.eos_token_id
|
33 |
+
|
34 |
+
def formatting_prompts_func(examples):
|
35 |
+
# Extract the conversations from the examples.
|
36 |
+
convos = examples["conversations"]
|
37 |
+
# Initialize an empty list to store the formatted texts.
|
38 |
+
texts = []
|
39 |
+
# Define a dictionary to map the 'from' field in the conversation to a prefix.
|
40 |
+
mapper = {"system": "system\n", "human": "\nuser\n", "gpt": "\nassistant\n"}
|
41 |
+
# Define a dictionary to map the 'from' field in the conversation to a suffix.
|
42 |
+
end_mapper = {"system": "", "human": "", "gpt": ""}
|
43 |
+
# Iterate over each conversation.
|
44 |
+
for convo in convos:
|
45 |
+
# Format the conversation by joining each turn with its corresponding prefix and suffix.
|
46 |
+
# Append the EOS token to the end of the conversation.
|
47 |
+
text = "".join(f"{mapper[(turn := x['from'])]} {x['value']}\n{end_mapper[turn]}" for x in convo)
|
48 |
+
texts.append(f"{text}{EOS_TOKEN}")
|
49 |
+
# Return the formatted texts.
|
50 |
+
return {"text": texts}
|
51 |
+
|
52 |
+
dataset = dataset.map(formatting_prompts_func, batched=True)
|
53 |
+
|
54 |
+
args = TrainingArguments(
|
55 |
+
evaluation_strategy="steps",
|
56 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
57 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
58 |
+
gradient_checkpointing=True,
|
59 |
+
learning_rate=learning_rate,
|
60 |
+
fp16 = not torch.cuda.is_bf16_supported(),
|
61 |
+
bf16 = torch.cuda.is_bf16_supported(),
|
62 |
+
max_steps=-1,
|
63 |
+
num_train_epochs=num_train_epochs,
|
64 |
+
save_strategy="epoch",
|
65 |
+
logging_steps=10,
|
66 |
+
output_dir=NEW_MODEL_NAME,
|
67 |
+
optim="paged_adamw_32bit",
|
68 |
+
lr_scheduler_type="linear"
|
69 |
+
)
|
70 |
+
|
71 |
+
trainer = SFTTrainer(
|
72 |
+
model=model,
|
73 |
+
args=args,
|
74 |
+
train_dataset=dataset,
|
75 |
+
dataset_text_field="text",
|
76 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
77 |
+
formatting_func=formatting_prompts_func
|
78 |
+
)
|
79 |
+
|
80 |
+
import gc
|
81 |
+
import os
|
82 |
+
|
83 |
+
gc.collect()
|
84 |
+
torch.cuda.empty_cache()
|
85 |
+
|
86 |
+
trainer.train()
|
prompting/.ipynb_checkpoints/training-checkpoint.ipynb
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "12315053-0630-4d3d-8028-02035c2dbf14",
|
6 |
+
"metadata": {
|
7 |
+
"jp-MarkdownHeadingCollapsed": true
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"## Slide Speech Dataset"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"id": "f8eed0bf-d822-4091-8762-df6582095ab4",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"\"\"\"\n",
|
21 |
+
"Dir Structure:\n",
|
22 |
+
" - data\n",
|
23 |
+
" - info\n",
|
24 |
+
" - test\n",
|
25 |
+
" - val\n",
|
26 |
+
"\"\"\""
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "markdown",
|
31 |
+
"id": "e999f437-d756-492d-b873-6ee656279b53",
|
32 |
+
"metadata": {
|
33 |
+
"jp-MarkdownHeadingCollapsed": true
|
34 |
+
},
|
35 |
+
"source": [
|
36 |
+
"## Phi 3 Tinkering"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": null,
|
42 |
+
"id": "3f5f1033-9118-4106-a7fb-6c3b527fe075",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"\"\"\"\n",
|
47 |
+
"Prompt template for Phi-3\n",
|
48 |
+
"<|system|>\n",
|
49 |
+
"You are a python developer.<|end|>\n",
|
50 |
+
"<|user|>\n",
|
51 |
+
"Help me generate a bubble sort algorithm<|end|>\n",
|
52 |
+
"<|assistant|>\n",
|
53 |
+
"\"\"\""
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": null,
|
59 |
+
"id": "1e9e81b0-ae3d-46f1-97ff-138984d07a28",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"import torch\n",
|
64 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
|
65 |
+
"\n",
|
66 |
+
"cache_dir = \"./../cache\"\n",
|
67 |
+
"model_id = \"microsoft/Phi-3-mini-4k-instruct\"\n",
|
68 |
+
"device = \"cuda:0\"\n",
|
69 |
+
"dtype = torch.float16\n",
|
70 |
+
"\n",
|
71 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
72 |
+
" model_id,\n",
|
73 |
+
" device_map = device,\n",
|
74 |
+
" torch_dtype = dtype,\n",
|
75 |
+
" trust_remote_code=True,\n",
|
76 |
+
" cache_dir = cache_dir,\n",
|
77 |
+
" attn_implementation = \"flash_attention_2\"\n",
|
78 |
+
")\n",
|
79 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir)\n",
|
80 |
+
"\n",
|
81 |
+
"pipe = pipeline(\n",
|
82 |
+
" \"text-generation\",\n",
|
83 |
+
" model = model,\n",
|
84 |
+
" tokenizer = tokenizer\n",
|
85 |
+
")"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"id": "93b9fb26-4661-4a35-ad62-b87834f577bc",
|
92 |
+
"metadata": {},
|
93 |
+
"outputs": [],
|
94 |
+
"source": [
|
95 |
+
"messages = [\n",
|
96 |
+
" {\"role\": \"system\", \"content\": \"You are a python developer\"},\n",
|
97 |
+
" {\"role\": \"user\", \"content\": \"Help me generate a bubble sort algorithm\"}\n",
|
98 |
+
"]\n",
|
99 |
+
"\n",
|
100 |
+
"generation_args = {\n",
|
101 |
+
" \"max_new_tokens\": 600,\n",
|
102 |
+
" \"return_full_text\": False,\n",
|
103 |
+
" \"temperature\": 1.0,\n",
|
104 |
+
" \"do_sample\": True\n",
|
105 |
+
"}\n",
|
106 |
+
"\n",
|
107 |
+
"output = pipe(messages, **generation_args)\n",
|
108 |
+
"print(output[0][\"generated_text\"])"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "markdown",
|
113 |
+
"id": "62db724a-9e20-422d-a5b3-dd55cae55cc7",
|
114 |
+
"metadata": {
|
115 |
+
"jp-MarkdownHeadingCollapsed": true
|
116 |
+
},
|
117 |
+
"source": [
|
118 |
+
"## Training Phi 3"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"id": "2036e6b5-c794-4668-9a79-8a53a2736cfa",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig\n",
|
129 |
+
"from huggingface_hub import ModelCard, ModelCardData, HfApi\n",
|
130 |
+
"from datasets import load_dataset\n",
|
131 |
+
"from jinja2 import Template\n",
|
132 |
+
"from trl import SFTTrainer\n",
|
133 |
+
"import yaml\n",
|
134 |
+
"import torch"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"id": "2dcff92e-ab3e-407a-8595-31ffae5f7acd",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"# Model Configs\n",
|
145 |
+
"MODEL_ID = \"microsoft/Phi-3-mini-4k-instruct\"\n",
|
146 |
+
"NEW_MODEL_NAME = \"opus-phi-3-mini-4k-instruct\"\n",
|
147 |
+
"CACHE_DIR = \"./../cache\"\n",
|
148 |
+
"\n",
|
149 |
+
"# Dataset Configs\n",
|
150 |
+
"DATASET_NAME = \"\"\n",
|
151 |
+
"SPLIT = \"train\"\n",
|
152 |
+
"\n",
|
153 |
+
"# the maximum length of the sequences that the model will handle\n",
|
154 |
+
"MAX_SEQ_LENGTH = 4096\n",
|
155 |
+
"num_train_epochs = 1\n",
|
156 |
+
"license = \"apache-2.0\"\n",
|
157 |
+
"username = \"darshanmakwana412\"\n",
|
158 |
+
"learning_rate = 1.41e-5\n",
|
159 |
+
"per_device_train_batch_size = 4\n",
|
160 |
+
"gradient_accumulation_steps = 1\n",
|
161 |
+
"\n",
|
162 |
+
"# If bd16 is supported use bf16 otherwise use f16\n",
|
163 |
+
"if torch.cuda.is_bf16_supported():\n",
|
164 |
+
" compute_dtype = torch.bfloat16\n",
|
165 |
+
"else:\n",
|
166 |
+
" compute_dtype = torch.float16"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": null,
|
172 |
+
"id": "0fbcf3c6-dd94-4133-a805-910d57c9f974",
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)\n",
|
177 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR, trust_remote_code=True)\n",
|
178 |
+
"# dataset = load_dataset(DATASET_NAME, split=SPLIT)\n",
|
179 |
+
"\n",
|
180 |
+
"# EOS Token is used to mark the end of a sentence\n",
|
181 |
+
"EOS_TOKEN=tokenizer.eos_token_id"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": null,
|
187 |
+
"id": "e31954ea-1838-4992-9132-32e59c42a128",
|
188 |
+
"metadata": {},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"def formatting_prompts_func(examples):\n",
|
192 |
+
" # Extract the conversations from the examples.\n",
|
193 |
+
" convos = examples[\"conversations\"]\n",
|
194 |
+
" # Initialize an empty list to store the formatted texts.\n",
|
195 |
+
" texts = []\n",
|
196 |
+
" # Define a dictionary to map the 'from' field in the conversation to a prefix.\n",
|
197 |
+
" mapper = {\"system\": \"system\\n\", \"human\": \"\\nuser\\n\", \"gpt\": \"\\nassistant\\n\"}\n",
|
198 |
+
" # Define a dictionary to map the 'from' field in the conversation to a suffix.\n",
|
199 |
+
" end_mapper = {\"system\": \"\", \"human\": \"\", \"gpt\": \"\"}\n",
|
200 |
+
" # Iterate over each conversation.\n",
|
201 |
+
" for convo in convos:\n",
|
202 |
+
" # Format the conversation by joining each turn with its corresponding prefix and suffix.\n",
|
203 |
+
" # Append the EOS token to the end of the conversation.\n",
|
204 |
+
" text = \"\".join(f\"{mapper[(turn := x['from'])]} {x['value']}\\n{end_mapper[turn]}\" for x in convo)\n",
|
205 |
+
" texts.append(f\"{text}{EOS_TOKEN}\")\n",
|
206 |
+
" # Return the formatted texts.\n",
|
207 |
+
" return {\"text\": texts}\n",
|
208 |
+
"\n",
|
209 |
+
"dataset = dataset.map(formatting_prompts_func, batched=True)"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": null,
|
215 |
+
"id": "3086c2a4-7cca-461e-894c-376046089fab",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": [
|
219 |
+
"args = TrainingArguments(\n",
|
220 |
+
" evaluation_strategy=\"steps\",\n",
|
221 |
+
" per_device_train_batch_size=7,\n",
|
222 |
+
" gradient_accumulation_steps=4,\n",
|
223 |
+
" gradient_checkpointing=True,\n",
|
224 |
+
" learning_rate=1e-4,\n",
|
225 |
+
" fp16 = not torch.cuda.is_bf16_supported(),\n",
|
226 |
+
" bf16 = torch.cuda.is_bf16_supported(),\n",
|
227 |
+
" max_steps=-1,\n",
|
228 |
+
" num_train_epochs=3,\n",
|
229 |
+
" save_strategy=\"epoch\",\n",
|
230 |
+
" logging_steps=10,\n",
|
231 |
+
" output_dir=NEW_MODEL_NAME,\n",
|
232 |
+
" optim=\"paged_adamw_32bit\",\n",
|
233 |
+
" lr_scheduler_type=\"linear\"\n",
|
234 |
+
")\n",
|
235 |
+
"\n",
|
236 |
+
"trainer = SFTTrainer(\n",
|
237 |
+
" model=model,\n",
|
238 |
+
" args=args,\n",
|
239 |
+
" train_dataset=dataset,\n",
|
240 |
+
" dataset_text_field=\"text\",\n",
|
241 |
+
" max_seq_length=128,\n",
|
242 |
+
" formatting_func=formatting_prompts_func\n",
|
243 |
+
")\n",
|
244 |
+
"\n",
|
245 |
+
"device = \"cuda:0\"\n",
|
246 |
+
"\n",
|
247 |
+
"import gc\n",
|
248 |
+
"import os\n",
|
249 |
+
"\n",
|
250 |
+
"gc.collect()\n",
|
251 |
+
"torch.cuda.empty_cache()\n",
|
252 |
+
"\n",
|
253 |
+
"trainer.train()"
|
254 |
+
]
|
255 |
+
}
|
256 |
+
],
|
257 |
+
"metadata": {
|
258 |
+
"kernelspec": {
|
259 |
+
"display_name": "Python 3 (ipykernel)",
|
260 |
+
"language": "python",
|
261 |
+
"name": "python3"
|
262 |
+
},
|
263 |
+
"language_info": {
|
264 |
+
"codemirror_mode": {
|
265 |
+
"name": "ipython",
|
266 |
+
"version": 3
|
267 |
+
},
|
268 |
+
"file_extension": ".py",
|
269 |
+
"mimetype": "text/x-python",
|
270 |
+
"name": "python",
|
271 |
+
"nbconvert_exporter": "python",
|
272 |
+
"pygments_lexer": "ipython3",
|
273 |
+
"version": "3.8.10"
|
274 |
+
}
|
275 |
+
},
|
276 |
+
"nbformat": 4,
|
277 |
+
"nbformat_minor": 5
|
278 |
+
}
|
prompting/RepCodec/.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
prompting/RepCodec/.ipynb_checkpoints/tinker-checkpoint.ipynb
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "997bed07-1181-4562-962a-cb8aa18e1d16",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from repcodec.RepCodec import RepCodec\n",
|
11 |
+
"import torch\n",
|
12 |
+
"import yaml\n",
|
13 |
+
"\n",
|
14 |
+
"config = \"repcodec/configs/repcodec_dim1024.yaml\"\n",
|
15 |
+
"with open(config) as fp:\n",
|
16 |
+
" conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
|
17 |
+
"\n",
|
18 |
+
"model = RepCodec(**conf)\n",
|
19 |
+
"model.load_state_dict(torch.load(\"./../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
|
20 |
+
"model.quantizer.initial()\n",
|
21 |
+
"model.eval()"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": null,
|
27 |
+
"id": "5c5516f1-3565-4080-8612-d5ce52ea2a4d",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"# input shape: (batch size, hidden dim, sequence length)\n",
|
32 |
+
"random_features = torch.randn(size=(1, 1024, 100))\n",
|
33 |
+
"with torch.no_grad():\n",
|
34 |
+
" x = model.encoder(random_features)\n",
|
35 |
+
" z = model.projector(x)\n",
|
36 |
+
" _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
37 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "markdown",
|
42 |
+
"id": "439ecea7-f0d4-4a61-80c2-729138beee32",
|
43 |
+
"metadata": {},
|
44 |
+
"source": [
|
45 |
+
"## Dump Representations"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": null,
|
51 |
+
"id": "6efa1891-0810-4cfb-9552-764297209e99",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"python3 examples/dump_feature.py --model_type data2vec --tsv_path \"./files/train.clean.100.tsv\" --ckpt_path \"./../models/vox_pretrained.pt\" --layer 18 --feat_dir \"./features/train.clean.100\""
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": null,
|
61 |
+
"id": "cbd1c550-0606-4217-ac65-55ae92843f19",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"from datasets import load_dataset\n",
|
66 |
+
"from tqdm import tqdm\n",
|
67 |
+
"import pandas as pd\n",
|
68 |
+
"\n",
|
69 |
+
"cache_dir = \"./../../cache\"\n",
|
70 |
+
"\n",
|
71 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
|
72 |
+
"\n",
|
73 |
+
"# for split in dataset.keys():\n",
|
74 |
+
"# data = dataset[split]\n",
|
75 |
+
"# num_frames = []\n",
|
76 |
+
"# for idx in tqdm(range(len(data))):\n",
|
77 |
+
"# audio = data[idx][\"audio\"]\n",
|
78 |
+
"# num_frames.append(int(len(audio[\"array\"]) * 16000 // audio[\"sampling_rate\"]))\n",
|
79 |
+
" \n",
|
80 |
+
"# df = pd.DataFrame.from_dict({\n",
|
81 |
+
"# \"file_path\": list(data[\"file\"]),\n",
|
82 |
+
"# \"num_frames\": num_frames\n",
|
83 |
+
"# })\n",
|
84 |
+
"# df.to_csv(f\"./files/{split}.tsv\", sep=\"\\t\", index=False)"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"execution_count": null,
|
90 |
+
"id": "5b4af1af-5726-4899-8272-dfe867cb48a8",
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"dataset[\"train.clean.100\"][0]"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "markdown",
|
99 |
+
"id": "ae6a0ef4-8c0a-4f6e-9a81-a9c3350e1266",
|
100 |
+
"metadata": {},
|
101 |
+
"source": [
|
102 |
+
"## Prepare the Dataset"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"id": "b1247988-5eaa-492a-a3ab-2b11505126a6",
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"from datasets import Dataset, load_dataset\n",
|
113 |
+
"from collections import defaultdict\n",
|
114 |
+
"from tqdm import tqdm\n",
|
115 |
+
"import numpy as np\n",
|
116 |
+
"import string\n",
|
117 |
+
"\n",
|
118 |
+
"cache_dir = \"./../../cache\"\n",
|
119 |
+
"\n",
|
120 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
|
121 |
+
"\n",
|
122 |
+
"def process(text):\n",
|
123 |
+
"\n",
|
124 |
+
" # Lower case every letter\n",
|
125 |
+
" text = text.lower()\n",
|
126 |
+
"\n",
|
127 |
+
" # Remove punctuation\n",
|
128 |
+
" punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
|
129 |
+
" translation_table = str.maketrans('', '', punctuation_to_remove)\n",
|
130 |
+
" text = text.translate(translation_table)\n",
|
131 |
+
"\n",
|
132 |
+
" # Remove whitespaces from front and behind\n",
|
133 |
+
" while text[0] == ' ' or text[-1] == ' ':\n",
|
134 |
+
" if text[0] == ' ':\n",
|
135 |
+
" text = text[1:]\n",
|
136 |
+
" if text[-1] == ' ':\n",
|
137 |
+
" text = text[:-1]\n",
|
138 |
+
" \n",
|
139 |
+
" return text\n",
|
140 |
+
"\n",
|
141 |
+
"dataset = dataset.remove_columns([\"audio\", \"speaker_id\", \"chapter_id\"])\n",
|
142 |
+
"\n",
|
143 |
+
"tokenized_ds = defaultdict(lambda: [])\n",
|
144 |
+
"\n",
|
145 |
+
"for split in dataset.keys():\n",
|
146 |
+
"\n",
|
147 |
+
" texts = []\n",
|
148 |
+
" tokens = []\n",
|
149 |
+
" tkns = np.load(f\"./examples/tkns/{split}.npz\")\n",
|
150 |
+
"\n",
|
151 |
+
" for idx, key in enumerate(tqdm(tkns.files)):\n",
|
152 |
+
" tokens.append(list(tkns[key]))\n",
|
153 |
+
" texts.append(process(dataset[split][idx][\"text\"]))\n",
|
154 |
+
"\n",
|
155 |
+
" tokenized_ds[split] = Dataset.from_dict({\n",
|
156 |
+
" \"text\": texts,\n",
|
157 |
+
" \"audio_tokens\": tokens\n",
|
158 |
+
" })"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"id": "bfc82444-3081-4138-aa06-6fb0b7cbc6c3",
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"from datasets import dataset_dict, DatasetDict\n",
|
169 |
+
"\n",
|
170 |
+
"tds = DatasetDict(tokenized_ds)"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"id": "006171b9-d479-4462-9642-d126f77edfc2",
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"tds.save_to_disk(\"librispeech_tokenized.hf\")"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 2,
|
186 |
+
"id": "12970376-4f6f-4926-a954-29c32043b64c",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [
|
189 |
+
{
|
190 |
+
"ename": "ValueError",
|
191 |
+
"evalue": "Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}",
|
192 |
+
"output_type": "error",
|
193 |
+
"traceback": [
|
194 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
195 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
196 |
+
"Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./librispeech_tokenized.hf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
197 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2594\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2589\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[1;32m 2590\u001b[0m (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2591\u001b[0m )\n\u001b[1;32m 2593\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2594\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2595\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2596\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2597\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2598\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2599\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2600\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2601\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2604\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2605\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2606\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2607\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2608\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2609\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2611\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2612\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
|
198 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2266\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2264\u001b[0m download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 2265\u001b[0m download_config\u001b[38;5;241m.\u001b[39mstorage_options\u001b[38;5;241m.\u001b[39mupdate(storage_options)\n\u001b[0;32m-> 2266\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2267\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2268\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2269\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2270\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2271\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2272\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2273\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2275\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_require_default_config_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2276\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_custom_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2277\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2278\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 2279\u001b[0m builder_kwargs \u001b[38;5;241m=\u001b[39m dataset_module\u001b[38;5;241m.\u001b[39mbuilder_kwargs\n",
|
199 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1825\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m 1818\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m LocalDatasetModuleFactoryWithScript(\n\u001b[1;32m 1819\u001b[0m combined_path,\n\u001b[1;32m 1820\u001b[0m download_mode\u001b[38;5;241m=\u001b[39mdownload_mode,\n\u001b[1;32m 1821\u001b[0m dynamic_modules_path\u001b[38;5;241m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m 1822\u001b[0m trust_remote_code\u001b[38;5;241m=\u001b[39mtrust_remote_code,\n\u001b[1;32m 1823\u001b[0m )\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1824\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[0;32m-> 1825\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mLocalDatasetModuleFactoryWithoutScript\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1826\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\n\u001b[1;32m 1827\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1828\u001b[0m \u001b[38;5;66;03m# Try remotely\u001b[39;00m\n\u001b[1;32m 1829\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_relative_path(path) \u001b[38;5;129;01mand\u001b[39;00m path\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
200 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1040\u001b[0m, in \u001b[0;36mLocalDatasetModuleFactoryWithoutScript.get_module\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1034\u001b[0m patterns \u001b[38;5;241m=\u001b[39m get_data_patterns(base_path)\n\u001b[1;32m 1035\u001b[0m data_files \u001b[38;5;241m=\u001b[39m DataFilesDict\u001b[38;5;241m.\u001b[39mfrom_patterns(\n\u001b[1;32m 1036\u001b[0m patterns,\n\u001b[1;32m 1037\u001b[0m base_path\u001b[38;5;241m=\u001b[39mbase_path,\n\u001b[1;32m 1038\u001b[0m allowed_extensions\u001b[38;5;241m=\u001b[39mALL_ALLOWED_EXTENSIONS,\n\u001b[1;32m 1039\u001b[0m )\n\u001b[0;32m-> 1040\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43minfer_module_for_data_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1041\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1042\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1043\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1044\u001b[0m data_files \u001b[38;5;241m=\u001b[39m data_files\u001b[38;5;241m.\u001b[39mfilter_extensions(_MODULE_TO_EXTENSIONS[module_name])\n\u001b[1;32m 1045\u001b[0m \u001b[38;5;66;03m# Collect metadata files if the module supports them\u001b[39;00m\n",
|
201 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:596\u001b[0m, in \u001b[0;36minfer_module_for_data_files\u001b[0;34m(data_files, path, download_config)\u001b[0m\n\u001b[1;32m 594\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(split_modules\u001b[38;5;241m.\u001b[39mvalues()))\n\u001b[1;32m 595\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m((module_name, default_builder_kwargs) \u001b[38;5;241m!=\u001b[39m split_module \u001b[38;5;28;01mfor\u001b[39;00m split_module \u001b[38;5;129;01min\u001b[39;00m split_modules\u001b[38;5;241m.\u001b[39mvalues()):\n\u001b[0;32m--> 596\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt infer the same data file format for all splits. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msplit_modules\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m module_name:\n\u001b[1;32m 598\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DataFilesNotFoundError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo (supported) data files found\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m path \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
|
202 |
+
"\u001b[0;31mValueError\u001b[0m: Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}"
|
203 |
+
]
|
204 |
+
}
|
205 |
+
],
|
206 |
+
"source": [
|
207 |
+
"from datasets import load_dataset\n",
|
208 |
+
"\n",
|
209 |
+
"dataset = load_dataset(\"./librispeech_tokenized.hf\")"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": 8,
|
215 |
+
"id": "b3ba0d58-b788-43b5-87a7-726aaa12dbbd",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [],
|
218 |
+
"source": [
|
219 |
+
"from datasets import dataset_dict, DatasetDict, Dataset\n",
|
220 |
+
"\n",
|
221 |
+
"dataset = DatasetDict.load_from_disk(\"./librispeech_tokenized.hf\")"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "code",
|
226 |
+
"execution_count": 13,
|
227 |
+
"id": "b7239186-73ae-407a-b9f6-b5a16f3a7ddc",
|
228 |
+
"metadata": {},
|
229 |
+
"outputs": [
|
230 |
+
{
|
231 |
+
"data": {
|
232 |
+
"text/plain": [
|
233 |
+
"726"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
"execution_count": 13,
|
237 |
+
"metadata": {},
|
238 |
+
"output_type": "execute_result"
|
239 |
+
}
|
240 |
+
],
|
241 |
+
"source": [
|
242 |
+
"len(dataset[\"train.clean.100\"][0][\"audio_tokens\"])"
|
243 |
+
]
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"metadata": {
|
247 |
+
"kernelspec": {
|
248 |
+
"display_name": "Python 3 (ipykernel)",
|
249 |
+
"language": "python",
|
250 |
+
"name": "python3"
|
251 |
+
},
|
252 |
+
"language_info": {
|
253 |
+
"codemirror_mode": {
|
254 |
+
"name": "ipython",
|
255 |
+
"version": 3
|
256 |
+
},
|
257 |
+
"file_extension": ".py",
|
258 |
+
"mimetype": "text/x-python",
|
259 |
+
"name": "python",
|
260 |
+
"nbconvert_exporter": "python",
|
261 |
+
"pygments_lexer": "ipython3",
|
262 |
+
"version": "3.8.10"
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"nbformat": 4,
|
266 |
+
"nbformat_minor": 5
|
267 |
+
}
|
prompting/RepCodec/LICENSE
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) ByteDance, Inc. and its affiliates.
|
4 |
+
Copyright (c) Chutong Meng
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
Attribution-NonCommercial 4.0 International
|
31 |
+
|
32 |
+
=======================================================================
|
33 |
+
|
34 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
35 |
+
does not provide legal services or legal advice. Distribution of
|
36 |
+
Creative Commons public licenses does not create a lawyer-client or
|
37 |
+
other relationship. Creative Commons makes its licenses and related
|
38 |
+
information available on an "as-is" basis. Creative Commons gives no
|
39 |
+
warranties regarding its licenses, any material licensed under their
|
40 |
+
terms and conditions, or any related information. Creative Commons
|
41 |
+
disclaims all liability for damages resulting from their use to the
|
42 |
+
fullest extent possible.
|
43 |
+
|
44 |
+
Using Creative Commons Public Licenses
|
45 |
+
|
46 |
+
Creative Commons public licenses provide a standard set of terms and
|
47 |
+
conditions that creators and other rights holders may use to share
|
48 |
+
original works of authorship and other material subject to copyright
|
49 |
+
and certain other rights specified in the public license below. The
|
50 |
+
following considerations are for informational purposes only, are not
|
51 |
+
exhaustive, and do not form part of our licenses.
|
52 |
+
|
53 |
+
Considerations for licensors: Our public licenses are
|
54 |
+
intended for use by those authorized to give the public
|
55 |
+
permission to use material in ways otherwise restricted by
|
56 |
+
copyright and certain other rights. Our licenses are
|
57 |
+
irrevocable. Licensors should read and understand the terms
|
58 |
+
and conditions of the license they choose before applying it.
|
59 |
+
Licensors should also secure all rights necessary before
|
60 |
+
applying our licenses so that the public can reuse the
|
61 |
+
material as expected. Licensors should clearly mark any
|
62 |
+
material not subject to the license. This includes other CC-
|
63 |
+
licensed material, or material used under an exception or
|
64 |
+
limitation to copyright. More considerations for licensors:
|
65 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
66 |
+
|
67 |
+
Considerations for the public: By using one of our public
|
68 |
+
licenses, a licensor grants the public permission to use the
|
69 |
+
licensed material under specified terms and conditions. If
|
70 |
+
the licensor's permission is not necessary for any reason--for
|
71 |
+
example, because of any applicable exception or limitation to
|
72 |
+
copyright--then that use is not regulated by the license. Our
|
73 |
+
licenses grant only permissions under copyright and certain
|
74 |
+
other rights that a licensor has authority to grant. Use of
|
75 |
+
the licensed material may still be restricted for other
|
76 |
+
reasons, including because others have copyright or other
|
77 |
+
rights in the material. A licensor may make special requests,
|
78 |
+
such as asking that all changes be marked or described.
|
79 |
+
Although not required by our licenses, you are encouraged to
|
80 |
+
respect those requests where reasonable. More_considerations
|
81 |
+
for the public:
|
82 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
83 |
+
|
84 |
+
=======================================================================
|
85 |
+
|
86 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
87 |
+
License
|
88 |
+
|
89 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
90 |
+
to be bound by the terms and conditions of this Creative Commons
|
91 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
92 |
+
License"). To the extent this Public License may be interpreted as a
|
93 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
94 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
95 |
+
such rights in consideration of benefits the Licensor receives from
|
96 |
+
making the Licensed Material available under these terms and
|
97 |
+
conditions.
|
98 |
+
|
99 |
+
Section 1 -- Definitions.
|
100 |
+
|
101 |
+
a. Adapted Material means material subject to Copyright and Similar
|
102 |
+
Rights that is derived from or based upon the Licensed Material
|
103 |
+
and in which the Licensed Material is translated, altered,
|
104 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
105 |
+
permission under the Copyright and Similar Rights held by the
|
106 |
+
Licensor. For purposes of this Public License, where the Licensed
|
107 |
+
Material is a musical work, performance, or sound recording,
|
108 |
+
Adapted Material is always produced where the Licensed Material is
|
109 |
+
synched in timed relation with a moving image.
|
110 |
+
|
111 |
+
b. Adapter's License means the license You apply to Your Copyright
|
112 |
+
and Similar Rights in Your contributions to Adapted Material in
|
113 |
+
accordance with the terms and conditions of this Public License.
|
114 |
+
|
115 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
116 |
+
closely related to copyright including, without limitation,
|
117 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
118 |
+
Rights, without regard to how the rights are labeled or
|
119 |
+
categorized. For purposes of this Public License, the rights
|
120 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
121 |
+
Rights.
|
122 |
+
d. Effective Technological Measures means those measures that, in the
|
123 |
+
absence of proper authority, may not be circumvented under laws
|
124 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
125 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
126 |
+
agreements.
|
127 |
+
|
128 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
129 |
+
any other exception or limitation to Copyright and Similar Rights
|
130 |
+
that applies to Your use of the Licensed Material.
|
131 |
+
|
132 |
+
f. Licensed Material means the artistic or literary work, database,
|
133 |
+
or other material to which the Licensor applied this Public
|
134 |
+
License.
|
135 |
+
|
136 |
+
g. Licensed Rights means the rights granted to You subject to the
|
137 |
+
terms and conditions of this Public License, which are limited to
|
138 |
+
all Copyright and Similar Rights that apply to Your use of the
|
139 |
+
Licensed Material and that the Licensor has authority to license.
|
140 |
+
|
141 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
142 |
+
under this Public License.
|
143 |
+
|
144 |
+
i. NonCommercial means not primarily intended for or directed towards
|
145 |
+
commercial advantage or monetary compensation. For purposes of
|
146 |
+
this Public License, the exchange of the Licensed Material for
|
147 |
+
other material subject to Copyright and Similar Rights by digital
|
148 |
+
file-sharing or similar means is NonCommercial provided there is
|
149 |
+
no payment of monetary compensation in connection with the
|
150 |
+
exchange.
|
151 |
+
|
152 |
+
j. Share means to provide material to the public by any means or
|
153 |
+
process that requires permission under the Licensed Rights, such
|
154 |
+
as reproduction, public display, public performance, distribution,
|
155 |
+
dissemination, communication, or importation, and to make material
|
156 |
+
available to the public including in ways that members of the
|
157 |
+
public may access the material from a place and at a time
|
158 |
+
individually chosen by them.
|
159 |
+
|
160 |
+
k. Sui Generis Database Rights means rights other than copyright
|
161 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
162 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
163 |
+
as amended and/or succeeded, as well as other essentially
|
164 |
+
equivalent rights anywhere in the world.
|
165 |
+
|
166 |
+
l. You means the individual or entity exercising the Licensed Rights
|
167 |
+
under this Public License. Your has a corresponding meaning.
|
168 |
+
|
169 |
+
Section 2 -- Scope.
|
170 |
+
|
171 |
+
a. License grant.
|
172 |
+
|
173 |
+
1. Subject to the terms and conditions of this Public License,
|
174 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
175 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
176 |
+
exercise the Licensed Rights in the Licensed Material to:
|
177 |
+
|
178 |
+
a. reproduce and Share the Licensed Material, in whole or
|
179 |
+
in part, for NonCommercial purposes only; and
|
180 |
+
|
181 |
+
b. produce, reproduce, and Share Adapted Material for
|
182 |
+
NonCommercial purposes only.
|
183 |
+
|
184 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
185 |
+
Exceptions and Limitations apply to Your use, this Public
|
186 |
+
License does not apply, and You do not need to comply with
|
187 |
+
its terms and conditions.
|
188 |
+
|
189 |
+
3. Term. The term of this Public License is specified in Section
|
190 |
+
6(a).
|
191 |
+
|
192 |
+
4. Media and formats; technical modifications allowed. The
|
193 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
194 |
+
all media and formats whether now known or hereafter created,
|
195 |
+
and to make technical modifications necessary to do so. The
|
196 |
+
Licensor waives and/or agrees not to assert any right or
|
197 |
+
authority to forbid You from making technical modifications
|
198 |
+
necessary to exercise the Licensed Rights, including
|
199 |
+
technical modifications necessary to circumvent Effective
|
200 |
+
Technological Measures. For purposes of this Public License,
|
201 |
+
simply making modifications authorized by this Section 2(a)
|
202 |
+
(4) never produces Adapted Material.
|
203 |
+
|
204 |
+
5. Downstream recipients.
|
205 |
+
|
206 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
207 |
+
recipient of the Licensed Material automatically
|
208 |
+
receives an offer from the Licensor to exercise the
|
209 |
+
Licensed Rights under the terms and conditions of this
|
210 |
+
Public License.
|
211 |
+
|
212 |
+
b. No downstream restrictions. You may not offer or impose
|
213 |
+
any additional or different terms or conditions on, or
|
214 |
+
apply any Effective Technological Measures to, the
|
215 |
+
Licensed Material if doing so restricts exercise of the
|
216 |
+
Licensed Rights by any recipient of the Licensed
|
217 |
+
Material.
|
218 |
+
|
219 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
220 |
+
may be construed as permission to assert or imply that You
|
221 |
+
are, or that Your use of the Licensed Material is, connected
|
222 |
+
with, or sponsored, endorsed, or granted official status by,
|
223 |
+
the Licensor or others designated to receive attribution as
|
224 |
+
provided in Section 3(a)(1)(A)(i).
|
225 |
+
|
226 |
+
b. Other rights.
|
227 |
+
|
228 |
+
1. Moral rights, such as the right of integrity, are not
|
229 |
+
licensed under this Public License, nor are publicity,
|
230 |
+
privacy, and/or other similar personality rights; however, to
|
231 |
+
the extent possible, the Licensor waives and/or agrees not to
|
232 |
+
assert any such rights held by the Licensor to the limited
|
233 |
+
extent necessary to allow You to exercise the Licensed
|
234 |
+
Rights, but not otherwise.
|
235 |
+
|
236 |
+
2. Patent and trademark rights are not licensed under this
|
237 |
+
Public License.
|
238 |
+
|
239 |
+
3. To the extent possible, the Licensor waives any right to
|
240 |
+
collect royalties from You for the exercise of the Licensed
|
241 |
+
Rights, whether directly or through a collecting society
|
242 |
+
under any voluntary or waivable statutory or compulsory
|
243 |
+
licensing scheme. In all other cases the Licensor expressly
|
244 |
+
reserves any right to collect such royalties, including when
|
245 |
+
the Licensed Material is used other than for NonCommercial
|
246 |
+
purposes.
|
247 |
+
|
248 |
+
Section 3 -- License Conditions.
|
249 |
+
|
250 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
251 |
+
following conditions.
|
252 |
+
|
253 |
+
a. Attribution.
|
254 |
+
|
255 |
+
1. If You Share the Licensed Material (including in modified
|
256 |
+
form), You must:
|
257 |
+
|
258 |
+
a. retain the following if it is supplied by the Licensor
|
259 |
+
with the Licensed Material:
|
260 |
+
|
261 |
+
i. identification of the creator(s) of the Licensed
|
262 |
+
Material and any others designated to receive
|
263 |
+
attribution, in any reasonable manner requested by
|
264 |
+
the Licensor (including by pseudonym if
|
265 |
+
designated);
|
266 |
+
|
267 |
+
ii. a copyright notice;
|
268 |
+
|
269 |
+
iii. a notice that refers to this Public License;
|
270 |
+
|
271 |
+
iv. a notice that refers to the disclaimer of
|
272 |
+
warranties;
|
273 |
+
|
274 |
+
v. a URI or hyperlink to the Licensed Material to the
|
275 |
+
extent reasonably practicable;
|
276 |
+
|
277 |
+
b. indicate if You modified the Licensed Material and
|
278 |
+
retain an indication of any previous modifications; and
|
279 |
+
|
280 |
+
c. indicate the Licensed Material is licensed under this
|
281 |
+
Public License, and include the text of, or the URI or
|
282 |
+
hyperlink to, this Public License.
|
283 |
+
|
284 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
285 |
+
reasonable manner based on the medium, means, and context in
|
286 |
+
which You Share the Licensed Material. For example, it may be
|
287 |
+
reasonable to satisfy the conditions by providing a URI or
|
288 |
+
hyperlink to a resource that includes the required
|
289 |
+
information.
|
290 |
+
|
291 |
+
3. If requested by the Licensor, You must remove any of the
|
292 |
+
information required by Section 3(a)(1)(A) to the extent
|
293 |
+
reasonably practicable.
|
294 |
+
|
295 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
296 |
+
License You apply must not prevent recipients of the Adapted
|
297 |
+
Material from complying with this Public License.
|
298 |
+
|
299 |
+
Section 4 -- Sui Generis Database Rights.
|
300 |
+
|
301 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
302 |
+
apply to Your use of the Licensed Material:
|
303 |
+
|
304 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
305 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
306 |
+
portion of the contents of the database for NonCommercial purposes
|
307 |
+
only;
|
308 |
+
|
309 |
+
b. if You include all or a substantial portion of the database
|
310 |
+
contents in a database in which You have Sui Generis Database
|
311 |
+
Rights, then the database in which You have Sui Generis Database
|
312 |
+
Rights (but not its individual contents) is Adapted Material; and
|
313 |
+
|
314 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
315 |
+
all or a substantial portion of the contents of the database.
|
316 |
+
|
317 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
318 |
+
replace Your obligations under this Public License where the Licensed
|
319 |
+
Rights include other Copyright and Similar Rights.
|
320 |
+
|
321 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
322 |
+
|
323 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
324 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
325 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
326 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
327 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
328 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
329 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
330 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
331 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
332 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
333 |
+
|
334 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
335 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
336 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
337 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
338 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
339 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
340 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
341 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
342 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
343 |
+
|
344 |
+
c. The disclaimer of warranties and limitation of liability provided
|
345 |
+
above shall be interpreted in a manner that, to the extent
|
346 |
+
possible, most closely approximates an absolute disclaimer and
|
347 |
+
waiver of all liability.
|
348 |
+
|
349 |
+
Section 6 -- Term and Termination.
|
350 |
+
|
351 |
+
a. This Public License applies for the term of the Copyright and
|
352 |
+
Similar Rights licensed here. However, if You fail to comply with
|
353 |
+
this Public License, then Your rights under this Public License
|
354 |
+
terminate automatically.
|
355 |
+
|
356 |
+
b. Where Your right to use the Licensed Material has terminated under
|
357 |
+
Section 6(a), it reinstates:
|
358 |
+
|
359 |
+
1. automatically as of the date the violation is cured, provided
|
360 |
+
it is cured within 30 days of Your discovery of the
|
361 |
+
violation; or
|
362 |
+
|
363 |
+
2. upon express reinstatement by the Licensor.
|
364 |
+
|
365 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
366 |
+
right the Licensor may have to seek remedies for Your violations
|
367 |
+
of this Public License.
|
368 |
+
|
369 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
370 |
+
Licensed Material under separate terms or conditions or stop
|
371 |
+
distributing the Licensed Material at any time; however, doing so
|
372 |
+
will not terminate this Public License.
|
373 |
+
|
374 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
375 |
+
License.
|
376 |
+
|
377 |
+
Section 7 -- Other Terms and Conditions.
|
378 |
+
|
379 |
+
a. The Licensor shall not be bound by any additional or different
|
380 |
+
terms or conditions communicated by You unless expressly agreed.
|
381 |
+
|
382 |
+
b. Any arrangements, understandings, or agreements regarding the
|
383 |
+
Licensed Material not stated herein are separate from and
|
384 |
+
independent of the terms and conditions of this Public License.
|
385 |
+
|
386 |
+
Section 8 -- Interpretation.
|
387 |
+
|
388 |
+
a. For the avoidance of doubt, this Public License does not, and
|
389 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
390 |
+
conditions on any use of the Licensed Material that could lawfully
|
391 |
+
be made without permission under this Public License.
|
392 |
+
|
393 |
+
b. To the extent possible, if any provision of this Public License is
|
394 |
+
deemed unenforceable, it shall be automatically reformed to the
|
395 |
+
minimum extent necessary to make it enforceable. If the provision
|
396 |
+
cannot be reformed, it shall be severed from this Public License
|
397 |
+
without affecting the enforceability of the remaining terms and
|
398 |
+
conditions.
|
399 |
+
|
400 |
+
c. No term or condition of this Public License will be waived and no
|
401 |
+
failure to comply consented to unless expressly agreed to by the
|
402 |
+
Licensor.
|
403 |
+
|
404 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
405 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
406 |
+
that apply to the Licensor or You, including from the legal
|
407 |
+
processes of any jurisdiction or authority.
|
408 |
+
|
409 |
+
=======================================================================
|
410 |
+
|
411 |
+
Creative Commons is not a party to its public
|
412 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
413 |
+
its public licenses to material it publishes and in those instances
|
414 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
415 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
416 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
417 |
+
material is shared under a Creative Commons public license or as
|
418 |
+
otherwise permitted by the Creative Commons policies published at
|
419 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
420 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
421 |
+
of Creative Commons without its prior written consent including,
|
422 |
+
without limitation, in connection with any unauthorized modifications
|
423 |
+
to any of its public licenses or any other arrangements,
|
424 |
+
understandings, or agreements concerning use of licensed material. For
|
425 |
+
the avoidance of doubt, this paragraph does not form part of the
|
426 |
+
public licenses.
|
427 |
+
|
428 |
+
Creative Commons may be contacted at creativecommons.org.
|
prompting/RepCodec/README.md
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RepCodec: A Speech Representation Codec for Speech Tokenization
|
2 |
+
|
3 |
+
> [**RepCodec: A Speech Representation Codec for Speech Tokenization**](https://arxiv.org/abs/2309.00169)
|
4 |
+
|
5 |
+
## Introduction
|
6 |
+
|
7 |
+
**RepCodec** is a speech tokenization method for converting a speech waveform into a sequence of discrete semantic
|
8 |
+
tokens.
|
9 |
+
The main idea is to train a representation codec which learns a vector quantization codebook through reconstructing the
|
10 |
+
input speech representations from speech encoders like HuBERT or data2vec.
|
11 |
+
Extensive experiments show that RepCodec significantly outperforms the widely used k-means clustering approach in both
|
12 |
+
speech understanding and generation.
|
13 |
+
Also, RepCodec generalizes well across various speech encoders and languages.
|
14 |
+
|
15 |
+
<img src="images/RepCodec.png" alt="se" width="1000" />
|
16 |
+
|
17 |
+
## RepCodec Models
|
18 |
+
|
19 |
+
| Feature Type | Speech Data | RepCodec Model |
|
20 |
+
|-----------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------|----------------------------------------------------------------------------------------------------------|
|
21 |
+
| [HuBERT base](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 9 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_base_l9](https://drive.google.com/file/d/1XD0HKl607FFjri2-VJT7lHQeSpxsCCFO/view?usp=sharing) |
|
22 |
+
| [HuBERT large](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [hubert_large_l18](https://drive.google.com/file/d/1mTbm5GeJ7gp_5L3QLP-JGXdf8RnRw5n6/view?usp=sharing) |
|
23 |
+
| [data2vec base](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 6 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_base_l6](https://drive.google.com/file/d/1d8sf3Ko_fYM9zlaiwxK_4xusLRKV5EMd/view?usp=sharing) |
|
24 |
+
| [data2vec large](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2) layer 18 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [data2vec_large_l18](https://drive.google.com/file/d/1nuRIHaejT-uVi4cluftbT8o_JZqar5SU/view?usp=sharing) |
|
25 |
+
| [Whisper medium](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 24 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_medium_l24](https://drive.google.com/file/d/1V6YJSA2V4iywXrecJAN0oqsa3aHowexZ/view?usp=sharing) |
|
26 |
+
| [Whisper large-v2](https://github.com/openai/whisper/tree/main#available-models-and-languages) layer 32 | [Librispeech](http://www.openslr.org/12) train-clean-100 | [whisper_large_l32](https://drive.google.com/file/d/1k_X7ZMPg8iOeDrIJe70v6CHfFygzufXC/view?usp=sharing) |
|
27 |
+
|
28 |
+
## Speech Tokenization Using Pre-Trained Models
|
29 |
+
|
30 |
+
### Installation
|
31 |
+
|
32 |
+
Please first install RepCodec by
|
33 |
+
|
34 |
+
```
|
35 |
+
git clone https://github.com/mct10/RepCodec.git
|
36 |
+
cd RepCodec
|
37 |
+
pip install .
|
38 |
+
```
|
39 |
+
|
40 |
+
We used Python 3.9.18 and PyTorch 1.12.1 to test the usage, but the code should be compatible with other recent Python
|
41 |
+
and PyTorch versions.
|
42 |
+
|
43 |
+
### Representation Preparation
|
44 |
+
|
45 |
+
We adapt the `dump_hubert_feature.py` script
|
46 |
+
from [fairseq](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans#hubert-feature)
|
47 |
+
to support dumping representations from **data2vec**, **HuBERT**, or **Whisper** encoders.
|
48 |
+
|
49 |
+
If you use our script (`examples/dump_feature.py`), please also install the following packages:
|
50 |
+
|
51 |
+
```
|
52 |
+
pip install npy_append_array soundfile
|
53 |
+
```
|
54 |
+
|
55 |
+
Additionally, if you want to dump representations from
|
56 |
+
|
57 |
+
- **data2vec** or **HuBERT**: please
|
58 |
+
follow [fairseq's instruction](https://github.com/facebookresearch/fairseq#requirements-and-installation) to install
|
59 |
+
the latest fairseq.
|
60 |
+
|
61 |
+
- **Whisper**: please follow [Whispers'instruction](https://github.com/openai/whisper/tree/main#setup) to install the
|
62 |
+
latest
|
63 |
+
Whisper.
|
64 |
+
|
65 |
+
Then, you can follow the given examples to dump representations:
|
66 |
+
|
67 |
+
```
|
68 |
+
# Example 1: dump from HuBERT base layer 9
|
69 |
+
# (for data2vec, simply change "model_type" to data2vec and "ckpt_path" to the path of data2vec model)
|
70 |
+
|
71 |
+
layer=9
|
72 |
+
|
73 |
+
python3 examples/dump_feature.py \
|
74 |
+
--model_type hubert \
|
75 |
+
--tsv_path /path/to/tsv/file \
|
76 |
+
--ckpt_path /path/to/HuBERT/model \
|
77 |
+
--layer ${layer} \
|
78 |
+
--feat_dir /dir/to/save/representations
|
79 |
+
|
80 |
+
|
81 |
+
# Example 2: dump from Whisper medium layer 24
|
82 |
+
|
83 |
+
layer=24
|
84 |
+
|
85 |
+
python3 examples/dump_feature.py \
|
86 |
+
--model_type whisper \
|
87 |
+
--tsv_path /path/to/tsv/file \
|
88 |
+
--whisper_root /directory/to/save/whisper/model \
|
89 |
+
--whisper_name medium \
|
90 |
+
--layer ${layer} \
|
91 |
+
--feat_dir /dir/to/save/representations
|
92 |
+
```
|
93 |
+
|
94 |
+
Explanations about the args:
|
95 |
+
|
96 |
+
- **model_type:** choose from `data2vec`, `hubert`, and `whisper`.
|
97 |
+
|
98 |
+
- **tsv_path:** path of the tsv file.
|
99 |
+
Should have the format of
|
100 |
+
|
101 |
+
```
|
102 |
+
/dir/to/dataset
|
103 |
+
path_of_utterance_1 number_of_frames
|
104 |
+
path_of_utterance_2 number_of_frames
|
105 |
+
```
|
106 |
+
|
107 |
+
You can follow [this script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py)
|
108 |
+
to generate the tsv file.
|
109 |
+
|
110 |
+
For example, by running
|
111 |
+
|
112 |
+
```
|
113 |
+
python wav2vec_manifest.py \
|
114 |
+
/dir/to/LibriSpeech/dev-clean \
|
115 |
+
--dest /dir/to/manifest \
|
116 |
+
--ext flac \
|
117 |
+
--valid-percent 0
|
118 |
+
```
|
119 |
+
|
120 |
+
you can obtain the `dev-clean.tsv` in `/dir/to/manifest` for LibriSpeech. (By default, the output file name
|
121 |
+
is `train.tsv`. Remember to rename the file.)
|
122 |
+
|
123 |
+
It should be similar to:
|
124 |
+
|
125 |
+
```
|
126 |
+
/dir/to/LibriSpeech/dev-clean
|
127 |
+
2277/149896/2277-149896-0026.flac 78720
|
128 |
+
2277/149896/2277-149896-0005.flac 89600
|
129 |
+
2277/149896/2277-149896-0033.flac 45520
|
130 |
+
```
|
131 |
+
|
132 |
+
- **ckpt_path**:
|
133 |
+
must provide for data2vec and HuBERT.
|
134 |
+
You need to download the model
|
135 |
+
from [data2vec website](https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/README.md#speech-2)
|
136 |
+
or [HuBERT website](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#pre-trained-and-fine-tuned-asr-models)
|
137 |
+
yourself.
|
138 |
+
`--ckpt_path` is the path of the data2vec/HuBERT model.
|
139 |
+
- **whisper_root** and **whisper_name**:
|
140 |
+
must provide **BOTH** `--whisper_root` and `--whisper_name` for Whisper.
|
141 |
+
If there is no corresponding model in `--whisper_root`, the script will download for you.
|
142 |
+
|
143 |
+
- **layer**:
|
144 |
+
which Transformer encoder layer of the model should the representations be extracted from.
|
145 |
+
It is **1-based**.
|
146 |
+
For example, if layer=9, then the outputs from the 9<sup>th</sup> Transformer encoder layer are dumped.
|
147 |
+
Range: [1, number of Transformer encoder layers]
|
148 |
+
|
149 |
+
- **feat_dir**: The output representations will be saved to `${feat_dir}/0_1.npy`
|
150 |
+
and `${feat_dir}/0_1.len`.
|
151 |
+
|
152 |
+
For other useful functionalities (e.g., sharding), please check the argument list in `examples/dump_feature.py`.
|
153 |
+
|
154 |
+
### Command Line Usage
|
155 |
+
|
156 |
+
We expect to have `${feat_dir}/0_1.npy` and `${feat_dir}/0_1.len` in the provided
|
157 |
+
directory `/dir/to/representaitons`.
|
158 |
+
|
159 |
+
Also, the tsv file should be the **same** as the one used in [Representation Preparation](#representation-preparation).
|
160 |
+
|
161 |
+
```
|
162 |
+
repcodec /dir/to/representaitons \
|
163 |
+
--model /path/to/repcodec/model \
|
164 |
+
--tsv_path /path/to/tsv/file \
|
165 |
+
[--model_config_path /path/to/train/config] \
|
166 |
+
[--use_gpu] \
|
167 |
+
[--out_dir /path/to/output]
|
168 |
+
```
|
169 |
+
|
170 |
+
If you trained the model yourself following [Training New RepCodec Models](#training-new-repcodec-models),
|
171 |
+
please provide the training config file using `--model_config_path`.
|
172 |
+
If you use the model we provide [here](#repcodec-models), then you do not have to provide that.
|
173 |
+
|
174 |
+
This command will tokenize the representations and the output discrete tokens will be saved to `${out_dir}/tokens`.
|
175 |
+
The tokens are in the same order as the provided tsv file.
|
176 |
+
|
177 |
+
An example of the output file:
|
178 |
+
|
179 |
+
```
|
180 |
+
/dir/to/LibriSpeech/dev-clean
|
181 |
+
2277/149896/2277-149896-0026.flac 696 696 198 198 198 498 ...
|
182 |
+
2277/149896/2277-149896-0005.flac 696 696 198 198 198 907 ...
|
183 |
+
2277/149896/2277-149896-0033.flac 696 696 198 198 198 696 ...
|
184 |
+
```
|
185 |
+
|
186 |
+
Under `examples/tokens`, we provide some token files as references. They are obtained from LibriSpeech dev-clean subset
|
187 |
+
using the 6 types of representations and corresponding [RepCodec Models](#repcodec-models).
|
188 |
+
Your results should be very similar to ours.
|
189 |
+
|
190 |
+
### Python Usage
|
191 |
+
|
192 |
+
```python
|
193 |
+
import torch
|
194 |
+
import yaml
|
195 |
+
|
196 |
+
from repcodec.RepCodec import RepCodec
|
197 |
+
|
198 |
+
# for feature types of HubERT base & data2vec base, please use repcodec_dim768.yaml;
|
199 |
+
# for feature types of HuBERT large & data2vec large & Whisper medium, please use repcodec_dim1024.yaml;
|
200 |
+
# for feature types of Whisper large-v2, please use repcodec_dim1280.yaml
|
201 |
+
config = "repcodec/configs/repcodec_dim768.yaml"
|
202 |
+
with open(config) as fp:
|
203 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
204 |
+
|
205 |
+
model = RepCodec(**conf)
|
206 |
+
model.load_state_dict(torch.load("./hubert_base_l9.pkl", map_location="cpu")["model"]["repcodec"])
|
207 |
+
model.quantizer.initial()
|
208 |
+
model.eval()
|
209 |
+
|
210 |
+
# input shape: (batch size, hidden dim, sequence length)
|
211 |
+
random_features = torch.randn(size=(1, 768, 100))
|
212 |
+
with torch.no_grad():
|
213 |
+
x = model.encoder(random_features)
|
214 |
+
z = model.projector(x)
|
215 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
216 |
+
tokens = idx.cpu().data.numpy().tolist()[0]
|
217 |
+
```
|
218 |
+
|
219 |
+
## Training New RepCodec Models
|
220 |
+
|
221 |
+
We use a config file to set up all the training configurations, e.g., data, model architecture,
|
222 |
+
optimizer, scheduler.
|
223 |
+
We provide an example [here](./train_configs/ex_dim768_mse.yaml).
|
224 |
+
|
225 |
+
Please first install required packages following [Installation](#installation)
|
226 |
+
and prepare the representations following [Representation Preparation](#representation-preparation).
|
227 |
+
|
228 |
+
The input data directory is expected to have the following structure
|
229 |
+
```
|
230 |
+
/dir/to/representations/
|
231 |
+
train_set_name/
|
232 |
+
0_1.npy
|
233 |
+
0_1.len
|
234 |
+
valid_set_name/
|
235 |
+
0_1.npy
|
236 |
+
0_1.len
|
237 |
+
test_set_name/
|
238 |
+
0_1.npy
|
239 |
+
0_1.len
|
240 |
+
```
|
241 |
+
|
242 |
+
The names of subsets should be the same as the fields in the config file.
|
243 |
+
|
244 |
+
Then, you can run training by
|
245 |
+
```
|
246 |
+
python train.py \
|
247 |
+
-c /path/to/config/file \
|
248 |
+
--tag $tag \
|
249 |
+
--exp_root exp
|
250 |
+
```
|
251 |
+
|
252 |
+
`tag` is the name of the output folder.
|
253 |
+
All outputs will be saved to `exp_root/tag/`.
|
254 |
+
|
255 |
+
## Acknowledge
|
256 |
+
|
257 |
+
Our implementation is based on [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec).
|
258 |
+
We thank them for open-sourcing their code!
|
259 |
+
|
260 |
+
## Citation
|
261 |
+
|
262 |
+
If you find our work useful, please cite the following article.
|
263 |
+
|
264 |
+
```
|
265 |
+
@misc{huang2023repcodec,
|
266 |
+
title={RepCodec: A Speech Representation Codec for Speech Tokenization},
|
267 |
+
author={Zhichao Huang and Chutong Meng and Tom Ko},
|
268 |
+
year={2023},
|
269 |
+
eprint={2309.00169},
|
270 |
+
archivePrefix={arXiv},
|
271 |
+
primaryClass={eess.AS}
|
272 |
+
}
|
273 |
+
```
|
prompting/RepCodec/dataloader/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .collater import *
|
2 |
+
from .dataset import *
|
prompting/RepCodec/dataloader/collater.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class ReprCollater(object):
|
13 |
+
def __call__(self, batch):
|
14 |
+
xs = []
|
15 |
+
for b in batch:
|
16 |
+
if b is not None:
|
17 |
+
xs.append(b)
|
18 |
+
|
19 |
+
x_batch = np.stack(xs, axis=0)
|
20 |
+
x_batch = torch.tensor(x_batch, dtype=torch.float).transpose(1, 2) # (B, T, C) -> (B, C, T)
|
21 |
+
|
22 |
+
return x_batch
|
prompting/RepCodec/dataloader/dataset.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import glob
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
logging.basicConfig(
|
17 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
18 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
19 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
20 |
+
)
|
21 |
+
logger = logging.getLogger("dataset")
|
22 |
+
|
23 |
+
|
24 |
+
class ReprDataset(Dataset):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
data_dir: str,
|
28 |
+
batch_len: int,
|
29 |
+
):
|
30 |
+
self.batch_len = batch_len
|
31 |
+
|
32 |
+
self.blocks = self._load_blocks(data_dir)
|
33 |
+
self.offsets = self._load_offsets(data_dir)
|
34 |
+
assert len(self.blocks) == len(self.offsets)
|
35 |
+
# check len
|
36 |
+
for i in range(len(self.blocks)):
|
37 |
+
assert self.blocks[i].shape[0] == self.offsets[i][-1]
|
38 |
+
|
39 |
+
self.n_examples = np.cumsum([0] + [offset.shape[0] - 1 for offset in self.offsets])
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
return self.n_examples[-1]
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
# find which block
|
46 |
+
block_id = -1
|
47 |
+
for n in range(len(self.n_examples) - 1):
|
48 |
+
if self.n_examples[n] <= idx < self.n_examples[n + 1]:
|
49 |
+
block_id = n
|
50 |
+
break
|
51 |
+
assert 0 <= block_id < len(self.blocks), f"Failed to find {idx}"
|
52 |
+
block_offset = idx - self.n_examples[block_id]
|
53 |
+
start = self.offsets[block_id][block_offset]
|
54 |
+
end = self.offsets[block_id][block_offset + 1]
|
55 |
+
|
56 |
+
# randomly choose a slice
|
57 |
+
if end - start < self.batch_len:
|
58 |
+
return None
|
59 |
+
elif end - start == self.batch_len:
|
60 |
+
return self.blocks[block_id][start:end]
|
61 |
+
else:
|
62 |
+
start_offset = np.random.randint(low=start, high=end - self.batch_len)
|
63 |
+
return self.blocks[block_id][start_offset:start_offset + self.batch_len]
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def _load_blocks(feat_dir: str) -> List[np.ndarray]:
|
67 |
+
# e.g., 0_2.npy, 1_2.npy
|
68 |
+
file_names = glob.glob(os.path.join(feat_dir, "*.npy"), recursive=False)
|
69 |
+
# sort by index
|
70 |
+
file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0]))
|
71 |
+
logger.info(f"Found following blocks: {file_names}")
|
72 |
+
blocks = [np.load(name, mmap_mode="r") for name in file_names]
|
73 |
+
return blocks
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def _load_offsets(feat_dir: str):
|
77 |
+
def load_lens(file_name: str):
|
78 |
+
with open(file_name, mode="r") as fp:
|
79 |
+
res = fp.read().strip().split("\n")
|
80 |
+
# for easy use. [res[i], res[i+1]) denotes the range for ith element
|
81 |
+
res = [0] + [int(r) for r in res]
|
82 |
+
return np.cumsum(res, dtype=int)
|
83 |
+
|
84 |
+
# e.g., 0_2.len, 1_2.len
|
85 |
+
file_names = glob.glob(os.path.join(feat_dir, "*.len"), recursive=False)
|
86 |
+
file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0]))
|
87 |
+
file_lens = []
|
88 |
+
for name in file_names:
|
89 |
+
file_lens.append(load_lens(name))
|
90 |
+
return file_lens
|
prompting/RepCodec/examples/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "72bf1b45-66fd-450d-8d5c-bec9e0b3d08f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from data2vec_feature_reader import Data2vecFeatureReader\n",
|
11 |
+
"\n",
|
12 |
+
"reader = Data2vecFeatureReader(\"./../../models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"id": "84a9d238-048a-4772-a47b-5aadc50f36df",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [
|
21 |
+
{
|
22 |
+
"data": {
|
23 |
+
"application/vnd.jupyter.widget-view+json": {
|
24 |
+
"model_id": "490421d1c2f54cca9855f1a5397185f8",
|
25 |
+
"version_major": 2,
|
26 |
+
"version_minor": 0
|
27 |
+
},
|
28 |
+
"text/plain": [
|
29 |
+
"Loading dataset shards: 0%| | 0/45 [00:00<?, ?it/s]"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
"metadata": {},
|
33 |
+
"output_type": "display_data"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"data": {
|
37 |
+
"application/vnd.jupyter.widget-view+json": {
|
38 |
+
"model_id": "be44942581b34d5388b0264e7b40d472",
|
39 |
+
"version_major": 2,
|
40 |
+
"version_minor": 0
|
41 |
+
},
|
42 |
+
"text/plain": [
|
43 |
+
"Loading dataset shards: 0%| | 0/60 [00:00<?, ?it/s]"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
"metadata": {},
|
47 |
+
"output_type": "display_data"
|
48 |
+
}
|
49 |
+
],
|
50 |
+
"source": [
|
51 |
+
"from datasets import load_dataset\n",
|
52 |
+
"from tqdm import tqdm\n",
|
53 |
+
"import pandas as pd\n",
|
54 |
+
"\n",
|
55 |
+
"cache_dir = \"./../../../cache\"\n",
|
56 |
+
"\n",
|
57 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
+
"id": "cffd49ca-3524-4ac4-8ba5-bc4fcc9e0f53",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"data": {
|
68 |
+
"text/plain": [
|
69 |
+
"RepCodec(\n",
|
70 |
+
" (encoder): Encoder(\n",
|
71 |
+
" (conv): Conv1d(\n",
|
72 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
73 |
+
" )\n",
|
74 |
+
" (conv_blocks): ModuleList(\n",
|
75 |
+
" (0-1): 2 x EncoderBlock(\n",
|
76 |
+
" (res_units): ModuleList(\n",
|
77 |
+
" (0-1): 2 x ResidualUnit(\n",
|
78 |
+
" (activation): ELU(alpha=1.0)\n",
|
79 |
+
" (conv1): Conv1d(\n",
|
80 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
81 |
+
" )\n",
|
82 |
+
" (conv2): Conv1d1x1(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n",
|
83 |
+
" )\n",
|
84 |
+
" )\n",
|
85 |
+
" (conv): Conv1d(\n",
|
86 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
|
87 |
+
" )\n",
|
88 |
+
" )\n",
|
89 |
+
" )\n",
|
90 |
+
" )\n",
|
91 |
+
" (decoder): Decoder(\n",
|
92 |
+
" (conv1): Conv1d(\n",
|
93 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
94 |
+
" )\n",
|
95 |
+
" (conv_blocks): ModuleList(\n",
|
96 |
+
" (0-1): 2 x DecoderBlock(\n",
|
97 |
+
" (conv): Conv1d(\n",
|
98 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))\n",
|
99 |
+
" )\n",
|
100 |
+
" (res_units): ModuleList(\n",
|
101 |
+
" (0-1): 2 x ResidualUnit(\n",
|
102 |
+
" (activation): ELU(alpha=1.0)\n",
|
103 |
+
" (conv1): Conv1d(\n",
|
104 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
105 |
+
" )\n",
|
106 |
+
" (conv2): Conv1d1x1(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n",
|
107 |
+
" )\n",
|
108 |
+
" )\n",
|
109 |
+
" )\n",
|
110 |
+
" )\n",
|
111 |
+
" (conv2): Conv1d(\n",
|
112 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
113 |
+
" )\n",
|
114 |
+
" )\n",
|
115 |
+
" (projector): Projector(\n",
|
116 |
+
" (project): Conv1d(\n",
|
117 |
+
" (conv): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
|
118 |
+
" )\n",
|
119 |
+
" )\n",
|
120 |
+
" (quantizer): Quantizer(\n",
|
121 |
+
" (codebook): ResidualVQ(\n",
|
122 |
+
" (layers): ModuleList(\n",
|
123 |
+
" (0): VectorQuantize()\n",
|
124 |
+
" )\n",
|
125 |
+
" )\n",
|
126 |
+
" )\n",
|
127 |
+
")"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
"execution_count": 3,
|
131 |
+
"metadata": {},
|
132 |
+
"output_type": "execute_result"
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"from repcodec.RepCodec import RepCodec\n",
|
137 |
+
"import torch\n",
|
138 |
+
"import yaml\n",
|
139 |
+
"\n",
|
140 |
+
"config = \"./../repcodec/configs/repcodec_dim1024.yaml\"\n",
|
141 |
+
"with open(config) as fp:\n",
|
142 |
+
" conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
|
143 |
+
"\n",
|
144 |
+
"model = RepCodec(**conf)\n",
|
145 |
+
"model.load_state_dict(torch.load(\"./../../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
|
146 |
+
"model.quantizer.initial()\n",
|
147 |
+
"model.eval()"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": 22,
|
153 |
+
"id": "a9a1731e-052c-4af0-a29c-b171a988b300",
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [
|
156 |
+
{
|
157 |
+
"ename": "RuntimeError",
|
158 |
+
"evalue": "Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same",
|
159 |
+
"output_type": "error",
|
160 |
+
"traceback": [
|
161 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
162 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
163 |
+
"Cell \u001b[0;32mIn[22], line 27\u001b[0m\n\u001b[1;32m 23\u001b[0m feat\u001b[38;5;241m.\u001b[39mappend(feat_chunk)\n\u001b[1;32m 25\u001b[0m features \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(feat, \u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 27\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m z \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mprojector(x)\n\u001b[1;32m 29\u001b[0m _, idx \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mquantizer\u001b[38;5;241m.\u001b[39mcodebook\u001b[38;5;241m.\u001b[39mforward_index(z\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m))\n",
|
164 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
165 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
166 |
+
"File \u001b[0;32m/jupyter_workspace/users/Darshan/RepCodec/repcodec/modules/encoder.py:86\u001b[0m, in \u001b[0;36mEncoder.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 86\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_blocks):\n\u001b[1;32m 88\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv_blocks[i](x)\n",
|
167 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
168 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
169 |
+
"File \u001b[0;32m/jupyter_workspace/users/Darshan/RepCodec/repcodec/layers/conv_layer.py:55\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124;03m x (Tensor): Float tensor variable with the shape (B, C, T).\u001b[39;00m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;124;03m Returns:\u001b[39;00m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;124;03m Tensor: Float tensor variable with the shape (B, C, T).\u001b[39;00m\n\u001b[1;32m 54\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 55\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
|
170 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
171 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
172 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:310\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
173 |
+
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py:306\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 304\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 305\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 307\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n",
|
174 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same"
|
175 |
+
]
|
176 |
+
}
|
177 |
+
],
|
178 |
+
"source": [
|
179 |
+
"import torch.nn.functional as F\n",
|
180 |
+
"\n",
|
181 |
+
"sample = dataset[\"train.clean.100\"][1]\n",
|
182 |
+
"\n",
|
183 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
184 |
+
"\n",
|
185 |
+
"with torch.no_grad():\n",
|
186 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
187 |
+
" if reader.task.cfg.normalize:\n",
|
188 |
+
" x = F.layer_norm(x, x.shape)\n",
|
189 |
+
" x = x.view(1, -1)\n",
|
190 |
+
"\n",
|
191 |
+
" feat = []\n",
|
192 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
193 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
194 |
+
" res = reader.model.extract_features(\n",
|
195 |
+
" source=x_chunk,\n",
|
196 |
+
" padding_mask=None,\n",
|
197 |
+
" mask=False,\n",
|
198 |
+
" layer=reader.layer,\n",
|
199 |
+
" )\n",
|
200 |
+
" feat_chunk = res[\"x\"]\n",
|
201 |
+
" feat.append(feat_chunk)\n",
|
202 |
+
" \n",
|
203 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
204 |
+
"\n",
|
205 |
+
" x = model.encoder(features)\n",
|
206 |
+
" z = model.projector(x)\n",
|
207 |
+
" _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
208 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": 14,
|
214 |
+
"id": "d51709a9-6fb3-450b-a517-005367095663",
|
215 |
+
"metadata": {},
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"data": {
|
219 |
+
"text/plain": [
|
220 |
+
"torch.Size([1, 804, 1024])"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
"execution_count": 14,
|
224 |
+
"metadata": {},
|
225 |
+
"output_type": "execute_result"
|
226 |
+
}
|
227 |
+
],
|
228 |
+
"source": [
|
229 |
+
"features.shape"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": 8,
|
235 |
+
"id": "dfc977d7-f27c-40d7-b545-fbdf26728cbe",
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [
|
238 |
+
{
|
239 |
+
"data": {
|
240 |
+
"text/plain": [
|
241 |
+
"torch.Size([726, 1024])"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
"execution_count": 8,
|
245 |
+
"metadata": {},
|
246 |
+
"output_type": "execute_result"
|
247 |
+
}
|
248 |
+
],
|
249 |
+
"source": [
|
250 |
+
"feat.shape"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": null,
|
256 |
+
"id": "1810e6dc-2ece-4aca-a29a-e1933b8ce82a",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"import logging\n",
|
261 |
+
"import os\n",
|
262 |
+
"import sys\n",
|
263 |
+
"\n",
|
264 |
+
"import tqdm\n",
|
265 |
+
"from npy_append_array import NpyAppendArray\n",
|
266 |
+
"\n",
|
267 |
+
"def get_shard_range(tot, nshard, rank):\n",
|
268 |
+
" assert rank < nshard and rank >= 0, f\"invaid rank/nshard {rank}/{nshard}\"\n",
|
269 |
+
" start = round(tot / nshard * rank)\n",
|
270 |
+
" end = round(tot / nshard * (rank + 1))\n",
|
271 |
+
" assert start < end, f\"start={start}, end={end}\"\n",
|
272 |
+
" logger.info(\n",
|
273 |
+
" f\"rank {rank} of {nshard}, process {end-start} \"\n",
|
274 |
+
" f\"({start}-{end}) out of {tot}\"\n",
|
275 |
+
" )\n",
|
276 |
+
" return start, end\n",
|
277 |
+
"\n",
|
278 |
+
"def get_path_iterator(tsv, nshard, rank):\n",
|
279 |
+
" with open(tsv, \"r\") as f:\n",
|
280 |
+
" root = f.readline().rstrip()\n",
|
281 |
+
" lines = [line.rstrip() for line in f]\n",
|
282 |
+
" start, end = get_shard_range(len(lines), nshard, rank)\n",
|
283 |
+
" lines = lines[start:end]\n",
|
284 |
+
" def iterate():\n",
|
285 |
+
" for line in lines:\n",
|
286 |
+
" subpath, nsample = line.split(\"\\t\")\n",
|
287 |
+
" yield f\"{root}/{subpath}\", int(nsample)\n",
|
288 |
+
" return iterate, len(lines)\n",
|
289 |
+
"\n",
|
290 |
+
"def dump_feature(reader, generator, num, nshard, rank, feat_dir):\n",
|
291 |
+
" iterator = generator()\n",
|
292 |
+
"\n",
|
293 |
+
" feat_path = f\"{feat_dir}/{rank}_{nshard}.npy\"\n",
|
294 |
+
" leng_path = f\"{feat_dir}/{rank}_{nshard}.len\"\n",
|
295 |
+
"\n",
|
296 |
+
" os.makedirs(feat_dir, exist_ok=True)\n",
|
297 |
+
" if os.path.exists(feat_path):\n",
|
298 |
+
" os.remove(feat_path)\n",
|
299 |
+
"\n",
|
300 |
+
" feat_f = NpyAppendArray(feat_path)\n",
|
301 |
+
" with open(leng_path, \"w\") as leng_f:\n",
|
302 |
+
" for path, nsample in tqdm.tqdm(iterator, total=num):\n",
|
303 |
+
" feat = reader.get_feats(path, nsample)\n",
|
304 |
+
" feat_f.append(feat.cpu().numpy())\n",
|
305 |
+
" leng_f.write(f\"{len(feat)}\\n\")\n",
|
306 |
+
" logger.info(\"finished successfully\")\n",
|
307 |
+
"\n",
|
308 |
+
"generator, num = get_path_iterator(tsv_path, nshard, rank)\n",
|
309 |
+
"dump_feature(reader, generator, num, nshard, rank, feat_dir)"
|
310 |
+
]
|
311 |
+
}
|
312 |
+
],
|
313 |
+
"metadata": {
|
314 |
+
"kernelspec": {
|
315 |
+
"display_name": "Python 3 (ipykernel)",
|
316 |
+
"language": "python",
|
317 |
+
"name": "python3"
|
318 |
+
},
|
319 |
+
"language_info": {
|
320 |
+
"codemirror_mode": {
|
321 |
+
"name": "ipython",
|
322 |
+
"version": 3
|
323 |
+
},
|
324 |
+
"file_extension": ".py",
|
325 |
+
"mimetype": "text/x-python",
|
326 |
+
"name": "python",
|
327 |
+
"nbconvert_exporter": "python",
|
328 |
+
"pygments_lexer": "ipython3",
|
329 |
+
"version": "3.8.10"
|
330 |
+
}
|
331 |
+
},
|
332 |
+
"nbformat": 4,
|
333 |
+
"nbformat_minor": 5
|
334 |
+
}
|
prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_audio-checkpoint.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.distributed as dist
|
21 |
+
|
22 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
23 |
+
from fairseq.data.data_utils import compute_mask_indices
|
24 |
+
from fairseq.models import BaseFairseqModel, register_model
|
25 |
+
from fairseq.models.wav2vec import (
|
26 |
+
ConvFeatureExtractionModel,
|
27 |
+
Wav2Vec2Config,
|
28 |
+
TransformerEncoder,
|
29 |
+
)
|
30 |
+
from fairseq.modules import (
|
31 |
+
GradMultiply,
|
32 |
+
LayerNorm,
|
33 |
+
)
|
34 |
+
from fairseq.utils import index_put
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
42 |
+
|
43 |
+
loss_beta: float = field(
|
44 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
45 |
+
)
|
46 |
+
loss_scale: Optional[float] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={
|
49 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
50 |
+
},
|
51 |
+
)
|
52 |
+
average_top_k_layers: int = field(
|
53 |
+
default=8, metadata={"help": "how many layers to average"}
|
54 |
+
)
|
55 |
+
|
56 |
+
layer_norm_target_layer: bool = False
|
57 |
+
instance_norm_target_layer: bool = False
|
58 |
+
instance_norm_targets: bool = False
|
59 |
+
layer_norm_targets: bool = False
|
60 |
+
batch_norm_target_layer: bool = False
|
61 |
+
group_norm_target_layer: bool = False
|
62 |
+
|
63 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
64 |
+
ema_end_decay: float = field(
|
65 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
66 |
+
)
|
67 |
+
|
68 |
+
# when to finish annealing ema decay rate
|
69 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
70 |
+
|
71 |
+
ema_transformer_only: bool = field(
|
72 |
+
default=True,
|
73 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
74 |
+
)
|
75 |
+
ema_layers_only: bool = field(
|
76 |
+
default=True,
|
77 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
78 |
+
)
|
79 |
+
|
80 |
+
max_update: int = II("optimization.max_update")
|
81 |
+
|
82 |
+
min_target_var: float = field(
|
83 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
84 |
+
)
|
85 |
+
min_pred_var: float = field(
|
86 |
+
default=0.01,
|
87 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
92 |
+
r = end - start
|
93 |
+
pct_remaining = 1 - curr_step / total_steps
|
94 |
+
return end - r * pct_remaining
|
95 |
+
|
96 |
+
|
97 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
98 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
99 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
100 |
+
super().__init__()
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
104 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
105 |
+
|
106 |
+
self.ema = None
|
107 |
+
self.embed = cfg.encoder_embed_dim
|
108 |
+
|
109 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
110 |
+
self.loss_beta = cfg.loss_beta
|
111 |
+
self.loss_scale = cfg.loss_scale
|
112 |
+
|
113 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
114 |
+
conv_layers=feature_enc_layers,
|
115 |
+
dropout=0.0,
|
116 |
+
mode=cfg.extractor_mode,
|
117 |
+
conv_bias=cfg.conv_bias,
|
118 |
+
)
|
119 |
+
|
120 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
121 |
+
|
122 |
+
self.mask_prob = cfg.mask_prob
|
123 |
+
self.mask_selection = cfg.mask_selection
|
124 |
+
self.mask_other = cfg.mask_other
|
125 |
+
self.mask_length = cfg.mask_length
|
126 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
127 |
+
self.mask_min_space = cfg.mask_min_space
|
128 |
+
|
129 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
130 |
+
self.mask_channel_before = cfg.mask_channel_before
|
131 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
132 |
+
self.mask_channel_other = cfg.mask_channel_other
|
133 |
+
self.mask_channel_length = cfg.mask_channel_length
|
134 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
135 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
136 |
+
|
137 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
138 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
139 |
+
|
140 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
141 |
+
|
142 |
+
self.mask_emb = nn.Parameter(
|
143 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
144 |
+
)
|
145 |
+
|
146 |
+
self.encoder = TransformerEncoder(cfg)
|
147 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
148 |
+
|
149 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
150 |
+
|
151 |
+
self.num_updates = 0
|
152 |
+
|
153 |
+
def make_ema_teacher(self):
|
154 |
+
ema_config = EMAModuleConfig(
|
155 |
+
ema_decay=self.cfg.ema_decay,
|
156 |
+
ema_fp32=True,
|
157 |
+
)
|
158 |
+
skip_keys = set()
|
159 |
+
if self.cfg.ema_layers_only:
|
160 |
+
self.cfg.ema_transformer_only = True
|
161 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
162 |
+
skip_keys.add(f"pos_conv.{k}")
|
163 |
+
|
164 |
+
self.ema = EMAModule(
|
165 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
166 |
+
ema_config,
|
167 |
+
skip_keys=skip_keys,
|
168 |
+
)
|
169 |
+
|
170 |
+
def set_num_updates(self, num_updates):
|
171 |
+
super().set_num_updates(num_updates)
|
172 |
+
|
173 |
+
if self.ema is None and self.final_proj is not None:
|
174 |
+
logger.info(f"making ema teacher")
|
175 |
+
self.make_ema_teacher()
|
176 |
+
elif self.training and self.ema is not None:
|
177 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
178 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
179 |
+
decay = self.cfg.ema_end_decay
|
180 |
+
else:
|
181 |
+
decay = get_annealed_rate(
|
182 |
+
self.cfg.ema_decay,
|
183 |
+
self.cfg.ema_end_decay,
|
184 |
+
num_updates,
|
185 |
+
self.cfg.ema_anneal_end_step,
|
186 |
+
)
|
187 |
+
self.ema.set_decay(decay)
|
188 |
+
if self.ema.get_decay() < 1:
|
189 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
190 |
+
|
191 |
+
self.num_updates = num_updates
|
192 |
+
|
193 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
194 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
195 |
+
|
196 |
+
if self.ema is not None:
|
197 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
198 |
+
|
199 |
+
return state
|
200 |
+
|
201 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
202 |
+
if self.ema is not None:
|
203 |
+
k = prefix + "_ema"
|
204 |
+
assert k in state_dict
|
205 |
+
self.ema.restore(state_dict[k], True)
|
206 |
+
del state_dict[k]
|
207 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
211 |
+
"""Build a new model instance."""
|
212 |
+
|
213 |
+
return cls(cfg)
|
214 |
+
|
215 |
+
def apply_mask(
|
216 |
+
self,
|
217 |
+
x,
|
218 |
+
padding_mask,
|
219 |
+
mask_indices=None,
|
220 |
+
mask_channel_indices=None,
|
221 |
+
):
|
222 |
+
B, T, C = x.shape
|
223 |
+
|
224 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
225 |
+
mask_channel_indices = compute_mask_indices(
|
226 |
+
(B, C),
|
227 |
+
None,
|
228 |
+
self.mask_channel_prob,
|
229 |
+
self.mask_channel_length,
|
230 |
+
self.mask_channel_selection,
|
231 |
+
self.mask_channel_other,
|
232 |
+
no_overlap=self.no_mask_channel_overlap,
|
233 |
+
min_space=self.mask_channel_min_space,
|
234 |
+
)
|
235 |
+
mask_channel_indices = (
|
236 |
+
torch.from_numpy(mask_channel_indices)
|
237 |
+
.to(x.device)
|
238 |
+
.unsqueeze(1)
|
239 |
+
.expand(-1, T, -1)
|
240 |
+
)
|
241 |
+
x[mask_channel_indices] = 0
|
242 |
+
|
243 |
+
if self.mask_prob > 0:
|
244 |
+
if mask_indices is None:
|
245 |
+
mask_indices = compute_mask_indices(
|
246 |
+
(B, T),
|
247 |
+
padding_mask,
|
248 |
+
self.mask_prob,
|
249 |
+
self.mask_length,
|
250 |
+
self.mask_selection,
|
251 |
+
self.mask_other,
|
252 |
+
min_masks=1,
|
253 |
+
no_overlap=self.no_mask_overlap,
|
254 |
+
min_space=self.mask_min_space,
|
255 |
+
require_same_masks=self.cfg.require_same_masks,
|
256 |
+
mask_dropout=self.cfg.mask_dropout,
|
257 |
+
)
|
258 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
259 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
260 |
+
else:
|
261 |
+
mask_indices = None
|
262 |
+
|
263 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
264 |
+
if mask_channel_indices is None:
|
265 |
+
mask_channel_indices = compute_mask_indices(
|
266 |
+
(B, C),
|
267 |
+
None,
|
268 |
+
self.mask_channel_prob,
|
269 |
+
self.mask_channel_length,
|
270 |
+
self.mask_channel_selection,
|
271 |
+
self.mask_channel_other,
|
272 |
+
no_overlap=self.no_mask_channel_overlap,
|
273 |
+
min_space=self.mask_channel_min_space,
|
274 |
+
)
|
275 |
+
mask_channel_indices = (
|
276 |
+
torch.from_numpy(mask_channel_indices)
|
277 |
+
.to(x.device)
|
278 |
+
.unsqueeze(1)
|
279 |
+
.expand(-1, T, -1)
|
280 |
+
)
|
281 |
+
x = index_put(x, mask_channel_indices, 0)
|
282 |
+
|
283 |
+
return x, mask_indices
|
284 |
+
|
285 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
286 |
+
"""
|
287 |
+
Computes the output length of the convolutional layers
|
288 |
+
"""
|
289 |
+
|
290 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
291 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
292 |
+
|
293 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
294 |
+
|
295 |
+
for i in range(len(conv_cfg_list)):
|
296 |
+
input_lengths = _conv_out_length(
|
297 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
298 |
+
)
|
299 |
+
|
300 |
+
return input_lengths.to(torch.long)
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
source,
|
305 |
+
padding_mask=None,
|
306 |
+
mask=True,
|
307 |
+
features_only=False,
|
308 |
+
layer=None,
|
309 |
+
mask_indices=None,
|
310 |
+
mask_channel_indices=None,
|
311 |
+
padding_count=None,
|
312 |
+
):
|
313 |
+
features = source
|
314 |
+
|
315 |
+
if self.feature_grad_mult > 0:
|
316 |
+
features = self.feature_extractor(features)
|
317 |
+
if self.feature_grad_mult != 1.0:
|
318 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
319 |
+
else:
|
320 |
+
with torch.no_grad():
|
321 |
+
features = self.feature_extractor(features)
|
322 |
+
|
323 |
+
features = features.transpose(1, 2)
|
324 |
+
|
325 |
+
features = self.layer_norm(features)
|
326 |
+
|
327 |
+
orig_padding_mask = padding_mask
|
328 |
+
|
329 |
+
if padding_mask is not None and padding_mask.any():
|
330 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
331 |
+
# apply conv formula to get real output_lengths
|
332 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
333 |
+
|
334 |
+
padding_mask = torch.zeros(
|
335 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
336 |
+
)
|
337 |
+
|
338 |
+
# these two operations makes sure that all values
|
339 |
+
# before the output lengths indices are attended to
|
340 |
+
padding_mask[
|
341 |
+
(
|
342 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
343 |
+
output_lengths - 1,
|
344 |
+
)
|
345 |
+
] = 1
|
346 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
347 |
+
else:
|
348 |
+
padding_mask = None
|
349 |
+
|
350 |
+
if self.post_extract_proj is not None:
|
351 |
+
features = self.post_extract_proj(features)
|
352 |
+
|
353 |
+
pre_encoder_features = None
|
354 |
+
if self.cfg.ema_transformer_only:
|
355 |
+
pre_encoder_features = features.clone()
|
356 |
+
|
357 |
+
features = self.dropout_input(features)
|
358 |
+
|
359 |
+
if mask:
|
360 |
+
x, mask_indices = self.apply_mask(
|
361 |
+
features,
|
362 |
+
padding_mask,
|
363 |
+
mask_indices=mask_indices,
|
364 |
+
mask_channel_indices=mask_channel_indices,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
x = features
|
368 |
+
mask_indices = None
|
369 |
+
|
370 |
+
x, layer_results = self.encoder(
|
371 |
+
x,
|
372 |
+
padding_mask=padding_mask,
|
373 |
+
layer=layer,
|
374 |
+
)
|
375 |
+
|
376 |
+
if features_only:
|
377 |
+
return {
|
378 |
+
"x": x,
|
379 |
+
"padding_mask": padding_mask,
|
380 |
+
"layer_results": layer_results,
|
381 |
+
}
|
382 |
+
|
383 |
+
result = {
|
384 |
+
"losses": {},
|
385 |
+
}
|
386 |
+
|
387 |
+
with torch.no_grad():
|
388 |
+
self.ema.model.eval()
|
389 |
+
|
390 |
+
if self.cfg.ema_transformer_only:
|
391 |
+
y, layer_results = self.ema.model.extract_features(
|
392 |
+
pre_encoder_features,
|
393 |
+
padding_mask=padding_mask,
|
394 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
395 |
+
)
|
396 |
+
y = {
|
397 |
+
"x": y,
|
398 |
+
"padding_mask": padding_mask,
|
399 |
+
"layer_results": layer_results,
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
y = self.ema.model.extract_features(
|
403 |
+
source=source,
|
404 |
+
padding_mask=orig_padding_mask,
|
405 |
+
mask=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
409 |
+
|
410 |
+
permuted = False
|
411 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
412 |
+
target_layer_results = [
|
413 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
414 |
+
]
|
415 |
+
permuted = True
|
416 |
+
|
417 |
+
if self.cfg.batch_norm_target_layer:
|
418 |
+
target_layer_results = [
|
419 |
+
F.batch_norm(
|
420 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
421 |
+
)
|
422 |
+
for tl in target_layer_results
|
423 |
+
]
|
424 |
+
|
425 |
+
if self.cfg.instance_norm_target_layer:
|
426 |
+
target_layer_results = [
|
427 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
428 |
+
]
|
429 |
+
|
430 |
+
if permuted:
|
431 |
+
target_layer_results = [
|
432 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
433 |
+
]
|
434 |
+
|
435 |
+
if self.cfg.group_norm_target_layer:
|
436 |
+
target_layer_results = [
|
437 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
438 |
+
for tl in target_layer_results
|
439 |
+
]
|
440 |
+
|
441 |
+
if self.cfg.layer_norm_target_layer:
|
442 |
+
target_layer_results = [
|
443 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
444 |
+
for tl in target_layer_results
|
445 |
+
]
|
446 |
+
|
447 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
448 |
+
|
449 |
+
if self.cfg.layer_norm_targets:
|
450 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
451 |
+
|
452 |
+
if self.cfg.instance_norm_targets:
|
453 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
454 |
+
|
455 |
+
if not permuted:
|
456 |
+
y = y.transpose(0, 1)
|
457 |
+
|
458 |
+
y = y[mask_indices]
|
459 |
+
|
460 |
+
x = x[mask_indices]
|
461 |
+
x = self.final_proj(x)
|
462 |
+
|
463 |
+
sz = x.size(-1)
|
464 |
+
|
465 |
+
if self.loss_beta == 0:
|
466 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
467 |
+
else:
|
468 |
+
loss = F.smooth_l1_loss(
|
469 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
470 |
+
).sum(dim=-1)
|
471 |
+
|
472 |
+
if self.loss_scale is not None:
|
473 |
+
scale = self.loss_scale
|
474 |
+
else:
|
475 |
+
scale = 1 / math.sqrt(sz)
|
476 |
+
|
477 |
+
result["losses"]["regression"] = loss.sum() * scale
|
478 |
+
|
479 |
+
if "sample_size" not in result:
|
480 |
+
result["sample_size"] = loss.numel()
|
481 |
+
|
482 |
+
with torch.no_grad():
|
483 |
+
result["target_var"] = self.compute_var(y)
|
484 |
+
result["pred_var"] = self.compute_var(x.float())
|
485 |
+
|
486 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
487 |
+
logger.error(
|
488 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
489 |
+
)
|
490 |
+
raise Exception(
|
491 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
492 |
+
)
|
493 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
494 |
+
logger.error(
|
495 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
496 |
+
)
|
497 |
+
raise Exception(
|
498 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
499 |
+
)
|
500 |
+
|
501 |
+
if self.ema is not None:
|
502 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
503 |
+
|
504 |
+
return result
|
505 |
+
|
506 |
+
@staticmethod
|
507 |
+
def compute_var(y):
|
508 |
+
y = y.view(-1, y.size(-1))
|
509 |
+
if dist.is_initialized():
|
510 |
+
zc = torch.tensor(y.size(0)).cuda()
|
511 |
+
zs = y.sum(dim=0)
|
512 |
+
zss = (y ** 2).sum(dim=0)
|
513 |
+
|
514 |
+
dist.all_reduce(zc)
|
515 |
+
dist.all_reduce(zs)
|
516 |
+
dist.all_reduce(zss)
|
517 |
+
|
518 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
519 |
+
return torch.sqrt(var + 1e-6).mean()
|
520 |
+
else:
|
521 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
522 |
+
|
523 |
+
def extract_features(
|
524 |
+
self, source, padding_mask, mask=False, layer=None
|
525 |
+
):
|
526 |
+
res = self.forward(
|
527 |
+
source,
|
528 |
+
padding_mask,
|
529 |
+
mask=mask,
|
530 |
+
features_only=True,
|
531 |
+
layer=layer,
|
532 |
+
)
|
533 |
+
return res
|
534 |
+
|
535 |
+
def remove_pretraining_modules(self, last_layer=None):
|
536 |
+
self.final_proj = None
|
537 |
+
self.ema = None
|
538 |
+
if last_layer is not None:
|
539 |
+
self.encoder.layers = nn.ModuleList(
|
540 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
541 |
+
)
|
prompting/RepCodec/examples/.ipynb_checkpoints/data2vec_feature_reader-checkpoint.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fairseq import tasks
|
13 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
14 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
from data2vec_audio import Data2VecAudioModel
|
18 |
+
|
19 |
+
logger = logging.getLogger("dump_feature")
|
20 |
+
|
21 |
+
|
22 |
+
class Data2vecFeatureReader(object):
|
23 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
24 |
+
state = load_checkpoint_to_cpu(ckpt_path)
|
25 |
+
cfg = state["cfg"]
|
26 |
+
# load task
|
27 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
28 |
+
task.load_state_dict(state["task_state"])
|
29 |
+
# load model config
|
30 |
+
if "layer_type" not in cfg.model:
|
31 |
+
# fix a missing key
|
32 |
+
model_config = {k: v for k, v in cfg.model.items()}
|
33 |
+
model_config["layer_type"] = "transformer"
|
34 |
+
model_config = OmegaConf.create(model_config)
|
35 |
+
else:
|
36 |
+
model_config = cfg.model
|
37 |
+
|
38 |
+
# fix param name in the state
|
39 |
+
state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
|
40 |
+
state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
|
41 |
+
del state["model"]["_ema"]
|
42 |
+
|
43 |
+
# load model
|
44 |
+
model = Data2VecAudioModel.build_model(model_config)
|
45 |
+
model.load_state_dict(
|
46 |
+
state["model"], strict=True, model_cfg=model_config
|
47 |
+
)
|
48 |
+
|
49 |
+
self.device = device
|
50 |
+
logger.info(f"device = {self.device}")
|
51 |
+
|
52 |
+
self.model = model.eval().to(self.device)
|
53 |
+
self.task = task
|
54 |
+
self.layer = layer - 1 # make it 1-based
|
55 |
+
self.max_chunk = max_chunk
|
56 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
57 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
58 |
+
|
59 |
+
def read_audio(self, path, ref_len=None):
|
60 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
61 |
+
if wav.ndim == 2:
|
62 |
+
wav = wav.mean(-1)
|
63 |
+
assert wav.ndim == 1, wav.ndim
|
64 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
65 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
66 |
+
return wav
|
67 |
+
|
68 |
+
def get_feats(self, path, ref_len=None):
|
69 |
+
x = self.read_audio(path, ref_len=ref_len)
|
70 |
+
with torch.no_grad():
|
71 |
+
x = torch.from_numpy(x).float().to(self.device)
|
72 |
+
if self.task.cfg.normalize:
|
73 |
+
x = F.layer_norm(x, x.shape)
|
74 |
+
x = x.view(1, -1)
|
75 |
+
|
76 |
+
feat = []
|
77 |
+
for start in range(0, x.size(1), self.max_chunk):
|
78 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
79 |
+
res = self.model.extract_features(
|
80 |
+
source=x_chunk,
|
81 |
+
padding_mask=None,
|
82 |
+
mask=False,
|
83 |
+
layer=self.layer,
|
84 |
+
)
|
85 |
+
feat_chunk = res["x"]
|
86 |
+
feat.append(feat_chunk)
|
87 |
+
return torch.cat(feat, 1).squeeze(0)
|
prompting/RepCodec/examples/.ipynb_checkpoints/dump_feature-checkpoint.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
from feature_utils import get_path_iterator, dump_feature
|
13 |
+
|
14 |
+
logging.basicConfig(
|
15 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
16 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
17 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
18 |
+
stream=sys.stdout,
|
19 |
+
)
|
20 |
+
logger = logging.getLogger("dump_feature")
|
21 |
+
|
22 |
+
|
23 |
+
def main(
|
24 |
+
model_type: str,
|
25 |
+
tsv_path: str,
|
26 |
+
ckpt_path: str,
|
27 |
+
whisper_root: str,
|
28 |
+
whisper_name: str,
|
29 |
+
layer: int,
|
30 |
+
nshard: int,
|
31 |
+
rank: int,
|
32 |
+
feat_dir: str,
|
33 |
+
max_chunk: int,
|
34 |
+
use_cpu: bool = False
|
35 |
+
):
|
36 |
+
device = "cpu" if use_cpu else "cuda"
|
37 |
+
|
38 |
+
# some checks
|
39 |
+
if model_type in ["hubert", "data2vec"]:
|
40 |
+
assert ckpt_path and os.path.exists(ckpt_path)
|
41 |
+
elif model_type in ["whisper"]:
|
42 |
+
assert whisper_name and whisper_root
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported model type {model_type}")
|
45 |
+
|
46 |
+
reader = None
|
47 |
+
if model_type == "hubert":
|
48 |
+
from hubert_feature_reader import HubertFeatureReader
|
49 |
+
reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
50 |
+
elif model_type == "data2vec":
|
51 |
+
from data2vec_feature_reader import Data2vecFeatureReader
|
52 |
+
reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
53 |
+
elif model_type == "whisper":
|
54 |
+
from whisper_feature_reader import WhisperFeatureReader
|
55 |
+
reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
|
56 |
+
|
57 |
+
assert reader is not None
|
58 |
+
|
59 |
+
generator, num = get_path_iterator(tsv_path, nshard, rank)
|
60 |
+
dump_feature(reader, generator, num, nshard, rank, feat_dir)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
import argparse
|
65 |
+
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument(
|
68 |
+
"--model_type",
|
69 |
+
required=True,
|
70 |
+
type=str,
|
71 |
+
choices=["data2vec", "hubert", "whisper"],
|
72 |
+
help="the type of the speech encoder."
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--tsv_path",
|
76 |
+
required=True,
|
77 |
+
type=str,
|
78 |
+
help="the path to the tsv file."
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--ckpt_path",
|
82 |
+
required=False,
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="path to the speech model. must provide for HuBERT and data2vec"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--whisper_root",
|
89 |
+
required=False,
|
90 |
+
type=str,
|
91 |
+
default=None,
|
92 |
+
help="root dir to download/store whisper model. must provide for whisper model."
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--whisper_name",
|
96 |
+
required=False,
|
97 |
+
type=str,
|
98 |
+
default=None,
|
99 |
+
help="name of whisper model. e.g., large-v2. must provide for whisper model."
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--layer",
|
103 |
+
required=True,
|
104 |
+
type=int,
|
105 |
+
help="which layer of the model. this is 1-based."
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--feat_dir",
|
109 |
+
required=True,
|
110 |
+
type=str,
|
111 |
+
help="the output dir to save the representations."
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--nshard",
|
115 |
+
required=False,
|
116 |
+
type=int,
|
117 |
+
default=1,
|
118 |
+
help="total number of shards."
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--rank",
|
122 |
+
required=False,
|
123 |
+
type=int,
|
124 |
+
default=0,
|
125 |
+
help="shard id of this process."
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--max_chunk",
|
129 |
+
type=int,
|
130 |
+
default=1600000,
|
131 |
+
help="max number of frames of each batch."
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--use_cpu",
|
135 |
+
default=False,
|
136 |
+
action="store_true",
|
137 |
+
help="whether use cpu instead of gpu."
|
138 |
+
)
|
139 |
+
args = parser.parse_args()
|
140 |
+
logger.info(args)
|
141 |
+
|
142 |
+
main(**vars(args))
|
prompting/RepCodec/examples/.ipynb_checkpoints/feature_utils-checkpoint.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import tqdm
|
15 |
+
from npy_append_array import NpyAppendArray
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
22 |
+
stream=sys.stdout,
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("feature_utils")
|
25 |
+
|
26 |
+
|
27 |
+
def get_shard_range(tot, nshard, rank):
|
28 |
+
assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
|
29 |
+
start = round(tot / nshard * rank)
|
30 |
+
end = round(tot / nshard * (rank + 1))
|
31 |
+
assert start < end, f"start={start}, end={end}"
|
32 |
+
logger.info(
|
33 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
34 |
+
f"({start}-{end}) out of {tot}"
|
35 |
+
)
|
36 |
+
return start, end
|
37 |
+
|
38 |
+
|
39 |
+
def get_path_iterator(tsv, nshard, rank):
|
40 |
+
with open(tsv, "r") as f:
|
41 |
+
root = f.readline().rstrip()
|
42 |
+
lines = [line.rstrip() for line in f]
|
43 |
+
start, end = get_shard_range(len(lines), nshard, rank)
|
44 |
+
lines = lines[start:end]
|
45 |
+
def iterate():
|
46 |
+
for line in lines:
|
47 |
+
subpath, nsample = line.split("\t")
|
48 |
+
yield f"{subpath}", int(nsample)
|
49 |
+
return iterate, len(lines)
|
50 |
+
|
51 |
+
|
52 |
+
def dump_feature(reader, generator, num, nshard, rank, feat_dir):
|
53 |
+
iterator = generator()
|
54 |
+
|
55 |
+
feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
|
56 |
+
leng_path = f"{feat_dir}/{rank}_{nshard}.len"
|
57 |
+
|
58 |
+
os.makedirs(feat_dir, exist_ok=True)
|
59 |
+
if os.path.exists(feat_path):
|
60 |
+
os.remove(feat_path)
|
61 |
+
|
62 |
+
feat_f = NpyAppendArray(feat_path)
|
63 |
+
with open(leng_path, "w") as leng_f:
|
64 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
65 |
+
feat = reader.get_feats(path, nsample)
|
66 |
+
feat_f.append(feat.cpu().numpy())
|
67 |
+
leng_f.write(f"{len(feat)}\n")
|
68 |
+
logger.info("finished successfully")
|
69 |
+
|
70 |
+
|
prompting/RepCodec/examples/.ipynb_checkpoints/some_run-checkpoint.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
cache_dir = "./../../../cache"
|
6 |
+
|
7 |
+
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
|
8 |
+
|
9 |
+
from repcodec.RepCodec import RepCodec
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
config = "./../repcodec/configs/repcodec_dim1024.yaml"
|
14 |
+
with open(config) as fp:
|
15 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
16 |
+
|
17 |
+
model = RepCodec(**conf)
|
18 |
+
model.load_state_dict(torch.load("./../../models/data2vec_large_l18.pkl", map_location="cuda:0")["model"]["repcodec"])
|
19 |
+
model.quantizer.initial()
|
20 |
+
model.eval()
|
21 |
+
model.to("cuda:0")
|
22 |
+
|
23 |
+
from data2vec_feature_reader import Data2vecFeatureReader
|
24 |
+
|
25 |
+
reader = Data2vecFeatureReader("./../../models/vox_pretrained.pt", 18, device="cuda:0", max_chunk=1600000)
|
26 |
+
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
for split in dataset.keys():
|
31 |
+
|
32 |
+
tokens = []
|
33 |
+
|
34 |
+
for idx in tqdm(range(len(dataset[split]))):
|
35 |
+
|
36 |
+
sample = dataset[split][idx]
|
37 |
+
|
38 |
+
x = sample["audio"]["array"]
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
x = torch.from_numpy(x).float().to(reader.device)
|
42 |
+
if reader.task.cfg.normalize:
|
43 |
+
x = F.layer_norm(x, x.shape)
|
44 |
+
x = x.view(1, -1)
|
45 |
+
|
46 |
+
feat = []
|
47 |
+
for start in range(0, x.size(1), reader.max_chunk):
|
48 |
+
x_chunk = x[:, start: start + reader.max_chunk]
|
49 |
+
res = reader.model.extract_features(
|
50 |
+
source=x_chunk,
|
51 |
+
padding_mask=None,
|
52 |
+
mask=False,
|
53 |
+
layer=reader.layer,
|
54 |
+
)
|
55 |
+
feat_chunk = res["x"]
|
56 |
+
feat.append(feat_chunk)
|
57 |
+
|
58 |
+
features = torch.cat(feat, 1).permute(0, 2, 1)
|
59 |
+
|
60 |
+
x = model.encoder(features)
|
61 |
+
z = model.projector(x)
|
62 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
63 |
+
tkn = idx.detach().cpu().data.numpy()[0]
|
64 |
+
|
65 |
+
tokens.append(tkn)
|
66 |
+
np.savez(f"./tkns/{split}.npz", *tokens)
|
prompting/RepCodec/examples/Untitled.ipynb
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "72bf1b45-66fd-450d-8d5c-bec9e0b3d08f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from data2vec_feature_reader import Data2vecFeatureReader\n",
|
11 |
+
"\n",
|
12 |
+
"reader = Data2vecFeatureReader(\"./../../models/vox_pretrained.pt\", 18, device=\"cuda:0\", max_chunk=1600000)"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 2,
|
18 |
+
"id": "84a9d238-048a-4772-a47b-5aadc50f36df",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [
|
21 |
+
{
|
22 |
+
"data": {
|
23 |
+
"application/vnd.jupyter.widget-view+json": {
|
24 |
+
"model_id": "fb01bc434d964db08fde7f9f2c90ea3c",
|
25 |
+
"version_major": 2,
|
26 |
+
"version_minor": 0
|
27 |
+
},
|
28 |
+
"text/plain": [
|
29 |
+
"Loading dataset shards: 0%| | 0/45 [00:00<?, ?it/s]"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
"metadata": {},
|
33 |
+
"output_type": "display_data"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"data": {
|
37 |
+
"application/vnd.jupyter.widget-view+json": {
|
38 |
+
"model_id": "d4adc62013644ed0b16056aa217448a9",
|
39 |
+
"version_major": 2,
|
40 |
+
"version_minor": 0
|
41 |
+
},
|
42 |
+
"text/plain": [
|
43 |
+
"Loading dataset shards: 0%| | 0/60 [00:00<?, ?it/s]"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
"metadata": {},
|
47 |
+
"output_type": "display_data"
|
48 |
+
}
|
49 |
+
],
|
50 |
+
"source": [
|
51 |
+
"from datasets import load_dataset\n",
|
52 |
+
"from tqdm import tqdm\n",
|
53 |
+
"import pandas as pd\n",
|
54 |
+
"\n",
|
55 |
+
"cache_dir = \"./../../../cache\"\n",
|
56 |
+
"\n",
|
57 |
+
"dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 5,
|
63 |
+
"id": "cffd49ca-3524-4ac4-8ba5-bc4fcc9e0f53",
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"ename": "ImportError",
|
68 |
+
"evalue": "attempted relative import with no known parent package",
|
69 |
+
"output_type": "error",
|
70 |
+
"traceback": [
|
71 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
72 |
+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
73 |
+
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mRepCodec\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RepCodec\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01myaml\u001b[39;00m\n",
|
74 |
+
"\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
|
75 |
+
]
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"from .RepCodec import RepCodec\n",
|
80 |
+
"import torch\n",
|
81 |
+
"import yaml\n",
|
82 |
+
"\n",
|
83 |
+
"config = \"./../repcodec/configs/repcodec_dim1024.yaml\"\n",
|
84 |
+
"with open(config) as fp:\n",
|
85 |
+
" conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
|
86 |
+
"\n",
|
87 |
+
"model = RepCodec(**conf)\n",
|
88 |
+
"model.load_state_dict(torch.load(\"./../../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
|
89 |
+
"model.quantizer.initial()\n",
|
90 |
+
"model.eval()\n",
|
91 |
+
"model.to(\"cuda:0\")"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"id": "a9a1731e-052c-4af0-a29c-b171a988b300",
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [],
|
100 |
+
"source": [
|
101 |
+
"import torch.nn.functional as F\n",
|
102 |
+
"\n",
|
103 |
+
"sample = dataset[\"train.clean.100\"][1]\n",
|
104 |
+
"\n",
|
105 |
+
"x = sample[\"audio\"][\"array\"]\n",
|
106 |
+
"\n",
|
107 |
+
"with torch.no_grad():\n",
|
108 |
+
" x = torch.from_numpy(x).float().to(reader.device)\n",
|
109 |
+
" if reader.task.cfg.normalize:\n",
|
110 |
+
" x = F.layer_norm(x, x.shape)\n",
|
111 |
+
" x = x.view(1, -1)\n",
|
112 |
+
"\n",
|
113 |
+
" feat = []\n",
|
114 |
+
" for start in range(0, x.size(1), reader.max_chunk):\n",
|
115 |
+
" x_chunk = x[:, start: start + reader.max_chunk]\n",
|
116 |
+
" res = reader.model.extract_features(\n",
|
117 |
+
" source=x_chunk,\n",
|
118 |
+
" padding_mask=None,\n",
|
119 |
+
" mask=False,\n",
|
120 |
+
" layer=reader.layer,\n",
|
121 |
+
" )\n",
|
122 |
+
" feat_chunk = res[\"x\"]\n",
|
123 |
+
" feat.append(feat_chunk)\n",
|
124 |
+
" \n",
|
125 |
+
" features = torch.cat(feat, 1).permute(0, 2, 1)\n",
|
126 |
+
"\n",
|
127 |
+
" x = model.encoder(features)\n",
|
128 |
+
" z = model.projector(x)\n",
|
129 |
+
" _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
|
130 |
+
" tokens = idx.cpu().data.numpy().tolist()[0]"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"id": "1810e6dc-2ece-4aca-a29a-e1933b8ce82a",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"import logging\n",
|
141 |
+
"import os\n",
|
142 |
+
"import sys\n",
|
143 |
+
"\n",
|
144 |
+
"import tqdm\n",
|
145 |
+
"from npy_append_array import NpyAppendArray\n",
|
146 |
+
"\n",
|
147 |
+
"def get_shard_range(tot, nshard, rank):\n",
|
148 |
+
" assert rank < nshard and rank >= 0, f\"invaid rank/nshard {rank}/{nshard}\"\n",
|
149 |
+
" start = round(tot / nshard * rank)\n",
|
150 |
+
" end = round(tot / nshard * (rank + 1))\n",
|
151 |
+
" assert start < end, f\"start={start}, end={end}\"\n",
|
152 |
+
" logger.info(\n",
|
153 |
+
" f\"rank {rank} of {nshard}, process {end-start} \"\n",
|
154 |
+
" f\"({start}-{end}) out of {tot}\"\n",
|
155 |
+
" )\n",
|
156 |
+
" return start, end\n",
|
157 |
+
"\n",
|
158 |
+
"def get_path_iterator(tsv, nshard, rank):\n",
|
159 |
+
" with open(tsv, \"r\") as f:\n",
|
160 |
+
" root = f.readline().rstrip()\n",
|
161 |
+
" lines = [line.rstrip() for line in f]\n",
|
162 |
+
" start, end = get_shard_range(len(lines), nshard, rank)\n",
|
163 |
+
" lines = lines[start:end]\n",
|
164 |
+
" def iterate():\n",
|
165 |
+
" for line in lines:\n",
|
166 |
+
" subpath, nsample = line.split(\"\\t\")\n",
|
167 |
+
" yield f\"{root}/{subpath}\", int(nsample)\n",
|
168 |
+
" return iterate, len(lines)\n",
|
169 |
+
"\n",
|
170 |
+
"def dump_feature(reader, generator, num, nshard, rank, feat_dir):\n",
|
171 |
+
" iterator = generator()\n",
|
172 |
+
"\n",
|
173 |
+
" feat_path = f\"{feat_dir}/{rank}_{nshard}.npy\"\n",
|
174 |
+
" leng_path = f\"{feat_dir}/{rank}_{nshard}.len\"\n",
|
175 |
+
"\n",
|
176 |
+
" os.makedirs(feat_dir, exist_ok=True)\n",
|
177 |
+
" if os.path.exists(feat_path):\n",
|
178 |
+
" os.remove(feat_path)\n",
|
179 |
+
"\n",
|
180 |
+
" feat_f = NpyAppendArray(feat_path)\n",
|
181 |
+
" with open(leng_path, \"w\") as leng_f:\n",
|
182 |
+
" for path, nsample in tqdm.tqdm(iterator, total=num):\n",
|
183 |
+
" feat = reader.get_feats(path, nsample)\n",
|
184 |
+
" feat_f.append(feat.cpu().numpy())\n",
|
185 |
+
" leng_f.write(f\"{len(feat)}\\n\")\n",
|
186 |
+
" logger.info(\"finished successfully\")\n",
|
187 |
+
"\n",
|
188 |
+
"generator, num = get_path_iterator(tsv_path, nshard, rank)\n",
|
189 |
+
"dump_feature(reader, generator, num, nshard, rank, feat_dir)"
|
190 |
+
]
|
191 |
+
}
|
192 |
+
],
|
193 |
+
"metadata": {
|
194 |
+
"kernelspec": {
|
195 |
+
"display_name": "Python 3 (ipykernel)",
|
196 |
+
"language": "python",
|
197 |
+
"name": "python3"
|
198 |
+
},
|
199 |
+
"language_info": {
|
200 |
+
"codemirror_mode": {
|
201 |
+
"name": "ipython",
|
202 |
+
"version": 3
|
203 |
+
},
|
204 |
+
"file_extension": ".py",
|
205 |
+
"mimetype": "text/x-python",
|
206 |
+
"name": "python",
|
207 |
+
"nbconvert_exporter": "python",
|
208 |
+
"pygments_lexer": "ipython3",
|
209 |
+
"version": "3.8.10"
|
210 |
+
}
|
211 |
+
},
|
212 |
+
"nbformat": 4,
|
213 |
+
"nbformat_minor": 5
|
214 |
+
}
|
prompting/RepCodec/examples/__pycache__/data2vec_audio.cpython-38.pyc
ADDED
Binary file (12.2 kB). View file
|
|
prompting/RepCodec/examples/__pycache__/data2vec_feature_reader.cpython-38.pyc
ADDED
Binary file (3 kB). View file
|
|
prompting/RepCodec/examples/__pycache__/feature_utils.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
prompting/RepCodec/examples/__pycache__/hubert_feature_reader.cpython-38.pyc
ADDED
Binary file (2.22 kB). View file
|
|
prompting/RepCodec/examples/__pycache__/tokenize.cpython-38.pyc
ADDED
Binary file (1.89 kB). View file
|
|
prompting/RepCodec/examples/data2vec_audio.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.distributed as dist
|
21 |
+
|
22 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
23 |
+
from fairseq.data.data_utils import compute_mask_indices
|
24 |
+
from fairseq.models import BaseFairseqModel, register_model
|
25 |
+
from fairseq.models.wav2vec import (
|
26 |
+
ConvFeatureExtractionModel,
|
27 |
+
Wav2Vec2Config,
|
28 |
+
TransformerEncoder,
|
29 |
+
)
|
30 |
+
from fairseq.modules import (
|
31 |
+
GradMultiply,
|
32 |
+
LayerNorm,
|
33 |
+
)
|
34 |
+
from fairseq.utils import index_put
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
42 |
+
|
43 |
+
loss_beta: float = field(
|
44 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
45 |
+
)
|
46 |
+
loss_scale: Optional[float] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={
|
49 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
50 |
+
},
|
51 |
+
)
|
52 |
+
average_top_k_layers: int = field(
|
53 |
+
default=8, metadata={"help": "how many layers to average"}
|
54 |
+
)
|
55 |
+
|
56 |
+
layer_norm_target_layer: bool = False
|
57 |
+
instance_norm_target_layer: bool = False
|
58 |
+
instance_norm_targets: bool = False
|
59 |
+
layer_norm_targets: bool = False
|
60 |
+
batch_norm_target_layer: bool = False
|
61 |
+
group_norm_target_layer: bool = False
|
62 |
+
|
63 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
64 |
+
ema_end_decay: float = field(
|
65 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
66 |
+
)
|
67 |
+
|
68 |
+
# when to finish annealing ema decay rate
|
69 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
70 |
+
|
71 |
+
ema_transformer_only: bool = field(
|
72 |
+
default=True,
|
73 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
74 |
+
)
|
75 |
+
ema_layers_only: bool = field(
|
76 |
+
default=True,
|
77 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
78 |
+
)
|
79 |
+
|
80 |
+
max_update: int = II("optimization.max_update")
|
81 |
+
|
82 |
+
min_target_var: float = field(
|
83 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
84 |
+
)
|
85 |
+
min_pred_var: float = field(
|
86 |
+
default=0.01,
|
87 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
92 |
+
r = end - start
|
93 |
+
pct_remaining = 1 - curr_step / total_steps
|
94 |
+
return end - r * pct_remaining
|
95 |
+
|
96 |
+
|
97 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
98 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
99 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
100 |
+
super().__init__()
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
104 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
105 |
+
|
106 |
+
self.ema = None
|
107 |
+
self.embed = cfg.encoder_embed_dim
|
108 |
+
|
109 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
110 |
+
self.loss_beta = cfg.loss_beta
|
111 |
+
self.loss_scale = cfg.loss_scale
|
112 |
+
|
113 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
114 |
+
conv_layers=feature_enc_layers,
|
115 |
+
dropout=0.0,
|
116 |
+
mode=cfg.extractor_mode,
|
117 |
+
conv_bias=cfg.conv_bias,
|
118 |
+
)
|
119 |
+
|
120 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
121 |
+
|
122 |
+
self.mask_prob = cfg.mask_prob
|
123 |
+
self.mask_selection = cfg.mask_selection
|
124 |
+
self.mask_other = cfg.mask_other
|
125 |
+
self.mask_length = cfg.mask_length
|
126 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
127 |
+
self.mask_min_space = cfg.mask_min_space
|
128 |
+
|
129 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
130 |
+
self.mask_channel_before = cfg.mask_channel_before
|
131 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
132 |
+
self.mask_channel_other = cfg.mask_channel_other
|
133 |
+
self.mask_channel_length = cfg.mask_channel_length
|
134 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
135 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
136 |
+
|
137 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
138 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
139 |
+
|
140 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
141 |
+
|
142 |
+
self.mask_emb = nn.Parameter(
|
143 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
144 |
+
)
|
145 |
+
|
146 |
+
self.encoder = TransformerEncoder(cfg)
|
147 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
148 |
+
|
149 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
150 |
+
|
151 |
+
self.num_updates = 0
|
152 |
+
|
153 |
+
def make_ema_teacher(self):
|
154 |
+
ema_config = EMAModuleConfig(
|
155 |
+
ema_decay=self.cfg.ema_decay,
|
156 |
+
ema_fp32=True,
|
157 |
+
)
|
158 |
+
skip_keys = set()
|
159 |
+
if self.cfg.ema_layers_only:
|
160 |
+
self.cfg.ema_transformer_only = True
|
161 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
162 |
+
skip_keys.add(f"pos_conv.{k}")
|
163 |
+
|
164 |
+
self.ema = EMAModule(
|
165 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
166 |
+
ema_config,
|
167 |
+
skip_keys=skip_keys,
|
168 |
+
)
|
169 |
+
|
170 |
+
def set_num_updates(self, num_updates):
|
171 |
+
super().set_num_updates(num_updates)
|
172 |
+
|
173 |
+
if self.ema is None and self.final_proj is not None:
|
174 |
+
logger.info(f"making ema teacher")
|
175 |
+
self.make_ema_teacher()
|
176 |
+
elif self.training and self.ema is not None:
|
177 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
178 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
179 |
+
decay = self.cfg.ema_end_decay
|
180 |
+
else:
|
181 |
+
decay = get_annealed_rate(
|
182 |
+
self.cfg.ema_decay,
|
183 |
+
self.cfg.ema_end_decay,
|
184 |
+
num_updates,
|
185 |
+
self.cfg.ema_anneal_end_step,
|
186 |
+
)
|
187 |
+
self.ema.set_decay(decay)
|
188 |
+
if self.ema.get_decay() < 1:
|
189 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
190 |
+
|
191 |
+
self.num_updates = num_updates
|
192 |
+
|
193 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
194 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
195 |
+
|
196 |
+
if self.ema is not None:
|
197 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
198 |
+
|
199 |
+
return state
|
200 |
+
|
201 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
202 |
+
if self.ema is not None:
|
203 |
+
k = prefix + "_ema"
|
204 |
+
assert k in state_dict
|
205 |
+
self.ema.restore(state_dict[k], True)
|
206 |
+
del state_dict[k]
|
207 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
211 |
+
"""Build a new model instance."""
|
212 |
+
|
213 |
+
return cls(cfg)
|
214 |
+
|
215 |
+
def apply_mask(
|
216 |
+
self,
|
217 |
+
x,
|
218 |
+
padding_mask,
|
219 |
+
mask_indices=None,
|
220 |
+
mask_channel_indices=None,
|
221 |
+
):
|
222 |
+
B, T, C = x.shape
|
223 |
+
|
224 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
225 |
+
mask_channel_indices = compute_mask_indices(
|
226 |
+
(B, C),
|
227 |
+
None,
|
228 |
+
self.mask_channel_prob,
|
229 |
+
self.mask_channel_length,
|
230 |
+
self.mask_channel_selection,
|
231 |
+
self.mask_channel_other,
|
232 |
+
no_overlap=self.no_mask_channel_overlap,
|
233 |
+
min_space=self.mask_channel_min_space,
|
234 |
+
)
|
235 |
+
mask_channel_indices = (
|
236 |
+
torch.from_numpy(mask_channel_indices)
|
237 |
+
.to(x.device)
|
238 |
+
.unsqueeze(1)
|
239 |
+
.expand(-1, T, -1)
|
240 |
+
)
|
241 |
+
x[mask_channel_indices] = 0
|
242 |
+
|
243 |
+
if self.mask_prob > 0:
|
244 |
+
if mask_indices is None:
|
245 |
+
mask_indices = compute_mask_indices(
|
246 |
+
(B, T),
|
247 |
+
padding_mask,
|
248 |
+
self.mask_prob,
|
249 |
+
self.mask_length,
|
250 |
+
self.mask_selection,
|
251 |
+
self.mask_other,
|
252 |
+
min_masks=1,
|
253 |
+
no_overlap=self.no_mask_overlap,
|
254 |
+
min_space=self.mask_min_space,
|
255 |
+
require_same_masks=self.cfg.require_same_masks,
|
256 |
+
mask_dropout=self.cfg.mask_dropout,
|
257 |
+
)
|
258 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
259 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
260 |
+
else:
|
261 |
+
mask_indices = None
|
262 |
+
|
263 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
264 |
+
if mask_channel_indices is None:
|
265 |
+
mask_channel_indices = compute_mask_indices(
|
266 |
+
(B, C),
|
267 |
+
None,
|
268 |
+
self.mask_channel_prob,
|
269 |
+
self.mask_channel_length,
|
270 |
+
self.mask_channel_selection,
|
271 |
+
self.mask_channel_other,
|
272 |
+
no_overlap=self.no_mask_channel_overlap,
|
273 |
+
min_space=self.mask_channel_min_space,
|
274 |
+
)
|
275 |
+
mask_channel_indices = (
|
276 |
+
torch.from_numpy(mask_channel_indices)
|
277 |
+
.to(x.device)
|
278 |
+
.unsqueeze(1)
|
279 |
+
.expand(-1, T, -1)
|
280 |
+
)
|
281 |
+
x = index_put(x, mask_channel_indices, 0)
|
282 |
+
|
283 |
+
return x, mask_indices
|
284 |
+
|
285 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
286 |
+
"""
|
287 |
+
Computes the output length of the convolutional layers
|
288 |
+
"""
|
289 |
+
|
290 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
291 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
292 |
+
|
293 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
294 |
+
|
295 |
+
for i in range(len(conv_cfg_list)):
|
296 |
+
input_lengths = _conv_out_length(
|
297 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
298 |
+
)
|
299 |
+
|
300 |
+
return input_lengths.to(torch.long)
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
source,
|
305 |
+
padding_mask=None,
|
306 |
+
mask=True,
|
307 |
+
features_only=False,
|
308 |
+
layer=None,
|
309 |
+
mask_indices=None,
|
310 |
+
mask_channel_indices=None,
|
311 |
+
padding_count=None,
|
312 |
+
):
|
313 |
+
features = source
|
314 |
+
|
315 |
+
if self.feature_grad_mult > 0:
|
316 |
+
features = self.feature_extractor(features)
|
317 |
+
if self.feature_grad_mult != 1.0:
|
318 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
319 |
+
else:
|
320 |
+
with torch.no_grad():
|
321 |
+
features = self.feature_extractor(features)
|
322 |
+
|
323 |
+
features = features.transpose(1, 2)
|
324 |
+
|
325 |
+
features = self.layer_norm(features)
|
326 |
+
|
327 |
+
orig_padding_mask = padding_mask
|
328 |
+
|
329 |
+
if padding_mask is not None and padding_mask.any():
|
330 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
331 |
+
# apply conv formula to get real output_lengths
|
332 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
333 |
+
|
334 |
+
padding_mask = torch.zeros(
|
335 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
336 |
+
)
|
337 |
+
|
338 |
+
# these two operations makes sure that all values
|
339 |
+
# before the output lengths indices are attended to
|
340 |
+
padding_mask[
|
341 |
+
(
|
342 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
343 |
+
output_lengths - 1,
|
344 |
+
)
|
345 |
+
] = 1
|
346 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
347 |
+
else:
|
348 |
+
padding_mask = None
|
349 |
+
|
350 |
+
if self.post_extract_proj is not None:
|
351 |
+
features = self.post_extract_proj(features)
|
352 |
+
|
353 |
+
pre_encoder_features = None
|
354 |
+
if self.cfg.ema_transformer_only:
|
355 |
+
pre_encoder_features = features.clone()
|
356 |
+
|
357 |
+
features = self.dropout_input(features)
|
358 |
+
|
359 |
+
if mask:
|
360 |
+
x, mask_indices = self.apply_mask(
|
361 |
+
features,
|
362 |
+
padding_mask,
|
363 |
+
mask_indices=mask_indices,
|
364 |
+
mask_channel_indices=mask_channel_indices,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
x = features
|
368 |
+
mask_indices = None
|
369 |
+
|
370 |
+
x, layer_results = self.encoder(
|
371 |
+
x,
|
372 |
+
padding_mask=padding_mask,
|
373 |
+
layer=layer,
|
374 |
+
)
|
375 |
+
|
376 |
+
if features_only:
|
377 |
+
return {
|
378 |
+
"x": x,
|
379 |
+
"padding_mask": padding_mask,
|
380 |
+
"layer_results": layer_results,
|
381 |
+
}
|
382 |
+
|
383 |
+
result = {
|
384 |
+
"losses": {},
|
385 |
+
}
|
386 |
+
|
387 |
+
with torch.no_grad():
|
388 |
+
self.ema.model.eval()
|
389 |
+
|
390 |
+
if self.cfg.ema_transformer_only:
|
391 |
+
y, layer_results = self.ema.model.extract_features(
|
392 |
+
pre_encoder_features,
|
393 |
+
padding_mask=padding_mask,
|
394 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
395 |
+
)
|
396 |
+
y = {
|
397 |
+
"x": y,
|
398 |
+
"padding_mask": padding_mask,
|
399 |
+
"layer_results": layer_results,
|
400 |
+
}
|
401 |
+
else:
|
402 |
+
y = self.ema.model.extract_features(
|
403 |
+
source=source,
|
404 |
+
padding_mask=orig_padding_mask,
|
405 |
+
mask=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
409 |
+
|
410 |
+
permuted = False
|
411 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
412 |
+
target_layer_results = [
|
413 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
414 |
+
]
|
415 |
+
permuted = True
|
416 |
+
|
417 |
+
if self.cfg.batch_norm_target_layer:
|
418 |
+
target_layer_results = [
|
419 |
+
F.batch_norm(
|
420 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
421 |
+
)
|
422 |
+
for tl in target_layer_results
|
423 |
+
]
|
424 |
+
|
425 |
+
if self.cfg.instance_norm_target_layer:
|
426 |
+
target_layer_results = [
|
427 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
428 |
+
]
|
429 |
+
|
430 |
+
if permuted:
|
431 |
+
target_layer_results = [
|
432 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
433 |
+
]
|
434 |
+
|
435 |
+
if self.cfg.group_norm_target_layer:
|
436 |
+
target_layer_results = [
|
437 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
438 |
+
for tl in target_layer_results
|
439 |
+
]
|
440 |
+
|
441 |
+
if self.cfg.layer_norm_target_layer:
|
442 |
+
target_layer_results = [
|
443 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
444 |
+
for tl in target_layer_results
|
445 |
+
]
|
446 |
+
|
447 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
448 |
+
|
449 |
+
if self.cfg.layer_norm_targets:
|
450 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
451 |
+
|
452 |
+
if self.cfg.instance_norm_targets:
|
453 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
454 |
+
|
455 |
+
if not permuted:
|
456 |
+
y = y.transpose(0, 1)
|
457 |
+
|
458 |
+
y = y[mask_indices]
|
459 |
+
|
460 |
+
x = x[mask_indices]
|
461 |
+
x = self.final_proj(x)
|
462 |
+
|
463 |
+
sz = x.size(-1)
|
464 |
+
|
465 |
+
if self.loss_beta == 0:
|
466 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
467 |
+
else:
|
468 |
+
loss = F.smooth_l1_loss(
|
469 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
470 |
+
).sum(dim=-1)
|
471 |
+
|
472 |
+
if self.loss_scale is not None:
|
473 |
+
scale = self.loss_scale
|
474 |
+
else:
|
475 |
+
scale = 1 / math.sqrt(sz)
|
476 |
+
|
477 |
+
result["losses"]["regression"] = loss.sum() * scale
|
478 |
+
|
479 |
+
if "sample_size" not in result:
|
480 |
+
result["sample_size"] = loss.numel()
|
481 |
+
|
482 |
+
with torch.no_grad():
|
483 |
+
result["target_var"] = self.compute_var(y)
|
484 |
+
result["pred_var"] = self.compute_var(x.float())
|
485 |
+
|
486 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
487 |
+
logger.error(
|
488 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
489 |
+
)
|
490 |
+
raise Exception(
|
491 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
492 |
+
)
|
493 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
494 |
+
logger.error(
|
495 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
496 |
+
)
|
497 |
+
raise Exception(
|
498 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
499 |
+
)
|
500 |
+
|
501 |
+
if self.ema is not None:
|
502 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
503 |
+
|
504 |
+
return result
|
505 |
+
|
506 |
+
@staticmethod
|
507 |
+
def compute_var(y):
|
508 |
+
y = y.view(-1, y.size(-1))
|
509 |
+
if dist.is_initialized():
|
510 |
+
zc = torch.tensor(y.size(0)).cuda()
|
511 |
+
zs = y.sum(dim=0)
|
512 |
+
zss = (y ** 2).sum(dim=0)
|
513 |
+
|
514 |
+
dist.all_reduce(zc)
|
515 |
+
dist.all_reduce(zs)
|
516 |
+
dist.all_reduce(zss)
|
517 |
+
|
518 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
519 |
+
return torch.sqrt(var + 1e-6).mean()
|
520 |
+
else:
|
521 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
522 |
+
|
523 |
+
def extract_features(
|
524 |
+
self, source, padding_mask, mask=False, layer=None
|
525 |
+
):
|
526 |
+
res = self.forward(
|
527 |
+
source,
|
528 |
+
padding_mask,
|
529 |
+
mask=mask,
|
530 |
+
features_only=True,
|
531 |
+
layer=layer,
|
532 |
+
)
|
533 |
+
return res
|
534 |
+
|
535 |
+
def remove_pretraining_modules(self, last_layer=None):
|
536 |
+
self.final_proj = None
|
537 |
+
self.ema = None
|
538 |
+
if last_layer is not None:
|
539 |
+
self.encoder.layers = nn.ModuleList(
|
540 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
541 |
+
)
|
prompting/RepCodec/examples/data2vec_feature_reader.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from fairseq import tasks
|
13 |
+
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
14 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
|
17 |
+
from data2vec_audio import Data2VecAudioModel
|
18 |
+
|
19 |
+
logger = logging.getLogger("dump_feature")
|
20 |
+
|
21 |
+
|
22 |
+
class Data2vecFeatureReader(object):
|
23 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
24 |
+
state = load_checkpoint_to_cpu(ckpt_path)
|
25 |
+
cfg = state["cfg"]
|
26 |
+
# load task
|
27 |
+
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
28 |
+
task.load_state_dict(state["task_state"])
|
29 |
+
# load model config
|
30 |
+
if "layer_type" not in cfg.model:
|
31 |
+
# fix a missing key
|
32 |
+
model_config = {k: v for k, v in cfg.model.items()}
|
33 |
+
model_config["layer_type"] = "transformer"
|
34 |
+
model_config = OmegaConf.create(model_config)
|
35 |
+
else:
|
36 |
+
model_config = cfg.model
|
37 |
+
|
38 |
+
# fix param name in the state
|
39 |
+
state["model"]["final_proj.weight"] = state["model"].pop("final_proj.0.weight")
|
40 |
+
state["model"]["final_proj.bias"] = state["model"].pop("final_proj.0.bias")
|
41 |
+
del state["model"]["_ema"]
|
42 |
+
|
43 |
+
# load model
|
44 |
+
model = Data2VecAudioModel.build_model(model_config)
|
45 |
+
model.load_state_dict(
|
46 |
+
state["model"], strict=True, model_cfg=model_config
|
47 |
+
)
|
48 |
+
|
49 |
+
self.device = device
|
50 |
+
logger.info(f"device = {self.device}")
|
51 |
+
|
52 |
+
self.model = model.eval().to(self.device)
|
53 |
+
self.task = task
|
54 |
+
self.layer = layer - 1 # make it 1-based
|
55 |
+
self.max_chunk = max_chunk
|
56 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
57 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
58 |
+
|
59 |
+
def read_audio(self, path, ref_len=None):
|
60 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
61 |
+
if wav.ndim == 2:
|
62 |
+
wav = wav.mean(-1)
|
63 |
+
assert wav.ndim == 1, wav.ndim
|
64 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
65 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
66 |
+
return wav
|
67 |
+
|
68 |
+
def get_feats(self, path, ref_len=None):
|
69 |
+
x = self.read_audio(path, ref_len=ref_len)
|
70 |
+
with torch.no_grad():
|
71 |
+
x = torch.from_numpy(x).float().to(self.device)
|
72 |
+
if self.task.cfg.normalize:
|
73 |
+
x = F.layer_norm(x, x.shape)
|
74 |
+
x = x.view(1, -1)
|
75 |
+
|
76 |
+
feat = []
|
77 |
+
for start in range(0, x.size(1), self.max_chunk):
|
78 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
79 |
+
res = self.model.extract_features(
|
80 |
+
source=x_chunk,
|
81 |
+
padding_mask=None,
|
82 |
+
mask=False,
|
83 |
+
layer=self.layer,
|
84 |
+
)
|
85 |
+
feat_chunk = res["x"]
|
86 |
+
feat.append(feat_chunk)
|
87 |
+
return torch.cat(feat, 1).squeeze(0)
|
prompting/RepCodec/examples/dump_feature.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
|
12 |
+
from feature_utils import get_path_iterator, dump_feature
|
13 |
+
|
14 |
+
logging.basicConfig(
|
15 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
16 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
17 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
18 |
+
stream=sys.stdout,
|
19 |
+
)
|
20 |
+
logger = logging.getLogger("dump_feature")
|
21 |
+
|
22 |
+
|
23 |
+
def main(
|
24 |
+
model_type: str,
|
25 |
+
tsv_path: str,
|
26 |
+
ckpt_path: str,
|
27 |
+
whisper_root: str,
|
28 |
+
whisper_name: str,
|
29 |
+
layer: int,
|
30 |
+
nshard: int,
|
31 |
+
rank: int,
|
32 |
+
feat_dir: str,
|
33 |
+
max_chunk: int,
|
34 |
+
use_cpu: bool = False
|
35 |
+
):
|
36 |
+
device = "cpu" if use_cpu else "cuda"
|
37 |
+
|
38 |
+
# some checks
|
39 |
+
if model_type in ["hubert", "data2vec"]:
|
40 |
+
assert ckpt_path and os.path.exists(ckpt_path)
|
41 |
+
elif model_type in ["whisper"]:
|
42 |
+
assert whisper_name and whisper_root
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported model type {model_type}")
|
45 |
+
|
46 |
+
reader = None
|
47 |
+
if model_type == "hubert":
|
48 |
+
from hubert_feature_reader import HubertFeatureReader
|
49 |
+
reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
50 |
+
elif model_type == "data2vec":
|
51 |
+
from data2vec_feature_reader import Data2vecFeatureReader
|
52 |
+
reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
|
53 |
+
elif model_type == "whisper":
|
54 |
+
from whisper_feature_reader import WhisperFeatureReader
|
55 |
+
reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
|
56 |
+
|
57 |
+
assert reader is not None
|
58 |
+
|
59 |
+
generator, num = get_path_iterator(tsv_path, nshard, rank)
|
60 |
+
dump_feature(reader, generator, num, nshard, rank, feat_dir)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
import argparse
|
65 |
+
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument(
|
68 |
+
"--model_type",
|
69 |
+
required=True,
|
70 |
+
type=str,
|
71 |
+
choices=["data2vec", "hubert", "whisper"],
|
72 |
+
help="the type of the speech encoder."
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--tsv_path",
|
76 |
+
required=True,
|
77 |
+
type=str,
|
78 |
+
help="the path to the tsv file."
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--ckpt_path",
|
82 |
+
required=False,
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
help="path to the speech model. must provide for HuBERT and data2vec"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--whisper_root",
|
89 |
+
required=False,
|
90 |
+
type=str,
|
91 |
+
default=None,
|
92 |
+
help="root dir to download/store whisper model. must provide for whisper model."
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--whisper_name",
|
96 |
+
required=False,
|
97 |
+
type=str,
|
98 |
+
default=None,
|
99 |
+
help="name of whisper model. e.g., large-v2. must provide for whisper model."
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--layer",
|
103 |
+
required=True,
|
104 |
+
type=int,
|
105 |
+
help="which layer of the model. this is 1-based."
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--feat_dir",
|
109 |
+
required=True,
|
110 |
+
type=str,
|
111 |
+
help="the output dir to save the representations."
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--nshard",
|
115 |
+
required=False,
|
116 |
+
type=int,
|
117 |
+
default=1,
|
118 |
+
help="total number of shards."
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--rank",
|
122 |
+
required=False,
|
123 |
+
type=int,
|
124 |
+
default=0,
|
125 |
+
help="shard id of this process."
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--max_chunk",
|
129 |
+
type=int,
|
130 |
+
default=1600000,
|
131 |
+
help="max number of frames of each batch."
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--use_cpu",
|
135 |
+
default=False,
|
136 |
+
action="store_true",
|
137 |
+
help="whether use cpu instead of gpu."
|
138 |
+
)
|
139 |
+
args = parser.parse_args()
|
140 |
+
logger.info(args)
|
141 |
+
|
142 |
+
main(**vars(args))
|
prompting/RepCodec/examples/feature_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
# ref: https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/simple_kmeans/feature_utils.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import tqdm
|
15 |
+
from npy_append_array import NpyAppendArray
|
16 |
+
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
22 |
+
stream=sys.stdout,
|
23 |
+
)
|
24 |
+
logger = logging.getLogger("feature_utils")
|
25 |
+
|
26 |
+
|
27 |
+
def get_shard_range(tot, nshard, rank):
|
28 |
+
assert rank < nshard and rank >= 0, f"invaid rank/nshard {rank}/{nshard}"
|
29 |
+
start = round(tot / nshard * rank)
|
30 |
+
end = round(tot / nshard * (rank + 1))
|
31 |
+
assert start < end, f"start={start}, end={end}"
|
32 |
+
logger.info(
|
33 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
34 |
+
f"({start}-{end}) out of {tot}"
|
35 |
+
)
|
36 |
+
return start, end
|
37 |
+
|
38 |
+
|
39 |
+
def get_path_iterator(tsv, nshard, rank):
|
40 |
+
with open(tsv, "r") as f:
|
41 |
+
root = f.readline().rstrip()
|
42 |
+
lines = [line.rstrip() for line in f]
|
43 |
+
start, end = get_shard_range(len(lines), nshard, rank)
|
44 |
+
lines = lines[start:end]
|
45 |
+
def iterate():
|
46 |
+
for line in lines:
|
47 |
+
subpath, nsample = line.split("\t")
|
48 |
+
yield f"{subpath}", int(nsample)
|
49 |
+
return iterate, len(lines)
|
50 |
+
|
51 |
+
|
52 |
+
def dump_feature(reader, generator, num, nshard, rank, feat_dir):
|
53 |
+
iterator = generator()
|
54 |
+
|
55 |
+
feat_path = f"{feat_dir}/{rank}_{nshard}.npy"
|
56 |
+
leng_path = f"{feat_dir}/{rank}_{nshard}.len"
|
57 |
+
|
58 |
+
os.makedirs(feat_dir, exist_ok=True)
|
59 |
+
if os.path.exists(feat_path):
|
60 |
+
os.remove(feat_path)
|
61 |
+
|
62 |
+
feat_f = NpyAppendArray(feat_path)
|
63 |
+
with open(leng_path, "w") as leng_f:
|
64 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
65 |
+
feat = reader.get_feats(path, nsample)
|
66 |
+
feat_f.append(feat.cpu().numpy())
|
67 |
+
leng_f.write(f"{len(feat)}\n")
|
68 |
+
logger.info("finished successfully")
|
69 |
+
|
70 |
+
|
prompting/RepCodec/examples/hubert_feature_reader.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq)
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import fairseq
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
15 |
+
|
16 |
+
logger = logging.getLogger("dump_feature")
|
17 |
+
|
18 |
+
|
19 |
+
class HubertFeatureReader(object):
|
20 |
+
def __init__(self, ckpt_path: str, layer: int, device: str, max_chunk=1600000):
|
21 |
+
(
|
22 |
+
model,
|
23 |
+
cfg,
|
24 |
+
task,
|
25 |
+
) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
26 |
+
|
27 |
+
self.device = device
|
28 |
+
logger.info(f"device = {self.device}")
|
29 |
+
|
30 |
+
self.model = model[0].eval().to(self.device)
|
31 |
+
self.task = task
|
32 |
+
self.layer = layer
|
33 |
+
self.max_chunk = max_chunk
|
34 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
35 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
36 |
+
|
37 |
+
def read_audio(self, path, ref_len=None):
|
38 |
+
wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate)
|
39 |
+
if wav.ndim == 2:
|
40 |
+
wav = wav.mean(-1)
|
41 |
+
assert wav.ndim == 1, wav.ndim
|
42 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
43 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
44 |
+
return wav
|
45 |
+
|
46 |
+
def get_feats(self, path, ref_len=None):
|
47 |
+
x = self.read_audio(path, ref_len=ref_len)
|
48 |
+
with torch.no_grad():
|
49 |
+
x = torch.from_numpy(x).float().to(self.device)
|
50 |
+
if self.task.cfg.normalize:
|
51 |
+
x = F.layer_norm(x, x.shape)
|
52 |
+
x = x.view(1, -1)
|
53 |
+
|
54 |
+
feat = []
|
55 |
+
for start in range(0, x.size(1), self.max_chunk):
|
56 |
+
x_chunk = x[:, start: start + self.max_chunk]
|
57 |
+
feat_chunk, _ = self.model.extract_features(
|
58 |
+
source=x_chunk,
|
59 |
+
padding_mask=None,
|
60 |
+
mask=False,
|
61 |
+
output_layer=self.layer,
|
62 |
+
)
|
63 |
+
feat.append(feat_chunk)
|
64 |
+
return torch.cat(feat, 1).squeeze(0)
|
prompting/RepCodec/examples/some_run.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
cache_dir = "./../../../cache"
|
6 |
+
|
7 |
+
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
|
8 |
+
|
9 |
+
from repcodec.RepCodec import RepCodec
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
config = "./../repcodec/configs/repcodec_dim1024.yaml"
|
14 |
+
with open(config) as fp:
|
15 |
+
conf = yaml.load(fp, Loader=yaml.FullLoader)
|
16 |
+
|
17 |
+
model = RepCodec(**conf)
|
18 |
+
model.load_state_dict(torch.load("./../../models/data2vec_large_l18.pkl", map_location="cuda:0")["model"]["repcodec"])
|
19 |
+
model.quantizer.initial()
|
20 |
+
model.eval()
|
21 |
+
model.to("cuda:0")
|
22 |
+
|
23 |
+
from data2vec_feature_reader import Data2vecFeatureReader
|
24 |
+
|
25 |
+
reader = Data2vecFeatureReader("./../../models/vox_pretrained.pt", 18, device="cuda:0", max_chunk=1600000)
|
26 |
+
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
for split in dataset.keys():
|
31 |
+
|
32 |
+
tokens = []
|
33 |
+
|
34 |
+
for idx in tqdm(range(len(dataset[split]))):
|
35 |
+
|
36 |
+
sample = dataset[split][idx]
|
37 |
+
|
38 |
+
x = sample["audio"]["array"]
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
x = torch.from_numpy(x).float().to(reader.device)
|
42 |
+
if reader.task.cfg.normalize:
|
43 |
+
x = F.layer_norm(x, x.shape)
|
44 |
+
x = x.view(1, -1)
|
45 |
+
|
46 |
+
feat = []
|
47 |
+
for start in range(0, x.size(1), reader.max_chunk):
|
48 |
+
x_chunk = x[:, start: start + reader.max_chunk]
|
49 |
+
res = reader.model.extract_features(
|
50 |
+
source=x_chunk,
|
51 |
+
padding_mask=None,
|
52 |
+
mask=False,
|
53 |
+
layer=reader.layer,
|
54 |
+
)
|
55 |
+
feat_chunk = res["x"]
|
56 |
+
feat.append(feat_chunk)
|
57 |
+
|
58 |
+
features = torch.cat(feat, 1).permute(0, 2, 1)
|
59 |
+
|
60 |
+
x = model.encoder(features)
|
61 |
+
z = model.projector(x)
|
62 |
+
_, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
|
63 |
+
tkn = idx.detach().cpu().data.numpy()[0]
|
64 |
+
|
65 |
+
tokens.append(tkn)
|
66 |
+
np.savez(f"./tkns/{split}.npz", *tokens)
|
prompting/RepCodec/examples/tkns/test.clean.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b2ee3dbd59f84e1903c251bdce1e05a8ed6d1f30bf492ab368d4873dc0e713b
|
3 |
+
size 8415186
|
prompting/RepCodec/examples/tkns/test.other.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:188625227fc71069d6607920a733b76b0c025f6d75bd698abf87bbcfd6089d2c
|
3 |
+
size 8403370
|
prompting/RepCodec/examples/tkns/train.clean.100.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5966bdb4154a433ffe605d4cb152fdda7349ab6456747f0e2275f154739c004
|
3 |
+
size 151819656
|
prompting/RepCodec/examples/tkns/train.clean.360.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ed187075b7d515f3a1162df967548bf0d0c384aa31419ac12e7e5b95616ab2c
|
3 |
+
size 549056222
|
prompting/RepCodec/examples/tkns/train.other.500.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14441ffb0f6e0ca5edd1772ddf8fbc6065f8b3e70796b4244dd62a5c286467a7
|
3 |
+
size 751972686
|
prompting/RepCodec/examples/tkns/validation.clean.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5bb0560fa0169e387353d30f13a72eb40704a6518c0b8f72bdfbccada9a8660
|
3 |
+
size 8412602
|
prompting/RepCodec/examples/tkns/validation.other.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:314083f9013d459db663631f237f4ae68860cf25a3137dd7c9d42377e7efb8c1
|
3 |
+
size 8067914
|
prompting/RepCodec/examples/tokens/data2vec_base_l6_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/tokens/data2vec_large_l18_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/tokens/hubert_base_l9_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/tokens/hubert_large_l18_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/tokens/whisper_large_l32_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/tokens/whisper_medium_l24_dev-clean.tokens
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompting/RepCodec/examples/whisper_feature_reader.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq) and
|
7 |
+
# Whisper (https://github.com/openai/whisper/)
|
8 |
+
|
9 |
+
import io
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Optional, Union
|
13 |
+
|
14 |
+
import soundfile as sf
|
15 |
+
import torch
|
16 |
+
from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models
|
17 |
+
from whisper.audio import log_mel_spectrogram
|
18 |
+
from whisper.model import ModelDimensions
|
19 |
+
|
20 |
+
from whisper_model import Whisper_
|
21 |
+
|
22 |
+
logger = logging.getLogger("dump_feature")
|
23 |
+
|
24 |
+
|
25 |
+
def load_model(
|
26 |
+
name: str,
|
27 |
+
device: Optional[Union[str, torch.device]] = None,
|
28 |
+
download_root: str = None,
|
29 |
+
in_memory: bool = False,
|
30 |
+
) -> Whisper_:
|
31 |
+
"""
|
32 |
+
Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97
|
33 |
+
But we will load a `Whisper_` model for feature extraction.
|
34 |
+
|
35 |
+
Parameters
|
36 |
+
----------
|
37 |
+
name : str
|
38 |
+
one of the official model names listed by `whisper.available_models()`, or
|
39 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
40 |
+
device : Union[str, torch.device]
|
41 |
+
the PyTorch device to put the model into
|
42 |
+
download_root: str
|
43 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
44 |
+
in_memory: bool
|
45 |
+
whether to preload the model weights into host memory
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
model : Whisper
|
50 |
+
The Whisper ASR model instance
|
51 |
+
"""
|
52 |
+
|
53 |
+
if device is None:
|
54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
if download_root is None:
|
56 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
57 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
58 |
+
|
59 |
+
if name in _MODELS:
|
60 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
61 |
+
alignment_heads = _ALIGNMENT_HEADS[name]
|
62 |
+
elif os.path.isfile(name):
|
63 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
64 |
+
alignment_heads = None
|
65 |
+
else:
|
66 |
+
raise RuntimeError(
|
67 |
+
f"Model {name} not found; available models = {available_models()}"
|
68 |
+
)
|
69 |
+
|
70 |
+
with (
|
71 |
+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
72 |
+
) as fp:
|
73 |
+
checkpoint = torch.load(fp, map_location=device)
|
74 |
+
del checkpoint_file
|
75 |
+
|
76 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
77 |
+
model = Whisper_(dims)
|
78 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
79 |
+
|
80 |
+
if alignment_heads is not None:
|
81 |
+
model.set_alignment_heads(alignment_heads)
|
82 |
+
|
83 |
+
return model.to(device)
|
84 |
+
|
85 |
+
|
86 |
+
class WhisperFeatureReader(object):
|
87 |
+
def __init__(self, root, ckpt, layer, device):
|
88 |
+
self.device = device
|
89 |
+
logger.info(f"device = {self.device}")
|
90 |
+
|
91 |
+
self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval()
|
92 |
+
self.model.decoder = None # to save some memory by deleting the decoder
|
93 |
+
self.layer = layer # one-based
|
94 |
+
|
95 |
+
def read_audio(self, path, ref_len=None):
|
96 |
+
wav, sample_rate = sf.read(path)
|
97 |
+
assert sample_rate == 16000, sample_rate
|
98 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
99 |
+
logger.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
100 |
+
return wav
|
101 |
+
|
102 |
+
def get_feats(self, path, ref_len=None):
|
103 |
+
wav = self.read_audio(path, ref_len)
|
104 |
+
audio_length = len(wav)
|
105 |
+
with torch.no_grad():
|
106 |
+
mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device))
|
107 |
+
hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer)
|
108 |
+
feature_length = audio_length // 320
|
109 |
+
hidden = hidden[0, :feature_length]
|
110 |
+
return hidden.contiguous()
|
prompting/RepCodec/examples/whisper_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on fairseq (https://github.com/facebookresearch/fairseq) and
|
7 |
+
# Whisper (https://github.com/openai/whisper/)
|
8 |
+
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions
|
15 |
+
|
16 |
+
|
17 |
+
class AudioEncoder_(AudioEncoder):
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
super(AudioEncoder_, self).__init__(*args, **kwargs)
|
20 |
+
|
21 |
+
def extract_feature(self, x: Tensor, target_layer: Optional[int] = None):
|
22 |
+
"""
|
23 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
24 |
+
the mel spectrogram of the audio
|
25 |
+
"""
|
26 |
+
x = F.gelu(self.conv1(x))
|
27 |
+
x = F.gelu(self.conv2(x))
|
28 |
+
x = x.permute(0, 2, 1)
|
29 |
+
|
30 |
+
length_x = x.shape[1]
|
31 |
+
if length_x > self.positional_embedding.shape[0]:
|
32 |
+
self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1]))
|
33 |
+
self.positional_embedding = self.positional_embedding.to(x.device)
|
34 |
+
x = (x + self.positional_embedding[:length_x, :]).to(x.dtype)
|
35 |
+
|
36 |
+
if target_layer is None:
|
37 |
+
target_layer = len(self.blocks)
|
38 |
+
|
39 |
+
for block in self.blocks[:target_layer]:
|
40 |
+
x = block(x)
|
41 |
+
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class Whisper_(Whisper):
|
46 |
+
def __init__(self, dims: ModelDimensions):
|
47 |
+
super(Whisper_, self).__init__(dims)
|
48 |
+
# replace audio encoder with our audio encoder
|
49 |
+
self.encoder = AudioEncoder_(
|
50 |
+
self.dims.n_mels,
|
51 |
+
self.dims.n_audio_ctx,
|
52 |
+
self.dims.n_audio_state,
|
53 |
+
self.dims.n_audio_head,
|
54 |
+
self.dims.n_audio_layer,
|
55 |
+
)
|
56 |
+
|
57 |
+
def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None):
|
58 |
+
return self.encoder.extract_feature(mel, target_layer)
|