File size: 20,496 Bytes
2cddd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "997bed07-1181-4562-962a-cb8aa18e1d16",
   "metadata": {},
   "outputs": [],
   "source": [
    "from repcodec.RepCodec import RepCodec\n",
    "import torch\n",
    "import yaml\n",
    "\n",
    "config = \"repcodec/configs/repcodec_dim1024.yaml\"\n",
    "with open(config) as fp:\n",
    "    conf = yaml.load(fp, Loader=yaml.FullLoader)\n",
    "\n",
    "model = RepCodec(**conf)\n",
    "model.load_state_dict(torch.load(\"./../models/data2vec_large_l18.pkl\", map_location=\"cuda:0\")[\"model\"][\"repcodec\"])\n",
    "model.quantizer.initial()\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c5516f1-3565-4080-8612-d5ce52ea2a4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# input shape: (batch size, hidden dim, sequence length)\n",
    "random_features = torch.randn(size=(1, 1024, 100))\n",
    "with torch.no_grad():\n",
    "    x = model.encoder(random_features)\n",
    "    z = model.projector(x)\n",
    "    _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))\n",
    "    tokens = idx.cpu().data.numpy().tolist()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "439ecea7-f0d4-4a61-80c2-729138beee32",
   "metadata": {},
   "source": [
    "## Dump Representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6efa1891-0810-4cfb-9552-764297209e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "python3 examples/dump_feature.py --model_type data2vec --tsv_path \"./files/train.clean.100.tsv\" --ckpt_path \"./../models/vox_pretrained.pt\" --layer 18 --feat_dir \"./features/train.clean.100\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbd1c550-0606-4217-ac65-55ae92843f19",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "\n",
    "cache_dir = \"./../../cache\"\n",
    "\n",
    "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
    "\n",
    "# for split in dataset.keys():\n",
    "#     data = dataset[split]\n",
    "#     num_frames = []\n",
    "#     for idx in tqdm(range(len(data))):\n",
    "#         audio = data[idx][\"audio\"]\n",
    "#         num_frames.append(int(len(audio[\"array\"]) * 16000 // audio[\"sampling_rate\"]))\n",
    "        \n",
    "#     df = pd.DataFrame.from_dict({\n",
    "#         \"file_path\": list(data[\"file\"]),\n",
    "#         \"num_frames\": num_frames\n",
    "#     })\n",
    "#     df.to_csv(f\"./files/{split}.tsv\", sep=\"\\t\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b4af1af-5726-4899-8272-dfe867cb48a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[\"train.clean.100\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae6a0ef4-8c0a-4f6e-9a81-a9c3350e1266",
   "metadata": {},
   "source": [
    "## Prepare the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1247988-5eaa-492a-a3ab-2b11505126a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset, load_dataset\n",
    "from collections import defaultdict\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import string\n",
    "\n",
    "cache_dir = \"./../../cache\"\n",
    "\n",
    "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n",
    "\n",
    "def process(text):\n",
    "\n",
    "    # Lower case every letter\n",
    "    text = text.lower()\n",
    "\n",
    "    # Remove punctuation\n",
    "    punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n",
    "    translation_table = str.maketrans('', '', punctuation_to_remove)\n",
    "    text = text.translate(translation_table)\n",
    "\n",
    "    # Remove whitespaces from front and behind\n",
    "    while text[0] == ' ' or text[-1] == ' ':\n",
    "        if text[0] == ' ':\n",
    "            text = text[1:]\n",
    "        if text[-1] == ' ':\n",
    "            text = text[:-1]\n",
    "    \n",
    "    return text\n",
    "\n",
    "dataset = dataset.remove_columns([\"audio\", \"speaker_id\", \"chapter_id\"])\n",
    "\n",
    "tokenized_ds = defaultdict(lambda: [])\n",
    "\n",
    "for split in dataset.keys():\n",
    "\n",
    "    texts = []\n",
    "    tokens = []\n",
    "    tkns = np.load(f\"./examples/tkns/{split}.npz\")\n",
    "\n",
    "    for idx, key in enumerate(tqdm(tkns.files)):\n",
    "        tokens.append(list(tkns[key]))\n",
    "        texts.append(process(dataset[split][idx][\"text\"]))\n",
    "\n",
    "    tokenized_ds[split] = Dataset.from_dict({\n",
    "        \"text\": texts,\n",
    "        \"audio_tokens\": tokens\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfc82444-3081-4138-aa06-6fb0b7cbc6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import dataset_dict, DatasetDict\n",
    "\n",
    "tds = DatasetDict(tokenized_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "006171b9-d479-4462-9642-d126f77edfc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tds.save_to_disk(\"librispeech_tokenized.hf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "12970376-4f6f-4926-a954-29c32043b64c",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset\n\u001b[0;32m----> 3\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./librispeech_tokenized.hf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2594\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m   2589\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[1;32m   2590\u001b[0m     (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[1;32m   2591\u001b[0m )\n\u001b[1;32m   2593\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2594\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2595\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2596\u001b[0m \u001b[43m    \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2597\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2598\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2599\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2600\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2601\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2602\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2603\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2604\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2605\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2606\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2607\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2608\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2609\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2611\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m   2612\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:2266\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m   2264\u001b[0m     download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m   2265\u001b[0m     download_config\u001b[38;5;241m.\u001b[39mstorage_options\u001b[38;5;241m.\u001b[39mupdate(storage_options)\n\u001b[0;32m-> 2266\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2267\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2268\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2269\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2270\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2272\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2273\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2274\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2275\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_require_default_config_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2276\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_require_custom_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2277\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2278\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m   2279\u001b[0m builder_kwargs \u001b[38;5;241m=\u001b[39m dataset_module\u001b[38;5;241m.\u001b[39mbuilder_kwargs\n",
      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1825\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m   1818\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m LocalDatasetModuleFactoryWithScript(\n\u001b[1;32m   1819\u001b[0m         combined_path,\n\u001b[1;32m   1820\u001b[0m         download_mode\u001b[38;5;241m=\u001b[39mdownload_mode,\n\u001b[1;32m   1821\u001b[0m         dynamic_modules_path\u001b[38;5;241m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m   1822\u001b[0m         trust_remote_code\u001b[38;5;241m=\u001b[39mtrust_remote_code,\n\u001b[1;32m   1823\u001b[0m     )\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m   1824\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39misdir(path):\n\u001b[0;32m-> 1825\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mLocalDatasetModuleFactoryWithoutScript\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1826\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\n\u001b[1;32m   1827\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1828\u001b[0m \u001b[38;5;66;03m# Try remotely\u001b[39;00m\n\u001b[1;32m   1829\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_relative_path(path) \u001b[38;5;129;01mand\u001b[39;00m path\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:1040\u001b[0m, in \u001b[0;36mLocalDatasetModuleFactoryWithoutScript.get_module\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1034\u001b[0m     patterns \u001b[38;5;241m=\u001b[39m get_data_patterns(base_path)\n\u001b[1;32m   1035\u001b[0m data_files \u001b[38;5;241m=\u001b[39m DataFilesDict\u001b[38;5;241m.\u001b[39mfrom_patterns(\n\u001b[1;32m   1036\u001b[0m     patterns,\n\u001b[1;32m   1037\u001b[0m     base_path\u001b[38;5;241m=\u001b[39mbase_path,\n\u001b[1;32m   1038\u001b[0m     allowed_extensions\u001b[38;5;241m=\u001b[39mALL_ALLOWED_EXTENSIONS,\n\u001b[1;32m   1039\u001b[0m )\n\u001b[0;32m-> 1040\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43minfer_module_for_data_files\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1041\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1042\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1043\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1044\u001b[0m data_files \u001b[38;5;241m=\u001b[39m data_files\u001b[38;5;241m.\u001b[39mfilter_extensions(_MODULE_TO_EXTENSIONS[module_name])\n\u001b[1;32m   1045\u001b[0m \u001b[38;5;66;03m# Collect metadata files if the module supports them\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/datasets/load.py:596\u001b[0m, in \u001b[0;36minfer_module_for_data_files\u001b[0;34m(data_files, path, download_config)\u001b[0m\n\u001b[1;32m    594\u001b[0m module_name, default_builder_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(split_modules\u001b[38;5;241m.\u001b[39mvalues()))\n\u001b[1;32m    595\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m((module_name, default_builder_kwargs) \u001b[38;5;241m!=\u001b[39m split_module \u001b[38;5;28;01mfor\u001b[39;00m split_module \u001b[38;5;129;01min\u001b[39;00m split_modules\u001b[38;5;241m.\u001b[39mvalues()):\n\u001b[0;32m--> 596\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt infer the same data file format for all splits. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msplit_modules\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    597\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m module_name:\n\u001b[1;32m    598\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m DataFilesNotFoundError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo (supported) data files found\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m path \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
      "\u001b[0;31mValueError\u001b[0m: Couldn't infer the same data file format for all splits. Got {NamedSplit('train'): ('arrow', {}), NamedSplit('validation'): ('json', {}), NamedSplit('test'): ('json', {})}"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"./librispeech_tokenized.hf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b3ba0d58-b788-43b5-87a7-726aaa12dbbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import dataset_dict, DatasetDict, Dataset\n",
    "\n",
    "dataset = DatasetDict.load_from_disk(\"./librispeech_tokenized.hf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b7239186-73ae-407a-b9f6-b5a16f3a7ddc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "726"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset[\"train.clean.100\"][0][\"audio_tokens\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}