mshukor
commited on
Commit
·
80b588a
1
Parent(s):
1aa8f27
clean
Browse files- Audio_Captioning.ipynb +0 -0
- Captioning.ipynb +0 -0
- Image_gen.ipynb +0 -301
- README_EncouragingLoss.md +0 -34
- VG.ipynb +0 -0
- VQA.ipynb +0 -0
- Video_Captioning.ipynb +0 -0
- checkpoints.md +0 -36
- checkpoints_cn.md +0 -82
- colab.md +0 -9
- datasets.md +0 -44
- evaluate.py +0 -239
- modelscope.md +0 -23
- ofa_test.ipynb +0 -2499
- prompt_tuning.md +0 -66
- spaces.md +0 -8
- test.py +0 -101
- train.py +0 -729
- trainer.py +0 -1569
- transformers.md +0 -69
Audio_Captioning.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
Captioning.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
Image_gen.ipynb
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "399f2fcf-9241-4910-a30d-6ca19880d0ad",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"## Import"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "code",
|
13 |
-
"execution_count": 15,
|
14 |
-
"id": "97e68340-0096-475e-8ed8-22f5d627e3ad",
|
15 |
-
"metadata": {},
|
16 |
-
"outputs": [],
|
17 |
-
"source": [
|
18 |
-
"import torch\n",
|
19 |
-
"import numpy as np\n",
|
20 |
-
"from fairseq import utils, tasks\n",
|
21 |
-
"from fairseq import checkpoint_utils\n",
|
22 |
-
"from utils.eval_utils import eval_step\n",
|
23 |
-
"from tasks.mm_tasks import ImageGenTask\n",
|
24 |
-
"from models.unival import UnIVALModel\n",
|
25 |
-
"from PIL import Image\n",
|
26 |
-
"from torchvision import transforms\n",
|
27 |
-
"import time\n",
|
28 |
-
"\n",
|
29 |
-
"\n",
|
30 |
-
"# turn on cuda if GPU is available\n",
|
31 |
-
"use_cuda = torch.cuda.is_available()\n",
|
32 |
-
"# use fp16 only when GPU is available\n",
|
33 |
-
"use_fp16 = True if use_cuda else False"
|
34 |
-
]
|
35 |
-
},
|
36 |
-
{
|
37 |
-
"cell_type": "code",
|
38 |
-
"execution_count": 16,
|
39 |
-
"id": "719cef65-c00c-4c9c-90b2-e660b386c3d5",
|
40 |
-
"metadata": {},
|
41 |
-
"outputs": [
|
42 |
-
{
|
43 |
-
"data": {
|
44 |
-
"text/plain": [
|
45 |
-
"<function fairseq.tasks.register_task.<locals>.register_task_cls(cls)>"
|
46 |
-
]
|
47 |
-
},
|
48 |
-
"execution_count": 16,
|
49 |
-
"metadata": {},
|
50 |
-
"output_type": "execute_result"
|
51 |
-
}
|
52 |
-
],
|
53 |
-
"source": [
|
54 |
-
"# Register caption task\n",
|
55 |
-
"tasks.register_task('image_gen', ImageGenTask)\n"
|
56 |
-
]
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"cell_type": "markdown",
|
60 |
-
"id": "cc9c1d7b-898b-4ac4-adf3-832891d9e4be",
|
61 |
-
"metadata": {},
|
62 |
-
"source": [
|
63 |
-
"### Load model "
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"cell_type": "code",
|
68 |
-
"execution_count": 12,
|
69 |
-
"id": "568bb6ea-eef9-4024-98e6-35e74b5ffeec",
|
70 |
-
"metadata": {},
|
71 |
-
"outputs": [
|
72 |
-
{
|
73 |
-
"name": "stdout",
|
74 |
-
"output_type": "stream",
|
75 |
-
"text": [
|
76 |
-
"self.sample_patch_num 784\n",
|
77 |
-
"self.sample_audio_patch_num None\n",
|
78 |
-
"self.sample_video_patch_num None\n",
|
79 |
-
"self.with_cls False\n",
|
80 |
-
"Frozen image bn <class 'models.ofa.frozen_bn.FrozenBatchNorm2d'>\n",
|
81 |
-
"Loading: all_resnext101\n",
|
82 |
-
"use bn: <class 'torch.nn.modules.batchnorm.BatchNorm3d'>\n",
|
83 |
-
"load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n",
|
84 |
-
"_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n",
|
85 |
-
"load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n",
|
86 |
-
"<All keys matched successfully>\n",
|
87 |
-
"RAM memory % used: 10.5\n",
|
88 |
-
"RAM Used (GB): 19.574349824\n",
|
89 |
-
"encoder\n",
|
90 |
-
"RAM memory % used: 10.5\n",
|
91 |
-
"decoder\n",
|
92 |
-
"RAM memory % used: 10.5\n",
|
93 |
-
"ofa\n",
|
94 |
-
"Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n"
|
95 |
-
]
|
96 |
-
}
|
97 |
-
],
|
98 |
-
"source": [
|
99 |
-
"# Load pretrained ckpt & config\n",
|
100 |
-
"clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'\n",
|
101 |
-
"vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'\n",
|
102 |
-
"vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'\n",
|
103 |
-
"\n",
|
104 |
-
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
105 |
-
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'\n",
|
106 |
-
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'\n",
|
107 |
-
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'\n",
|
108 |
-
"\n",
|
109 |
-
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
110 |
-
"checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
111 |
-
"\n",
|
112 |
-
"\n",
|
113 |
-
"\n",
|
114 |
-
"video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n",
|
115 |
-
"resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n",
|
116 |
-
"\n",
|
117 |
-
"gen_images_path='results/image_gen/'\n",
|
118 |
-
"\n",
|
119 |
-
"overrides = {\"bpe_dir\": \"utils/BPE\",\n",
|
120 |
-
" \"eval_cider\": False,\n",
|
121 |
-
" \"beam\": 24,\n",
|
122 |
-
" \"max_len_b\": 1024,\n",
|
123 |
-
" \"max_len_a\": 0,\n",
|
124 |
-
" \"min_len\": 1024,\n",
|
125 |
-
" \"sampling_topk\": 256,\n",
|
126 |
-
" \"constraint_range\": \"50265,58457\",\n",
|
127 |
-
" \"clip_model_path\": clip_model_path,\n",
|
128 |
-
" \"vqgan_model_path\": vqgan_model_path,\n",
|
129 |
-
" \"vqgan_config_path\": vqgan_config_path,\n",
|
130 |
-
" \"seed\": 42,\n",
|
131 |
-
" \"video_model_path\": video_model_path, \n",
|
132 |
-
" \"resnet_model_path\": resnet_model_path,\n",
|
133 |
-
" \"gen_images_path\":gen_images_path,\n",
|
134 |
-
" \"patch_image_size\": 256,\n",
|
135 |
-
" \"temperature\": 1.5,\n",
|
136 |
-
" }\n",
|
137 |
-
"\n",
|
138 |
-
"models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
|
139 |
-
" utils.split_paths(checkpoint_path),\n",
|
140 |
-
" arg_overrides=overrides\n",
|
141 |
-
")\n",
|
142 |
-
"\n",
|
143 |
-
"task.cfg.sampling_times = 2\n",
|
144 |
-
"# Move models to GPU\n",
|
145 |
-
"for model in models:\n",
|
146 |
-
" model.eval()\n",
|
147 |
-
" if use_fp16:\n",
|
148 |
-
" model.half()\n",
|
149 |
-
" if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
|
150 |
-
" model.cuda()\n",
|
151 |
-
" model.prepare_for_inference_(cfg)\n",
|
152 |
-
"\n",
|
153 |
-
"# Initialize generator\n",
|
154 |
-
"generator = task.build_generator(models, cfg.generation)\n",
|
155 |
-
"\n",
|
156 |
-
"# Text preprocess\n",
|
157 |
-
"bos_item = torch.LongTensor([task.src_dict.bos()])\n",
|
158 |
-
"eos_item = torch.LongTensor([task.src_dict.eos()])\n",
|
159 |
-
"pad_idx = task.src_dict.pad()"
|
160 |
-
]
|
161 |
-
},
|
162 |
-
{
|
163 |
-
"cell_type": "markdown",
|
164 |
-
"id": "5e4a45ec-bce1-495b-8033-3b574367b360",
|
165 |
-
"metadata": {},
|
166 |
-
"source": [
|
167 |
-
"### Preprocess"
|
168 |
-
]
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"execution_count": 13,
|
173 |
-
"id": "9f2e7e32-c9a0-43b3-bf86-2419d9f7dfe0",
|
174 |
-
"metadata": {},
|
175 |
-
"outputs": [],
|
176 |
-
"source": [
|
177 |
-
"def encode_text(text, length=None, append_bos=False, append_eos=False):\n",
|
178 |
-
" s = task.tgt_dict.encode_line(\n",
|
179 |
-
" line=task.bpe.encode(text),\n",
|
180 |
-
" add_if_not_exist=False,\n",
|
181 |
-
" append_eos=False\n",
|
182 |
-
" ).long()\n",
|
183 |
-
" if length is not None:\n",
|
184 |
-
" s = s[:length]\n",
|
185 |
-
" if append_bos:\n",
|
186 |
-
" s = torch.cat([bos_item, s])\n",
|
187 |
-
" if append_eos:\n",
|
188 |
-
" s = torch.cat([s, eos_item])\n",
|
189 |
-
" return s\n",
|
190 |
-
"\n",
|
191 |
-
"\n",
|
192 |
-
"# Construct input for image generation task\n",
|
193 |
-
"def construct_sample(query: str):\n",
|
194 |
-
" code_mask = torch.tensor([True])\n",
|
195 |
-
" src_text = encode_text(\" what is the complete image? caption: {}\".format(query), append_bos=True,\n",
|
196 |
-
" append_eos=True).unsqueeze(0)\n",
|
197 |
-
" src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n",
|
198 |
-
" sample = {\n",
|
199 |
-
" \"id\": np.array(['42']),\n",
|
200 |
-
" \"net_input\": {\n",
|
201 |
-
" \"src_tokens\": src_text,\n",
|
202 |
-
" \"src_lengths\": src_length,\n",
|
203 |
-
" \"code_masks\": code_mask\n",
|
204 |
-
" }\n",
|
205 |
-
" }\n",
|
206 |
-
" return sample\n",
|
207 |
-
"\n",
|
208 |
-
"\n",
|
209 |
-
"# Function to turn FP32 to FP16\n",
|
210 |
-
"def apply_half(t):\n",
|
211 |
-
" if t.dtype is torch.float32:\n",
|
212 |
-
" return t.to(dtype=torch.half)\n",
|
213 |
-
" return t\n",
|
214 |
-
"\n",
|
215 |
-
"\n",
|
216 |
-
"# Function for image generation\n",
|
217 |
-
"def image_generation(caption):\n",
|
218 |
-
" sample = construct_sample(caption)\n",
|
219 |
-
" sample = utils.move_to_cuda(sample) if use_cuda else sample\n",
|
220 |
-
" sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n",
|
221 |
-
" print('|Start|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
|
222 |
-
" with torch.no_grad():\n",
|
223 |
-
" result, scores = eval_step(task, generator, models, sample)\n",
|
224 |
-
"\n",
|
225 |
-
" # return top-4 results (ranked by clip)\n",
|
226 |
-
" images = [result[i]['image'] for i in range(4)]\n",
|
227 |
-
" pic_size = 256\n",
|
228 |
-
" retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))\n",
|
229 |
-
" print('|FINISHED|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
|
230 |
-
" for i in range(4):\n",
|
231 |
-
" loc = ((i % 2) * pic_size, int(i / 2) * pic_size)\n",
|
232 |
-
" retImage.paste(images[i], loc)\n",
|
233 |
-
" return retImage"
|
234 |
-
]
|
235 |
-
},
|
236 |
-
{
|
237 |
-
"cell_type": "markdown",
|
238 |
-
"id": "44dec799-c5c2-4d22-8b08-7a7ca2cdf3c9",
|
239 |
-
"metadata": {},
|
240 |
-
"source": [
|
241 |
-
"### Inference"
|
242 |
-
]
|
243 |
-
},
|
244 |
-
{
|
245 |
-
"cell_type": "code",
|
246 |
-
"execution_count": 14,
|
247 |
-
"id": "02d5cd7a-8d63-4fa4-9da1-d4b79ec01445",
|
248 |
-
"metadata": {},
|
249 |
-
"outputs": [
|
250 |
-
{
|
251 |
-
"name": "stdout",
|
252 |
-
"output_type": "stream",
|
253 |
-
"text": [
|
254 |
-
"|Start| 2023-06-29 12:57:39 A brown horse in the street\n",
|
255 |
-
"|FINISHED| 2023-06-29 12:59:03 A brown horse in the street\n"
|
256 |
-
]
|
257 |
-
}
|
258 |
-
],
|
259 |
-
"source": [
|
260 |
-
"query = \"A brown horse in the street\"\n",
|
261 |
-
"# query = \"Cattle grazing on grass near a lake surrounded by mountain.\"\n",
|
262 |
-
"# query = 'A street scene with a double-decker bus on the road.'\n",
|
263 |
-
"# query = 'A path.'\n",
|
264 |
-
"\n",
|
265 |
-
"\n",
|
266 |
-
"retImage = image_generation(query)\n"
|
267 |
-
]
|
268 |
-
},
|
269 |
-
{
|
270 |
-
"cell_type": "code",
|
271 |
-
"execution_count": null,
|
272 |
-
"id": "1a8a1654-1f17-41c7-b410-c7491a96dcee",
|
273 |
-
"metadata": {},
|
274 |
-
"outputs": [],
|
275 |
-
"source": [
|
276 |
-
"retImage.save(f'{query}.png')"
|
277 |
-
]
|
278 |
-
}
|
279 |
-
],
|
280 |
-
"metadata": {
|
281 |
-
"kernelspec": {
|
282 |
-
"display_name": "ofa",
|
283 |
-
"language": "python",
|
284 |
-
"name": "ofa"
|
285 |
-
},
|
286 |
-
"language_info": {
|
287 |
-
"codemirror_mode": {
|
288 |
-
"name": "ipython",
|
289 |
-
"version": 3
|
290 |
-
},
|
291 |
-
"file_extension": ".py",
|
292 |
-
"mimetype": "text/x-python",
|
293 |
-
"name": "python",
|
294 |
-
"nbconvert_exporter": "python",
|
295 |
-
"pygments_lexer": "ipython3",
|
296 |
-
"version": "3.7.4"
|
297 |
-
}
|
298 |
-
},
|
299 |
-
"nbformat": 4,
|
300 |
-
"nbformat_minor": 5
|
301 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_EncouragingLoss.md
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# Finetuning with Encouraging Loss (EL)
|
2 |
-
Below we provide methods for finetuning with label smoothed encouraging loss proposed in [_Well-classified Examples are Underestimated in Classification with Deep Neural Networks_](https://arxiv.org/pdf/2110.06537.pdf) on different downstream tasks.
|
3 |
-
The implementation is in [label_smoothed_encouraging_loss.py](criterions/label_smoothed_encouraging_loss.py).
|
4 |
-
You can set the `--criterion` to `adjust_label_smoothed_encouraging_loss` to use it. This criterion has a hyper-parameter `--log-end`.
|
5 |
-
`--log-end < 1` results in a approximated and conservative version of the full encouraging loss.
|
6 |
-
A high log_end will more strongly weaken the gradient vanishing, enhance the modeling of the data, and increase the growth rate of the margin, but it will also bring a larger gradient norm, which will bring challenges to the existing optimization system.
|
7 |
-
We recommend higher log_end for cases with higher performance, and 0.75 or 0.5 as your first try.
|
8 |
-
## Image Captioning
|
9 |
-
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
10 |
-
|
11 |
-
<details>
|
12 |
-
<summary><b>Finetuning</b></summary>
|
13 |
-
<p>
|
14 |
-
We propose two scripts for stage1. </b>
|
15 |
-
</p>
|
16 |
-
<pre>
|
17 |
-
cd run_scripts/caption
|
18 |
-
nohup sh train_caption_stage1_el.sh > train_stage1_el.out & # stage 1, train with encouraging loss, expected cider 1.403
|
19 |
-
nohup sh train_caption_stage1_el_db.sh > train_stage1_el.out & # stage 1, train with encouraging loss, and drop best examples, expected cider 1.404
|
20 |
-
</pre>
|
21 |
-
</details>
|
22 |
-
|
23 |
-
## Referring Expression Comprehension
|
24 |
-
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
25 |
-
<details>
|
26 |
-
<summary><b>Finetuning</b></summary>
|
27 |
-
<pre>
|
28 |
-
cd run_scripts/refcoco
|
29 |
-
nohup sh train_refcoco_el.sh > train_refcoco_el.out & # finetune for refcoco
|
30 |
-
nohup sh train_refcocoplus_el.sh > train_refcocoplus_el.out & # finetune for refcoco+
|
31 |
-
nohup sh train_refcocog_el.sh > train_refcocog_el.out & # finetune for refcocog
|
32 |
-
</pre>
|
33 |
-
</details>
|
34 |
-
Evaluation is also the same as the default setting.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VG.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
VQA.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
Video_Captioning.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
checkpoints.md
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
# Checkpoints
|
2 |
-
|
3 |
-
We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.
|
4 |
-
|
5 |
-
## Pretraining
|
6 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
|
7 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a> (~470M parameters)
|
8 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt"> Pre-trained checkpoint (OFA-Base) </a> (~180M parameters)
|
9 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt"> Pre-trained checkpoint (OFA-Medium) </a> (~93M parameters)
|
10 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt"> Pre-trained checkpoint (OFA-Tiny) </a> (~33M parameters)
|
11 |
-
|
12 |
-
## Finetuning (OFA-Huge)
|
13 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_huge_best.pt"> Finetuned checkpoint for Caption on COCO </a>
|
14 |
-
|
15 |
-
## Finetuning (OFA-Large)
|
16 |
-
|
17 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
|
18 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_stage1_best.pt"> Finetuned checkpoint for Caption on COCO During Stage1 Finetuning </a>
|
19 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
|
20 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
|
21 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
|
22 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_large_best.pt"> Finetuned checkpoint for VQAv2 </a>
|
23 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_large_best.pt"> Finetuned checkpoint for SNLI-VE </a>
|
24 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_large_best.zip"> Finetuned checkpoint for Text-to-Image Generation on COCO && CLIP checkpoint && VQGAN checkpoint </a>
|
25 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/imagenet_1k_large_best.pt"> Finetuned checkpoint for ImageNet-1K </a>
|
26 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/gigaword_large_best.pt"> Finetuned checkpoint for Gigaword </a>
|
27 |
-
|
28 |
-
|
29 |
-
## Finetuning (OFA-Base)
|
30 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_base_best.pt"> Finetuned base checkpoint for Caption on COCO </a>
|
31 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_base_best.pt"> Finetuned base checkpoint for RefCOCO </a>
|
32 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_base_best.pt"> Finetuned base checkpoint for RefCOCO+ </a>
|
33 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_base_best.pt"> Finetuned base checkpoint for RefCOCOg </a>
|
34 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_base_best.pt"> Finetuned base checkpoint for VQAv2 </a>
|
35 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_base_best.pt"> Finetuned base checkpoint for SNLI-VE </a>
|
36 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_base_best.pt"> Finetuned base checkpoint for Text-to-Image Generation on COCO </a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints_cn.md
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
# Checkpoints (OFA-CN)
|
2 |
-
|
3 |
-
We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
|
4 |
-
<br>
|
5 |
-
|
6 |
-
## Checkpoints
|
7 |
-
Below we provide the links for downloading the Chinese OFA checkpoints.
|
8 |
-
|
9 |
-
### Pretraining
|
10 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
|
11 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)
|
12 |
-
|
13 |
-
### Finetuning (OFA-Large)
|
14 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
15 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
16 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
17 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
18 |
-
|
19 |
-
### Finetuning (OFA-Base)
|
20 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
21 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
22 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
23 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
24 |
-
<br>
|
25 |
-
|
26 |
-
## Model Card
|
27 |
-
Below we provide the basic information of the base-size and large-size OFA-CN.
|
28 |
-
|
29 |
-
<table border="1" width="100%">
|
30 |
-
<tr align="center">
|
31 |
-
<th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
|
32 |
-
</tr>
|
33 |
-
<tr align="center">
|
34 |
-
<td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
|
35 |
-
</tr>
|
36 |
-
<tr align="center">
|
37 |
-
<td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
|
38 |
-
</tr>
|
39 |
-
</tr>
|
40 |
-
</table>
|
41 |
-
<br>
|
42 |
-
|
43 |
-
## Results
|
44 |
-
Below we provide the results of OFA-CN and the baselines for comparison.
|
45 |
-
|
46 |
-
### [MUGE Caption]("https://tianchi.aliyun.com/muge")
|
47 |
-
<table border="1" width="100%">
|
48 |
-
<tr align="center">
|
49 |
-
<td>Model</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
|
50 |
-
</tr>
|
51 |
-
<tr align="center">
|
52 |
-
<td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
|
53 |
-
</tr>
|
54 |
-
<tr align="center">
|
55 |
-
<td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
|
56 |
-
</tr>
|
57 |
-
<tr align="center">
|
58 |
-
<td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
|
59 |
-
</tr>
|
60 |
-
<tr align="center">
|
61 |
-
<td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
|
62 |
-
</tr>
|
63 |
-
</table>
|
64 |
-
|
65 |
-
### RefCOCO-CN Series
|
66 |
-
<table border="1" width="100%">
|
67 |
-
<tr align="center">
|
68 |
-
<td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
|
69 |
-
</tr>
|
70 |
-
<tr align="center">
|
71 |
-
<td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
|
72 |
-
</tr>
|
73 |
-
<tr align="center">
|
74 |
-
<td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
|
75 |
-
</tr>
|
76 |
-
<tr align="center">
|
77 |
-
<td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
|
78 |
-
</tr>
|
79 |
-
</table>
|
80 |
-
<br>
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colab.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Colab Notebooks
|
2 |
-
|
3 |
-
We provide Colab notebooks of different downstream tasks for you guys to enjoy OFA. See below.
|
4 |
-
|
5 |
-
* [Image Captioning in Huggingface Transformers](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
|
6 |
-
* [Generic Interface](https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m?usp=sharing#scrollTo=s9Vni6YUZOpC) (using different instructions to perform various tasks with just one model.)
|
7 |
-
* [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
|
8 |
-
* [Referring Expression Comprehension](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
|
9 |
-
* [Open-Domain Visual Question Answering](https://colab.research.google.com/drive/14v6OQe_MxV_HMnsiKfnEeMR1UMqhzZNb?usp=sharing)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets.md
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
# Datasets
|
2 |
-
|
3 |
-
We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
|
4 |
-
|
5 |
-
## Pretraining
|
6 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip"> A small subset of the pretraining data </a>
|
7 |
-
|
8 |
-
The pretraining datasets used in OFA are all publicly available. Here we provide the public links to these data, it is recommended that you download the data from the links first, and then process the downloaded dataset into a similar format as the examples we provided.
|
9 |
-
- _CC12M_: https://github.com/google-research-datasets/conceptual-12m
|
10 |
-
- _CC3M_: https://github.com/google-research-datasets/conceptual-captions
|
11 |
-
- _SBU_: https://www.cs.virginia.edu/~vicente/sbucaptions
|
12 |
-
- _COCO_: https://cocodataset.org/#home
|
13 |
-
- _VG_: https://visualgenome.org/
|
14 |
-
- _VQAv2_: https://visualqa.org/
|
15 |
-
- _GQA_: https://cs.stanford.edu/people/dorarad/gqa/about.html
|
16 |
-
- _RefCOCO_/_RefCOCO+_/RefCOCOg: https://github.com/lichengunc/refer
|
17 |
-
- _OpenImages_: https://storage.googleapis.com/openimages/web/index.html
|
18 |
-
- _Object365_: https://www.objects365.org/overview.html
|
19 |
-
- _YFCC100M (subset)_: https://github.com/openai/CLIP/blob/main/data/yfcc100m.md
|
20 |
-
- _ImageNet-21K_: https://image-net.org/index.php
|
21 |
-
- _Pile_: https://pile.eleuther.ai
|
22 |
-
|
23 |
-
## Vision & Language Tasks
|
24 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
|
25 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
|
26 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
|
27 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
|
28 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/vqa_data/vqa_data.zip"> Dataset for VQAv2 </a> (we have also provided chunked parts of the dataset files for more convenient downloading, please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>)
|
29 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/snli_ve_data/snli_ve_data.zip"> Dataset for SNLI-VE </a>
|
30 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen.zip"> Dataset for Text-to-Image Genearion </a>
|
31 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen_origin_id.zip"> Dataset for Text-to-Image Genearion (with original id) </a>
|
32 |
-
|
33 |
-
## Vision Tasks
|
34 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/imagenet_1k_data/imagenet_1k_data.zip"> Dataset for ImageNet-1K </a>
|
35 |
-
|
36 |
-
## Language Tasks
|
37 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/cola_data.zip"> Dataset for COLA </a>
|
38 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mnli_data.zip"> Dataset for MNLI </a>
|
39 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mrpc_data.zip"> Dataset for MRPC </a>
|
40 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qnli_data.zip"> Dataset for QNLI </a>
|
41 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qqp_data.zip"> Dataset for QQP </a>
|
42 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/rte_data.zip"> Dataset for RTE </a>
|
43 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/sst2_data.zip"> Dataset for SST2 </a>
|
44 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/gigaword_data/gigaword_data.zip"> Dataset for Gigaword </a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluate.py
DELETED
@@ -1,239 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3 -u
|
2 |
-
# Copyright 2022 The OFA-Sys Team.
|
3 |
-
# All rights reserved.
|
4 |
-
# This source code is licensed under the Apache 2.0 license
|
5 |
-
# found in the LICENSE file in the root directory.
|
6 |
-
|
7 |
-
import logging
|
8 |
-
import os
|
9 |
-
import sys
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
from fairseq import distributed_utils, options, tasks, utils
|
14 |
-
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
15 |
-
from fairseq.logging import progress_bar
|
16 |
-
from fairseq.utils import reset_logging
|
17 |
-
from omegaconf import DictConfig
|
18 |
-
|
19 |
-
from utils import checkpoint_utils
|
20 |
-
from utils.eval_utils import eval_step, merge_results
|
21 |
-
from utils.zero_shot_utils import zero_shot_step
|
22 |
-
|
23 |
-
logging.basicConfig(
|
24 |
-
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
25 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
26 |
-
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
27 |
-
stream=sys.stdout,
|
28 |
-
)
|
29 |
-
logger = logging.getLogger("ofa.evaluate")
|
30 |
-
|
31 |
-
from utils.utils import print_trainable_params_percentage, setup_for_distributed
|
32 |
-
|
33 |
-
def apply_half(t):
|
34 |
-
if t.dtype is torch.float32:
|
35 |
-
return t.to(dtype=torch.half)
|
36 |
-
return t
|
37 |
-
|
38 |
-
|
39 |
-
def main(cfg: DictConfig, **kwargs):
|
40 |
-
utils.import_user_module(cfg.common)
|
41 |
-
|
42 |
-
setup_for_distributed(distributed_utils.is_master(cfg.distributed_training))
|
43 |
-
|
44 |
-
reset_logging()
|
45 |
-
# logger.info(cfg)
|
46 |
-
|
47 |
-
assert (
|
48 |
-
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
49 |
-
), "Must specify batch size either with --max-tokens or --batch-size"
|
50 |
-
|
51 |
-
# Fix seed for stochastic decoding
|
52 |
-
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
53 |
-
np.random.seed(cfg.common.seed)
|
54 |
-
utils.set_torch_seed(cfg.common.seed)
|
55 |
-
|
56 |
-
use_fp16 = cfg.common.fp16
|
57 |
-
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
58 |
-
|
59 |
-
if use_cuda:
|
60 |
-
torch.cuda.set_device(cfg.distributed_training.device_id)
|
61 |
-
|
62 |
-
# Load ensemble
|
63 |
-
overrides = eval(cfg.common_eval.model_overrides)
|
64 |
-
# Deal with beam-search / all-candidate VQA eval
|
65 |
-
if cfg.task._name == "vqa_gen":
|
66 |
-
overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
|
67 |
-
|
68 |
-
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
69 |
-
|
70 |
-
# print("cfg", cfg)
|
71 |
-
# print(kwargs)
|
72 |
-
# cfg.model.num_frames = kwargs["num_frames"]
|
73 |
-
# cfg.model.patch_frame_size = kwargs["patch_frame_size"]
|
74 |
-
# print("cfg.model", cfg.model)
|
75 |
-
# strict = getattr(kwargs, 'strict', True)
|
76 |
-
strict = kwargs['strict']
|
77 |
-
logger.info('load checkpoint, strict:{}'.format(strict))
|
78 |
-
|
79 |
-
if kwargs["zero_shot"]:
|
80 |
-
for arg_name, arg_val in overrides.items():
|
81 |
-
cfg.task[arg_name] = arg_val
|
82 |
-
# print("Zero-shot eval", cfg.task, cfg)
|
83 |
-
|
84 |
-
if hasattr(cfg.task, "add_caption"):
|
85 |
-
cfg.task.add_caption = False
|
86 |
-
print("cfg.task", cfg.task)
|
87 |
-
task = tasks.setup_task(cfg.task)
|
88 |
-
# cfg.criterion.sample_patch_num = 776
|
89 |
-
|
90 |
-
|
91 |
-
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
92 |
-
utils.split_paths(cfg.common_eval.path),
|
93 |
-
arg_overrides=overrides,
|
94 |
-
task=task,
|
95 |
-
suffix=cfg.checkpoint.checkpoint_suffix,
|
96 |
-
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
|
97 |
-
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
98 |
-
)
|
99 |
-
for m in models:
|
100 |
-
m.encoder.sample_patch_num = 776
|
101 |
-
saved_cfg.task = cfg.task
|
102 |
-
# print("saved_cfg", saved_cfg)
|
103 |
-
else:
|
104 |
-
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
105 |
-
utils.split_paths(cfg.common_eval.path),
|
106 |
-
arg_overrides=overrides,
|
107 |
-
suffix=cfg.checkpoint.checkpoint_suffix,
|
108 |
-
strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
|
109 |
-
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
110 |
-
)
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
# task.cfg['evaluate_cfg'] = cfg.task
|
115 |
-
# print(task.cfg)
|
116 |
-
kwargs['evaluate_cfg'] = cfg.task
|
117 |
-
# print(kwargs)
|
118 |
-
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
119 |
-
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
120 |
-
|
121 |
-
# Move models to GPU
|
122 |
-
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
|
123 |
-
if kwargs['ema_eval']:
|
124 |
-
logger.info("loading EMA weights from {}".format(ckpt_path))
|
125 |
-
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
|
126 |
-
model.eval()
|
127 |
-
print("use fp16", use_fp16)
|
128 |
-
if use_fp16:
|
129 |
-
|
130 |
-
model.half()
|
131 |
-
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
132 |
-
model.cuda()
|
133 |
-
model.prepare_for_inference_(cfg)
|
134 |
-
|
135 |
-
# Load dataset (possibly sharded)
|
136 |
-
itr = task.get_batch_iterator(
|
137 |
-
dataset=task.dataset(cfg.dataset.gen_subset),
|
138 |
-
max_tokens=cfg.dataset.max_tokens,
|
139 |
-
max_sentences=cfg.dataset.batch_size,
|
140 |
-
max_positions=utils.resolve_max_positions(
|
141 |
-
task.max_positions(), *[m.max_positions() for m in models]
|
142 |
-
),
|
143 |
-
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
144 |
-
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
145 |
-
seed=cfg.common.seed,
|
146 |
-
num_shards=cfg.distributed_training.distributed_world_size,
|
147 |
-
shard_id=cfg.distributed_training.distributed_rank,
|
148 |
-
num_workers=cfg.dataset.num_workers,
|
149 |
-
data_buffer_size=cfg.dataset.data_buffer_size,
|
150 |
-
).next_epoch_itr(shuffle=False)
|
151 |
-
progress = progress_bar.progress_bar(
|
152 |
-
itr,
|
153 |
-
log_format=cfg.common.log_format,
|
154 |
-
log_interval=cfg.common.log_interval,
|
155 |
-
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
156 |
-
)
|
157 |
-
|
158 |
-
# Initialize generator
|
159 |
-
generator = task.build_generator(models, cfg.generation)
|
160 |
-
|
161 |
-
results = []
|
162 |
-
score_sum = torch.FloatTensor([0]).cuda()
|
163 |
-
score_cnt = torch.FloatTensor([0]).cuda()
|
164 |
-
|
165 |
-
score_sum_list = []
|
166 |
-
score_cnt_list = []
|
167 |
-
for sample in progress:
|
168 |
-
if "net_input" not in sample:
|
169 |
-
continue
|
170 |
-
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
171 |
-
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
172 |
-
with torch.no_grad():
|
173 |
-
if kwargs["zero_shot"] and kwargs['noconstraints']:
|
174 |
-
result, scores = zero_shot_step(task, generator, models, sample)
|
175 |
-
else:
|
176 |
-
result, scores = eval_step(task, generator, models, sample, **kwargs)
|
177 |
-
### else refcoco res, score, other_scores
|
178 |
-
|
179 |
-
# print(scores)
|
180 |
-
scalar = False
|
181 |
-
if isinstance(scores, list):
|
182 |
-
if not isinstance(scores[0], list):
|
183 |
-
try:
|
184 |
-
tmp = sum(scores[0])
|
185 |
-
scalar=False
|
186 |
-
except:
|
187 |
-
scalar=True
|
188 |
-
# print(scalar)
|
189 |
-
# print(sum(scores[0]))
|
190 |
-
if isinstance(scores, list) and not scalar:
|
191 |
-
names = result[0]
|
192 |
-
result = result[1]
|
193 |
-
if len(score_sum_list) == 0:
|
194 |
-
score_sum_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
|
195 |
-
score_cnt_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
|
196 |
-
|
197 |
-
for i in range(len(scores)):
|
198 |
-
|
199 |
-
|
200 |
-
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
|
201 |
-
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
|
202 |
-
else:
|
203 |
-
for i in range(len(scores)):
|
204 |
-
score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
|
205 |
-
score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
|
206 |
-
else:
|
207 |
-
score_sum += sum(scores) if scores is not None else 0
|
208 |
-
score_cnt += len(scores) if scores is not None else 0
|
209 |
-
results += result
|
210 |
-
progress.log({"sentences": sample["nsentences"]})
|
211 |
-
|
212 |
-
|
213 |
-
### merge per metric
|
214 |
-
if len(score_sum_list) > 0:
|
215 |
-
print(names, len(score_sum_list))
|
216 |
-
for i in range(len(score_sum_list)):
|
217 |
-
print(names[i])
|
218 |
-
merge_results(task, cfg, logger, score_cnt_list[i], score_sum_list[i], results)
|
219 |
-
else:
|
220 |
-
merge_results(task, cfg, logger, score_cnt, score_sum, results)
|
221 |
-
|
222 |
-
|
223 |
-
def cli_main():
|
224 |
-
parser = options.get_generation_parser()
|
225 |
-
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
|
226 |
-
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
|
227 |
-
parser.add_argument("--zero-shot", action='store_true')
|
228 |
-
parser.add_argument("--strict", action='store_false')
|
229 |
-
parser.add_argument("--noconstraints", action='store_true')
|
230 |
-
args = options.parse_args_and_arch(parser)
|
231 |
-
cfg = convert_namespace_to_omegaconf(args)
|
232 |
-
distributed_utils.call_main(
|
233 |
-
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval,
|
234 |
-
zero_shot=args.zero_shot, strict=args.strict, noconstraints=args.noconstraints
|
235 |
-
)
|
236 |
-
|
237 |
-
|
238 |
-
if __name__ == "__main__":
|
239 |
-
cli_main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modelscope.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
# ModelScope
|
2 |
-
|
3 |
-
ModelScope is a new platform that provides "Model-As-A-Service", where users can use state-of-the-art models with the lowest costs of efforts as possible. We have released:
|
4 |
-
* The pretrained and finetuned **OFA** models
|
5 |
-
* **Chinese CLIP** (the CLIP pretrained Chinese data, which was previously released in our organization)
|
6 |
-
|
7 |
-
on the platform, including the English and Chinese ones. Feel free to check these models and use them with ModelScope, and also feel free to send us feedbacks to help us improve the product.
|
8 |
-
|
9 |
-
## Chinese
|
10 |
-
* Chinese CLIP \[[Base](https://www.modelscope.cn/#/models/damo/multi-modal_clip-vit-base-patch16_zh/summary) | [Large](https://www.modelscope.cn/#/models/damo/multi-modal_clip-vit-large-patch14_zh/summary)\]
|
11 |
-
* Finetuned OFA on Visual Grounding (RefCOCO) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_large_zh/summary)\]
|
12 |
-
|
13 |
-
## English
|
14 |
-
* Finetuned OFA on Image Captioning \[[Large](https://www.modelscope.cn/#/models/damo/ofa_image-caption_coco_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_image-caption_coco_distilled_en/summary)\]
|
15 |
-
* Finetuned OFA on Text-to-Image Generation \[[Large](https://www.modelscope.cn/#/models/damo/ofa_text-to-image-synthesis_coco_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_distilled_en/summary)\]
|
16 |
-
* Finetuned OFA on Visual Question Answering \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-question-answering_pretrain_large_en/summary)\]
|
17 |
-
* Finetuned OFA on Visual Grounding (RefCOCO) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_large_en/summary)\]
|
18 |
-
* Finetuned OFA on Visual Entailment \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-entailment_snli-ve_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_visual-entailment_snli-ve_distilled_v2_en/summary)\]
|
19 |
-
* Finetuned OFA on Summarization (Gigaword) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_summarization_gigaword_large_en/summary)\]
|
20 |
-
* Finetuned OFA on Natural Language Entailment (MNLI, can be used to finetune on the GLUE benchmark) \[[Large](https://modelscope.cn/#/models/damo/ofa_text-classification_mnli_large_en/summary)\]
|
21 |
-
* Finetuned OFA on Image Classification (ImageNet-1k) \[[Large](https://modelscope.cn/#/models/damo/ofa_image-classification_imagenet_large_en/summary)\]
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ofa_test.ipynb
DELETED
@@ -1,2499 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"metadata": {},
|
6 |
-
"source": [
|
7 |
-
"# Import"
|
8 |
-
]
|
9 |
-
},
|
10 |
-
{
|
11 |
-
"cell_type": "code",
|
12 |
-
"execution_count": 1,
|
13 |
-
"metadata": {},
|
14 |
-
"outputs": [],
|
15 |
-
"source": [
|
16 |
-
"%load_ext autoreload\n",
|
17 |
-
"%autoreload 2"
|
18 |
-
]
|
19 |
-
},
|
20 |
-
{
|
21 |
-
"cell_type": "code",
|
22 |
-
"execution_count": 2,
|
23 |
-
"metadata": {},
|
24 |
-
"outputs": [],
|
25 |
-
"source": [
|
26 |
-
"import os\n",
|
27 |
-
"import json \n",
|
28 |
-
"import torch\n",
|
29 |
-
"# import clip\n",
|
30 |
-
"from PIL import Image\n",
|
31 |
-
"# import sng_parser\n",
|
32 |
-
"from tqdm import tqdm \n",
|
33 |
-
"import codecs\n",
|
34 |
-
"import numpy as np\n",
|
35 |
-
"import csv\n",
|
36 |
-
"import sys\n",
|
37 |
-
"\n",
|
38 |
-
"from io import BytesIO\n",
|
39 |
-
"import base64"
|
40 |
-
]
|
41 |
-
},
|
42 |
-
{
|
43 |
-
"cell_type": "code",
|
44 |
-
"execution_count": null,
|
45 |
-
"metadata": {},
|
46 |
-
"outputs": [],
|
47 |
-
"source": []
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"cell_type": "markdown",
|
51 |
-
"metadata": {},
|
52 |
-
"source": [
|
53 |
-
"# Data"
|
54 |
-
]
|
55 |
-
},
|
56 |
-
{
|
57 |
-
"cell_type": "markdown",
|
58 |
-
"metadata": {},
|
59 |
-
"source": [
|
60 |
-
"## Explore"
|
61 |
-
]
|
62 |
-
},
|
63 |
-
{
|
64 |
-
"cell_type": "code",
|
65 |
-
"execution_count": 16,
|
66 |
-
"metadata": {},
|
67 |
-
"outputs": [
|
68 |
-
{
|
69 |
-
"name": "stderr",
|
70 |
-
"output_type": "stream",
|
71 |
-
"text": [
|
72 |
-
"100it [00:00, 14325.30it/s]\n"
|
73 |
-
]
|
74 |
-
}
|
75 |
-
],
|
76 |
-
"source": [
|
77 |
-
"csv.field_size_limit(sys.maxsize)\n",
|
78 |
-
"\n",
|
79 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_example/vision_language_examples.tsv'\n",
|
80 |
-
"# selected_cols='0,1,2,3,4,5,6,7'\n",
|
81 |
-
"\n",
|
82 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_example/detection_examples.tsv'\n",
|
83 |
-
"# selected_cols='0,1,2'\n",
|
84 |
-
"\n",
|
85 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_example/image_examples.tsv'\n",
|
86 |
-
"# selected_cols='0,1,2'\n",
|
87 |
-
"\n",
|
88 |
-
"path_data = '/data/mshukor/data/ofa/pretrain_example/text_examples.tsv'\n",
|
89 |
-
"selected_cols='0,1'\n",
|
90 |
-
"\n",
|
91 |
-
"data_example = []\n",
|
92 |
-
"\n",
|
93 |
-
"selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
94 |
-
"\n",
|
95 |
-
"with open(path_data) as file:\n",
|
96 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
97 |
-
" for line in tqdm(tsv_file):\n",
|
98 |
-
"\n",
|
99 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
100 |
-
"# print(d)\n",
|
101 |
-
" data_example.append(d)\n",
|
102 |
-
" \n"
|
103 |
-
]
|
104 |
-
},
|
105 |
-
{
|
106 |
-
"cell_type": "code",
|
107 |
-
"execution_count": 17,
|
108 |
-
"metadata": {},
|
109 |
-
"outputs": [
|
110 |
-
{
|
111 |
-
"data": {
|
112 |
-
"text/plain": [
|
113 |
-
"['100',\n",
|
114 |
-
" '...please depart this field clean unless you might be answering the question. do not ask questions you already know the answer to. thanks.retrieved from \" \" ad blocker interference detected! wikia is a single-to-usefulness web site that makes cash from promoting. we\\'ve a experience for viewers using ad blockers wikia shouldn\\'t be if youve made further modifications. take away the custom ad blocker (s) and the page leave timber as expected. categories : un-answered questionsadd class cancelsave per the reddit twine, flac files will be synced to an ios gadget via icloud impel, then accessed through thenew information utility , which will allow for local playback of the excessive-high quality audio files straight by the side of the device. if , it could stamp the first time that apple has offered help for the favored flac format an ios gadget.']"
|
115 |
-
]
|
116 |
-
},
|
117 |
-
"execution_count": 17,
|
118 |
-
"metadata": {},
|
119 |
-
"output_type": "execute_result"
|
120 |
-
}
|
121 |
-
],
|
122 |
-
"source": [
|
123 |
-
"line"
|
124 |
-
]
|
125 |
-
},
|
126 |
-
{
|
127 |
-
"cell_type": "code",
|
128 |
-
"execution_count": 12,
|
129 |
-
"metadata": {},
|
130 |
-
"outputs": [
|
131 |
-
{
|
132 |
-
"data": {
|
133 |
-
"text/plain": [
|
134 |
-
"['7',\n",
|
135 |
-
" 'perhaps the clearest indication of who won and lost came quickly on the heels of the event itself: the democratic post-debate message was that joe biden scored a clear win; the republican message was that joe biden was too mean to paul ryan. the former is a boast of success; the latter is an excuse for failure. in the larger context, it’s hard to overstate how much democrats needed a shot in the arm like this. the surface-level goals of any vice presidential debate is for the candidates to demonstrate a capacity to step up in the event of a crisis, while defending their ticket’s agenda and knocking their rivals’ agenda. but for biden, the overarching benefit was about the basic morale of his party with less than four weeks to go until election day: he wanted to give democratic voters something to feel good about, and he did. who the hell am i! i’m a liberal that is extreme in some ways and not in others. i support president obama and make no apologies for it. i think he has done a phenomenal job, especially when you consider that he inherited a huge mess and has faced unprecedented opposition from a lazy & desperate republican party. i’m a film producer/director/editor, adjunct professor, technician, media critic and photographer when i’m not reading left wing blogs and typing on this one. – on twitter @extremeliberal or email at liberalforreal (at) gmail.com own an important part of american history! cicely tyson narrates this award winning documentary that tells the story of african american migration from the old south to the prosperous north. winner of 5 awards including \"best film\" at the astoria international film festival, the \"paul robeson award\" at the newark black film festival and \"best film relating to the black experience\" at the xxv international black cinema berlin/germany!']"
|
136 |
-
]
|
137 |
-
},
|
138 |
-
"execution_count": 12,
|
139 |
-
"metadata": {},
|
140 |
-
"output_type": "execute_result"
|
141 |
-
}
|
142 |
-
],
|
143 |
-
"source": [
|
144 |
-
"data_[6]"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"cell_type": "code",
|
149 |
-
"execution_count": null,
|
150 |
-
"metadata": {},
|
151 |
-
"outputs": [],
|
152 |
-
"source": [
|
153 |
-
"tasks = set()\n",
|
154 |
-
"datasets = set()\n",
|
155 |
-
"\n",
|
156 |
-
"for d in data:\n",
|
157 |
-
" tasks.add(d[-1])\n",
|
158 |
-
" datasets.add(d[-2])"
|
159 |
-
]
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "code",
|
163 |
-
"execution_count": null,
|
164 |
-
"metadata": {},
|
165 |
-
"outputs": [],
|
166 |
-
"source": [
|
167 |
-
"int(data[10][0])"
|
168 |
-
]
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"execution_count": null,
|
173 |
-
"metadata": {},
|
174 |
-
"outputs": [],
|
175 |
-
"source": [
|
176 |
-
"# len(data[0][2:][0].split(' '))\n",
|
177 |
-
"# len(data[0][1])\n",
|
178 |
-
"text = data[10][1]\n",
|
179 |
-
"print(len(text.split(' ')))\n",
|
180 |
-
"print(len(text))\n",
|
181 |
-
"from nltk.tokenize import word_tokenize\n",
|
182 |
-
"len(word_tokenize(text))"
|
183 |
-
]
|
184 |
-
},
|
185 |
-
{
|
186 |
-
"cell_type": "code",
|
187 |
-
"execution_count": null,
|
188 |
-
"metadata": {},
|
189 |
-
"outputs": [],
|
190 |
-
"source": [
|
191 |
-
"text"
|
192 |
-
]
|
193 |
-
},
|
194 |
-
{
|
195 |
-
"cell_type": "code",
|
196 |
-
"execution_count": null,
|
197 |
-
"metadata": {},
|
198 |
-
"outputs": [],
|
199 |
-
"source": [
|
200 |
-
"from nltk.tokenize.treebank import TreebankWordDetokenizer\n",
|
201 |
-
"TreebankWordDetokenizer().detokenize(word_tokenize(text))"
|
202 |
-
]
|
203 |
-
},
|
204 |
-
{
|
205 |
-
"cell_type": "code",
|
206 |
-
"execution_count": null,
|
207 |
-
"metadata": {},
|
208 |
-
"outputs": [],
|
209 |
-
"source": [
|
210 |
-
"key = 'refcoco_train'\n",
|
211 |
-
"index = -2\n",
|
212 |
-
"for d in data:\n",
|
213 |
-
" if d[index] == key:\n",
|
214 |
-
" print(d[2:])\n",
|
215 |
-
"# break"
|
216 |
-
]
|
217 |
-
},
|
218 |
-
{
|
219 |
-
"cell_type": "code",
|
220 |
-
"execution_count": null,
|
221 |
-
"metadata": {},
|
222 |
-
"outputs": [],
|
223 |
-
"source": [
|
224 |
-
"d[4].split(',')\n",
|
225 |
-
"str([287.0, 127.0, 340.0, 162.0])\n",
|
226 |
-
"'{:.2f},{:.2f},{:.2f},{:.2f}'.format(287.0, 127.0, 340.0, 162.0)"
|
227 |
-
]
|
228 |
-
},
|
229 |
-
{
|
230 |
-
"cell_type": "code",
|
231 |
-
"execution_count": null,
|
232 |
-
"metadata": {},
|
233 |
-
"outputs": [],
|
234 |
-
"source": [
|
235 |
-
"print(len(data))\n",
|
236 |
-
"data[0]"
|
237 |
-
]
|
238 |
-
},
|
239 |
-
{
|
240 |
-
"cell_type": "code",
|
241 |
-
"execution_count": null,
|
242 |
-
"metadata": {},
|
243 |
-
"outputs": [],
|
244 |
-
"source": [
|
245 |
-
"all_captions_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/all_captions.txt'\n",
|
246 |
-
"all_objects_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/object.txt'\n",
|
247 |
-
"\n",
|
248 |
-
"all_object_list = [\n",
|
249 |
-
" row.strip() for row in open(all_objects_path) if row.strip() != ''\n",
|
250 |
-
"]\n",
|
251 |
-
"all_caption_list = [\n",
|
252 |
-
" row.strip() for row in open(all_captions_path) if row.strip() != ''\n",
|
253 |
-
"]\n"
|
254 |
-
]
|
255 |
-
},
|
256 |
-
{
|
257 |
-
"cell_type": "code",
|
258 |
-
"execution_count": null,
|
259 |
-
"metadata": {},
|
260 |
-
"outputs": [],
|
261 |
-
"source": [
|
262 |
-
"len(all_object_list)"
|
263 |
-
]
|
264 |
-
},
|
265 |
-
{
|
266 |
-
"cell_type": "code",
|
267 |
-
"execution_count": null,
|
268 |
-
"metadata": {},
|
269 |
-
"outputs": [],
|
270 |
-
"source": [
|
271 |
-
"len(all_caption_list)"
|
272 |
-
]
|
273 |
-
},
|
274 |
-
{
|
275 |
-
"cell_type": "code",
|
276 |
-
"execution_count": null,
|
277 |
-
"metadata": {},
|
278 |
-
"outputs": [],
|
279 |
-
"source": [
|
280 |
-
"all_object_list[:10]"
|
281 |
-
]
|
282 |
-
},
|
283 |
-
{
|
284 |
-
"cell_type": "code",
|
285 |
-
"execution_count": null,
|
286 |
-
"metadata": {},
|
287 |
-
"outputs": [],
|
288 |
-
"source": [
|
289 |
-
"all_caption_list[:10]"
|
290 |
-
]
|
291 |
-
},
|
292 |
-
{
|
293 |
-
"cell_type": "code",
|
294 |
-
"execution_count": null,
|
295 |
-
"metadata": {},
|
296 |
-
"outputs": [],
|
297 |
-
"source": [
|
298 |
-
"json_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/type2ans.json'\n",
|
299 |
-
"type2ans = json.load(open(json_path,'r'))"
|
300 |
-
]
|
301 |
-
},
|
302 |
-
{
|
303 |
-
"cell_type": "code",
|
304 |
-
"execution_count": null,
|
305 |
-
"metadata": {},
|
306 |
-
"outputs": [],
|
307 |
-
"source": [
|
308 |
-
"type2ans.keys()\n",
|
309 |
-
"# type2ans['what color is the']"
|
310 |
-
]
|
311 |
-
},
|
312 |
-
{
|
313 |
-
"cell_type": "markdown",
|
314 |
-
"metadata": {},
|
315 |
-
"source": [
|
316 |
-
"### Our data"
|
317 |
-
]
|
318 |
-
},
|
319 |
-
{
|
320 |
-
"cell_type": "code",
|
321 |
-
"execution_count": null,
|
322 |
-
"metadata": {},
|
323 |
-
"outputs": [],
|
324 |
-
"source": []
|
325 |
-
},
|
326 |
-
{
|
327 |
-
"cell_type": "code",
|
328 |
-
"execution_count": 3,
|
329 |
-
"metadata": {},
|
330 |
-
"outputs": [
|
331 |
-
{
|
332 |
-
"name": "stderr",
|
333 |
-
"output_type": "stream",
|
334 |
-
"text": [
|
335 |
-
"181767it [00:02, 70482.72it/s]\n"
|
336 |
-
]
|
337 |
-
}
|
338 |
-
],
|
339 |
-
"source": [
|
340 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
|
341 |
-
"# selected_cols='0,1'\n",
|
342 |
-
"\n",
|
343 |
-
"path_data = '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
|
344 |
-
"selected_cols='0,1,2'\n",
|
345 |
-
"\n",
|
346 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
|
347 |
-
"# selected_cols='0,1,2,3,4,5,6,7'\n",
|
348 |
-
"\n",
|
349 |
-
"# path_data = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
350 |
-
"# selected_cols='0,1,2'\n",
|
351 |
-
"\n",
|
352 |
-
"data = []\n",
|
353 |
-
"\n",
|
354 |
-
"selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
355 |
-
"\n",
|
356 |
-
"with open(path_data) as file:\n",
|
357 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
358 |
-
" for line in tqdm(tsv_file):\n",
|
359 |
-
"\n",
|
360 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
361 |
-
"# print(d)\n",
|
362 |
-
" data.append(d)"
|
363 |
-
]
|
364 |
-
},
|
365 |
-
{
|
366 |
-
"cell_type": "code",
|
367 |
-
"execution_count": 21,
|
368 |
-
"metadata": {},
|
369 |
-
"outputs": [],
|
370 |
-
"source": [
|
371 |
-
"# new_data = []\n",
|
372 |
-
"# for d in data:\n",
|
373 |
-
"# label_list = d[2].strip().split('&&')\n",
|
374 |
-
"# new_label_list = []\n",
|
375 |
-
"# for label in label_list:\n",
|
376 |
-
"# lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
|
377 |
-
" \n",
|
378 |
-
"# if any([\"&\" in l for l in lab]):\n",
|
379 |
-
"# lab = [remove_special(l) for l in lab]\n",
|
380 |
-
" \n",
|
381 |
-
"# print(lab)\n",
|
382 |
-
"# lab_ = lab + label.strip().split(',', 5)[4:]\n",
|
383 |
-
"# lab_ = ','.join(lab_)\n",
|
384 |
-
"# new_label_list.append(lab_)\n",
|
385 |
-
"# new_label_list = ['&&'.join(new_label_list)]\n",
|
386 |
-
"# new_data.append(d[:2]+new_label_list)"
|
387 |
-
]
|
388 |
-
},
|
389 |
-
{
|
390 |
-
"cell_type": "code",
|
391 |
-
"execution_count": 24,
|
392 |
-
"metadata": {},
|
393 |
-
"outputs": [
|
394 |
-
{
|
395 |
-
"name": "stdout",
|
396 |
-
"output_type": "stream",
|
397 |
-
"text": [
|
398 |
-
"['&40.000', '155.000', '44.000', '164.000']\n"
|
399 |
-
]
|
400 |
-
}
|
401 |
-
],
|
402 |
-
"source": [
|
403 |
-
"for d in data:\n",
|
404 |
-
" label_list = d[2].strip().split('&&')\n",
|
405 |
-
" new_label_list = []\n",
|
406 |
-
" for label in label_list:\n",
|
407 |
-
" lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
|
408 |
-
" \n",
|
409 |
-
" if any([\"&\" in l for l in lab]):\n",
|
410 |
-
" print(lab)\n",
|
411 |
-
" # lab = [remove_special(l) for l in lab]"
|
412 |
-
]
|
413 |
-
},
|
414 |
-
{
|
415 |
-
"cell_type": "code",
|
416 |
-
"execution_count": 27,
|
417 |
-
"metadata": {},
|
418 |
-
"outputs": [
|
419 |
-
{
|
420 |
-
"data": {
|
421 |
-
"text/plain": [
|
422 |
-
"['0',\n",
|
423 |
-
" 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
|
424 |
-
" '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
|
425 |
-
]
|
426 |
-
},
|
427 |
-
"execution_count": 27,
|
428 |
-
"metadata": {},
|
429 |
-
"output_type": "execute_result"
|
430 |
-
}
|
431 |
-
],
|
432 |
-
"source": [
|
433 |
-
"new_data[0]"
|
434 |
-
]
|
435 |
-
},
|
436 |
-
{
|
437 |
-
"cell_type": "code",
|
438 |
-
"execution_count": 8,
|
439 |
-
"metadata": {},
|
440 |
-
"outputs": [
|
441 |
-
{
|
442 |
-
"data": {
|
443 |
-
"text/plain": [
|
444 |
-
"['1.020,279.960,534.110,480.000,67,dining table',\n",
|
445 |
-
" '90.670,271.490,262.510,480.000,62,chair',\n",
|
446 |
-
" '233.290,270.450,403.610,473.810,62,chair',\n",
|
447 |
-
" '367.820,264.270,506.970,480.000,62,chair',\n",
|
448 |
-
" '476.760,261.030,596.490,462.740,62,chair',\n",
|
449 |
-
" '263.030,174.370,417.670,299.400,64,potted plant',\n",
|
450 |
-
" '539.330,290.160,640.000,469.210,62,chair',\n",
|
451 |
-
" '10.790,260.030,125.120,384.070,62,chair',\n",
|
452 |
-
" '560.800,413.950,639.090,479.200,67,dining table',\n",
|
453 |
-
" '20.540,376.760,103.780,431.890,62,chair',\n",
|
454 |
-
" '1.080,373.210,32.360,480.000,62,chair',\n",
|
455 |
-
" '298.200,235.170,381.210,269.250,86,vase',\n",
|
456 |
-
" '152.170,256.670,230.580,285.780,62,chair',\n",
|
457 |
-
" '364.400,256.570,417.060,283.210,62,chair',\n",
|
458 |
-
" '296.780,277.790,329.260,289.780,84,book',\n",
|
459 |
-
" '292.800,289.310,314.210,300.650,84,book',\n",
|
460 |
-
" '285.800,257.460,299.770,273.600,62,chair']"
|
461 |
-
]
|
462 |
-
},
|
463 |
-
"execution_count": 8,
|
464 |
-
"metadata": {},
|
465 |
-
"output_type": "execute_result"
|
466 |
-
}
|
467 |
-
],
|
468 |
-
"source": [
|
469 |
-
"label_list"
|
470 |
-
]
|
471 |
-
},
|
472 |
-
{
|
473 |
-
"cell_type": "code",
|
474 |
-
"execution_count": 12,
|
475 |
-
"metadata": {},
|
476 |
-
"outputs": [],
|
477 |
-
"source": [
|
478 |
-
"def remove_special(input_string):\n",
|
479 |
-
" final_string = \"\"\n",
|
480 |
-
" for character in input_string:\n",
|
481 |
-
" if character == \" \":\n",
|
482 |
-
" final_string = final_string + character\n",
|
483 |
-
" else:\n",
|
484 |
-
" if(character.isalnum()):\n",
|
485 |
-
" final_string = final_string + character\n",
|
486 |
-
" return final_string"
|
487 |
-
]
|
488 |
-
},
|
489 |
-
{
|
490 |
-
"cell_type": "code",
|
491 |
-
"execution_count": 60,
|
492 |
-
"metadata": {},
|
493 |
-
"outputs": [
|
494 |
-
{
|
495 |
-
"name": "stderr",
|
496 |
-
"output_type": "stream",
|
497 |
-
"text": [
|
498 |
-
"100%|█| 5593207/5593207 [00:35<00:0\n"
|
499 |
-
]
|
500 |
-
}
|
501 |
-
],
|
502 |
-
"source": [
|
503 |
-
"for d in tqdm(data):\n",
|
504 |
-
" label = d[2]\n",
|
505 |
-
" d[2] = remove_special(caption)\n"
|
506 |
-
]
|
507 |
-
},
|
508 |
-
{
|
509 |
-
"cell_type": "code",
|
510 |
-
"execution_count": 4,
|
511 |
-
"metadata": {},
|
512 |
-
"outputs": [
|
513 |
-
{
|
514 |
-
"data": {
|
515 |
-
"text/plain": [
|
516 |
-
"['0',\n",
|
517 |
-
" 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
|
518 |
-
" '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
|
519 |
-
]
|
520 |
-
},
|
521 |
-
"execution_count": 4,
|
522 |
-
"metadata": {},
|
523 |
-
"output_type": "execute_result"
|
524 |
-
}
|
525 |
-
],
|
526 |
-
"source": [
|
527 |
-
"data[0]"
|
528 |
-
]
|
529 |
-
},
|
530 |
-
{
|
531 |
-
"cell_type": "code",
|
532 |
-
"execution_count": 78,
|
533 |
-
"metadata": {},
|
534 |
-
"outputs": [
|
535 |
-
{
|
536 |
-
"name": "stderr",
|
537 |
-
"output_type": "stream",
|
538 |
-
"text": [
|
539 |
-
"100%|█| 181767/181767 [00:00<00:00,\n"
|
540 |
-
]
|
541 |
-
}
|
542 |
-
],
|
543 |
-
"source": [
|
544 |
-
"for d in tqdm(data):\n",
|
545 |
-
" d[2] = d[2].replace('\\\"', '')\n"
|
546 |
-
]
|
547 |
-
},
|
548 |
-
{
|
549 |
-
"cell_type": "code",
|
550 |
-
"execution_count": 49,
|
551 |
-
"metadata": {},
|
552 |
-
"outputs": [],
|
553 |
-
"source": [
|
554 |
-
"data_ = []\n",
|
555 |
-
"with open(path_data) as file:\n",
|
556 |
-
" for i in tqdm(range(6458670)):\n",
|
557 |
-
" column_l = file.readline().rstrip(\"\\n\").split(\"\\t\")\n",
|
558 |
-
" data_.append(column_l)\n",
|
559 |
-
" if len(column_l) < 2:\n",
|
560 |
-
" break"
|
561 |
-
]
|
562 |
-
},
|
563 |
-
{
|
564 |
-
"cell_type": "code",
|
565 |
-
"execution_count": 64,
|
566 |
-
"metadata": {},
|
567 |
-
"outputs": [
|
568 |
-
{
|
569 |
-
"name": "stderr",
|
570 |
-
"output_type": "stream",
|
571 |
-
"text": [
|
572 |
-
"5593207it [00:03, 1463300.52it/s]\n"
|
573 |
-
]
|
574 |
-
}
|
575 |
-
],
|
576 |
-
"source": [
|
577 |
-
"\n",
|
578 |
-
"data_example = []\n",
|
579 |
-
"fp = open('/data/mshukor/data/ofa/pretrain_ours/vision_language_mini_.tsv', \"r\")\n",
|
580 |
-
"data_example = []\n",
|
581 |
-
"for line in tqdm(fp):\n",
|
582 |
-
" data_example.append(line)"
|
583 |
-
]
|
584 |
-
},
|
585 |
-
{
|
586 |
-
"cell_type": "code",
|
587 |
-
"execution_count": 74,
|
588 |
-
"metadata": {},
|
589 |
-
"outputs": [
|
590 |
-
{
|
591 |
-
"name": "stdout",
|
592 |
-
"output_type": "stream",
|
593 |
-
"text": [
|
594 |
-
"2796604\tcc3m/train/8/2d0d96e4ecb8e2e959a3bf10d59b9d05ac114aea.jpg\tthe residential development under construction in district\t\t\t\tcc3m\tcaption\n",
|
595 |
-
"\n"
|
596 |
-
]
|
597 |
-
}
|
598 |
-
],
|
599 |
-
"source": [
|
600 |
-
"print(data_example[2796604])"
|
601 |
-
]
|
602 |
-
},
|
603 |
-
{
|
604 |
-
"cell_type": "code",
|
605 |
-
"execution_count": 73,
|
606 |
-
"metadata": {},
|
607 |
-
"outputs": [
|
608 |
-
{
|
609 |
-
"name": "stdout",
|
610 |
-
"output_type": "stream",
|
611 |
-
"text": [
|
612 |
-
"/val2014/COCO_val2014_000000329789.jpg\tA young man is eating a slice of pizza in his room\t\t\t\tcoco_karp\tcaption\n",
|
613 |
-
"\n"
|
614 |
-
]
|
615 |
-
}
|
616 |
-
],
|
617 |
-
"source": [
|
618 |
-
"data_example[2796604]\n",
|
619 |
-
"fp.seek(2796604)\n",
|
620 |
-
"for l in fp:\n",
|
621 |
-
" print(l)\n",
|
622 |
-
" break"
|
623 |
-
]
|
624 |
-
},
|
625 |
-
{
|
626 |
-
"cell_type": "code",
|
627 |
-
"execution_count": 46,
|
628 |
-
"metadata": {},
|
629 |
-
"outputs": [
|
630 |
-
{
|
631 |
-
"name": "stdout",
|
632 |
-
"output_type": "stream",
|
633 |
-
"text": [
|
634 |
-
"2317 2317\n",
|
635 |
-
"2510 2514 2510\n"
|
636 |
-
]
|
637 |
-
}
|
638 |
-
],
|
639 |
-
"source": [
|
640 |
-
"len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1])# len(line)\n",
|
641 |
-
"# len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1].encode('utf-8'))\n",
|
642 |
-
"print(len(data_example[10]), len(data_example[10].encode('utf-8')))\n",
|
643 |
-
"print(len(data_[10]), len(data_[10].encode('utf-8')), len(data_[10].encode('utf-8').decode(\"utf-8\")))\n",
|
644 |
-
"\n"
|
645 |
-
]
|
646 |
-
},
|
647 |
-
{
|
648 |
-
"cell_type": "code",
|
649 |
-
"execution_count": null,
|
650 |
-
"metadata": {},
|
651 |
-
"outputs": [],
|
652 |
-
"source": [
|
653 |
-
"print(data_[10].encode('utf-8'))\n",
|
654 |
-
"print(data_[10])\n",
|
655 |
-
"\n",
|
656 |
-
"print(data_example[10].encode('utf-8'))\n",
|
657 |
-
"print(data_example[10])"
|
658 |
-
]
|
659 |
-
},
|
660 |
-
{
|
661 |
-
"cell_type": "code",
|
662 |
-
"execution_count": 4,
|
663 |
-
"metadata": {},
|
664 |
-
"outputs": [
|
665 |
-
{
|
666 |
-
"name": "stderr",
|
667 |
-
"output_type": "stream",
|
668 |
-
"text": [
|
669 |
-
"6458670it [01:45, 61129.36it/s] \n"
|
670 |
-
]
|
671 |
-
}
|
672 |
-
],
|
673 |
-
"source": [
|
674 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
|
675 |
-
"\n",
|
676 |
-
"fp = open(output_path, \"r\")\n",
|
677 |
-
"data_ = []\n",
|
678 |
-
"for line in tqdm(fp):\n",
|
679 |
-
" data_.append(line)\n",
|
680 |
-
" "
|
681 |
-
]
|
682 |
-
},
|
683 |
-
{
|
684 |
-
"cell_type": "code",
|
685 |
-
"execution_count": 12,
|
686 |
-
"metadata": {},
|
687 |
-
"outputs": [
|
688 |
-
{
|
689 |
-
"name": "stderr",
|
690 |
-
"output_type": "stream",
|
691 |
-
"text": [
|
692 |
-
"6458670it [04:08, 25941.37it/s]\n"
|
693 |
-
]
|
694 |
-
}
|
695 |
-
],
|
696 |
-
"source": [
|
697 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
|
698 |
-
"\n",
|
699 |
-
"start_id = 0 \n",
|
700 |
-
"num_max_characters = 2500\n",
|
701 |
-
"\n",
|
702 |
-
"with open(output_path, 'w', newline='\\n') as f_output:\n",
|
703 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
704 |
-
"\n",
|
705 |
-
" for i, t in tqdm(enumerate(data)):\n",
|
706 |
-
" text = t[1]\n",
|
707 |
-
" item = [start_id, text]\n",
|
708 |
-
" csv_output.writerow(item)\n",
|
709 |
-
" start_id+=1"
|
710 |
-
]
|
711 |
-
},
|
712 |
-
{
|
713 |
-
"cell_type": "code",
|
714 |
-
"execution_count": 28,
|
715 |
-
"metadata": {},
|
716 |
-
"outputs": [
|
717 |
-
{
|
718 |
-
"name": "stderr",
|
719 |
-
"output_type": "stream",
|
720 |
-
"text": [
|
721 |
-
"100%|█████████████████████████████████████████████████████████| 181767/181767 [00:03<00:00, 51934.09it/s]\n"
|
722 |
-
]
|
723 |
-
}
|
724 |
-
],
|
725 |
-
"source": [
|
726 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
|
727 |
-
"\n",
|
728 |
-
"with open(output_path, 'w', newline='\\n') as f_output:\n",
|
729 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
730 |
-
"\n",
|
731 |
-
" for t in tqdm(data):\n",
|
732 |
-
" csv_output.writerow(t)"
|
733 |
-
]
|
734 |
-
},
|
735 |
-
{
|
736 |
-
"cell_type": "code",
|
737 |
-
"execution_count": 63,
|
738 |
-
"metadata": {},
|
739 |
-
"outputs": [
|
740 |
-
{
|
741 |
-
"data": {
|
742 |
-
"text/plain": [
|
743 |
-
"['5593206',\n",
|
744 |
-
" 'train2014/COCO_train2014_000000524286.jpg',\n",
|
745 |
-
" '',\n",
|
746 |
-
" 'Is that a laptop?',\n",
|
747 |
-
" '1.0|!+yes',\n",
|
748 |
-
" '',\n",
|
749 |
-
" 'vqa_train',\n",
|
750 |
-
" 'qa']"
|
751 |
-
]
|
752 |
-
},
|
753 |
-
"execution_count": 63,
|
754 |
-
"metadata": {},
|
755 |
-
"output_type": "execute_result"
|
756 |
-
}
|
757 |
-
],
|
758 |
-
"source": [
|
759 |
-
"t"
|
760 |
-
]
|
761 |
-
},
|
762 |
-
{
|
763 |
-
"cell_type": "markdown",
|
764 |
-
"metadata": {},
|
765 |
-
"source": [
|
766 |
-
"## Create data tsv"
|
767 |
-
]
|
768 |
-
},
|
769 |
-
{
|
770 |
-
"cell_type": "code",
|
771 |
-
"execution_count": null,
|
772 |
-
"metadata": {},
|
773 |
-
"outputs": [],
|
774 |
-
"source": [
|
775 |
-
"def convert_img_to_str(file_name):\n",
|
776 |
-
" img = Image.open(file_name) # path to file\n",
|
777 |
-
" img_buffer = BytesIO()\n",
|
778 |
-
" img.save(img_buffer, format=img.format)\n",
|
779 |
-
" byte_data = img_buffer.getvalue()\n",
|
780 |
-
" base64_str = base64.b64encode(byte_data) # bytes\n",
|
781 |
-
" base64_str = base64_str.decode(\"utf-8\") # str\n",
|
782 |
-
" return base64_str"
|
783 |
-
]
|
784 |
-
},
|
785 |
-
{
|
786 |
-
"cell_type": "markdown",
|
787 |
-
"metadata": {},
|
788 |
-
"source": [
|
789 |
-
"### Create VL tsv"
|
790 |
-
]
|
791 |
-
},
|
792 |
-
{
|
793 |
-
"cell_type": "markdown",
|
794 |
-
"metadata": {},
|
795 |
-
"source": [
|
796 |
-
"#### Caption"
|
797 |
-
]
|
798 |
-
},
|
799 |
-
{
|
800 |
-
"cell_type": "code",
|
801 |
-
"execution_count": null,
|
802 |
-
"metadata": {},
|
803 |
-
"outputs": [],
|
804 |
-
"source": [
|
805 |
-
"original_data_path = '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json'\n",
|
806 |
-
"original_data = json.load(open(original_data_path,'r'))"
|
807 |
-
]
|
808 |
-
},
|
809 |
-
{
|
810 |
-
"cell_type": "code",
|
811 |
-
"execution_count": null,
|
812 |
-
"metadata": {},
|
813 |
-
"outputs": [],
|
814 |
-
"source": [
|
815 |
-
"from preprocess.utils import get_tsv_data_from_jsons\n",
|
816 |
-
" \n",
|
817 |
-
"datasets = [\n",
|
818 |
-
" '/data/mshukor/data/our_albef_data/json_pretrain/coco_karp.json',\n",
|
819 |
-
" '/data/mshukor/data/our_albef_data/json_pretrain/vg_albef.json',\n",
|
820 |
-
" '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json',\n",
|
821 |
-
" '/data/mshukor/data/our_albef_data/json_pretrain/cc3m.json', \n",
|
822 |
-
" \n",
|
823 |
-
" ['/data/mshukor/data/refcoco/refcoco+/refs(unc).p', '/data/mshukor/data/refcoco/refcoco+/instances.json'],\n",
|
824 |
-
" \n",
|
825 |
-
" '/data/mshukor/data/our_albef_data/data/vqa_train.json',\n",
|
826 |
-
"]\n",
|
827 |
-
"\n",
|
828 |
-
"start_id = 0\n",
|
829 |
-
"task_types = ['caption',\n",
|
830 |
-
" 'caption',\n",
|
831 |
-
" 'caption',\n",
|
832 |
-
" 'caption',\n",
|
833 |
-
" 'visual_grounding',\n",
|
834 |
-
" 'qa',]\n",
|
835 |
-
"\n",
|
836 |
-
"tsvs = get_tsv_data_from_jsons(datasets, start_id, task_types, convert_images=False)\n"
|
837 |
-
]
|
838 |
-
},
|
839 |
-
{
|
840 |
-
"cell_type": "code",
|
841 |
-
"execution_count": null,
|
842 |
-
"metadata": {},
|
843 |
-
"outputs": [],
|
844 |
-
"source": [
|
845 |
-
"len(tsvs)\n",
|
846 |
-
"# tsvs[-10000]\n",
|
847 |
-
"tsvs[-1000000]"
|
848 |
-
]
|
849 |
-
},
|
850 |
-
{
|
851 |
-
"cell_type": "code",
|
852 |
-
"execution_count": null,
|
853 |
-
"metadata": {},
|
854 |
-
"outputs": [],
|
855 |
-
"source": [
|
856 |
-
"import csv\n",
|
857 |
-
"from io import StringIO\n",
|
858 |
-
"\n",
|
859 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
|
860 |
-
"\n",
|
861 |
-
"with open(output_path, 'w', newline='') as f_output:\n",
|
862 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
863 |
-
"\n",
|
864 |
-
" for t in tqdm(tsvs):\n",
|
865 |
-
" csv_output.writerow(t)"
|
866 |
-
]
|
867 |
-
},
|
868 |
-
{
|
869 |
-
"cell_type": "code",
|
870 |
-
"execution_count": null,
|
871 |
-
"metadata": {},
|
872 |
-
"outputs": [],
|
873 |
-
"source": [
|
874 |
-
"csv.field_size_limit(sys.maxsize)\n",
|
875 |
-
"\n",
|
876 |
-
"\n",
|
877 |
-
"out_data = []\n",
|
878 |
-
"selected_cols='0,1,2,3,4,5,6,7'\n",
|
879 |
-
"\n",
|
880 |
-
"selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
881 |
-
"\n",
|
882 |
-
"with open(output_path) as file:\n",
|
883 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
884 |
-
" for line in tqdm(tsv_file):\n",
|
885 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
886 |
-
" out_data.append(d)\n",
|
887 |
-
" "
|
888 |
-
]
|
889 |
-
},
|
890 |
-
{
|
891 |
-
"cell_type": "code",
|
892 |
-
"execution_count": null,
|
893 |
-
"metadata": {},
|
894 |
-
"outputs": [],
|
895 |
-
"source": [
|
896 |
-
"out_data[-1]"
|
897 |
-
]
|
898 |
-
},
|
899 |
-
{
|
900 |
-
"cell_type": "markdown",
|
901 |
-
"metadata": {},
|
902 |
-
"source": [
|
903 |
-
"#### VQA"
|
904 |
-
]
|
905 |
-
},
|
906 |
-
{
|
907 |
-
"cell_type": "code",
|
908 |
-
"execution_count": null,
|
909 |
-
"metadata": {},
|
910 |
-
"outputs": [],
|
911 |
-
"source": [
|
912 |
-
"original_data_path = '/data/mshukor/data/our_albef_data/data/vqa_train.json'\n",
|
913 |
-
"original_data = json.load(open(original_data_path,'r'))\n"
|
914 |
-
]
|
915 |
-
},
|
916 |
-
{
|
917 |
-
"cell_type": "code",
|
918 |
-
"execution_count": null,
|
919 |
-
"metadata": {},
|
920 |
-
"outputs": [],
|
921 |
-
"source": [
|
922 |
-
"original_data[100]"
|
923 |
-
]
|
924 |
-
},
|
925 |
-
{
|
926 |
-
"cell_type": "code",
|
927 |
-
"execution_count": null,
|
928 |
-
"metadata": {},
|
929 |
-
"outputs": [],
|
930 |
-
"source": [
|
931 |
-
"# 1.0|!+horizontal&&0.3|!+south&&0.3|!+straight&&0.3|!+vertical\n",
|
932 |
-
"\n",
|
933 |
-
"from preprocess.utils import get_tsv_vqa_data_from_json\n",
|
934 |
-
"\n",
|
935 |
-
"\n",
|
936 |
-
"start_id = 0\n",
|
937 |
-
"dataset_name = 'vqav2'\n",
|
938 |
-
"task_type = 'qa'\n",
|
939 |
-
"\n",
|
940 |
-
"image_root = '/data/mshukor/data/coco'\n",
|
941 |
-
"tmp = get_tsv_vqa_data_from_json(original_data, start_id, dataset_name, task_type, image_root=image_root, convert_images=False)"
|
942 |
-
]
|
943 |
-
},
|
944 |
-
{
|
945 |
-
"cell_type": "code",
|
946 |
-
"execution_count": null,
|
947 |
-
"metadata": {},
|
948 |
-
"outputs": [],
|
949 |
-
"source": [
|
950 |
-
"tmp[10]"
|
951 |
-
]
|
952 |
-
},
|
953 |
-
{
|
954 |
-
"cell_type": "code",
|
955 |
-
"execution_count": null,
|
956 |
-
"metadata": {},
|
957 |
-
"outputs": [],
|
958 |
-
"source": []
|
959 |
-
},
|
960 |
-
{
|
961 |
-
"cell_type": "markdown",
|
962 |
-
"metadata": {},
|
963 |
-
"source": [
|
964 |
-
"#### Visual Grounding "
|
965 |
-
]
|
966 |
-
},
|
967 |
-
{
|
968 |
-
"cell_type": "code",
|
969 |
-
"execution_count": null,
|
970 |
-
"metadata": {},
|
971 |
-
"outputs": [],
|
972 |
-
"source": [
|
973 |
-
"original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+_train.json'\n",
|
974 |
-
"original_data = json.load(open(original_data_path,'r'))\n",
|
975 |
-
"\n",
|
976 |
-
"original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/dets.json'\n",
|
977 |
-
"det_file = json.load(open(original_data_path,'r'))\n",
|
978 |
-
"\n",
|
979 |
-
"original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/cocos.json'\n",
|
980 |
-
"coco_file = json.load(open(original_data_path,'r'))"
|
981 |
-
]
|
982 |
-
},
|
983 |
-
{
|
984 |
-
"cell_type": "code",
|
985 |
-
"execution_count": null,
|
986 |
-
"metadata": {},
|
987 |
-
"outputs": [],
|
988 |
-
"source": [
|
989 |
-
"list(det_file.keys())[:10]"
|
990 |
-
]
|
991 |
-
},
|
992 |
-
{
|
993 |
-
"cell_type": "code",
|
994 |
-
"execution_count": null,
|
995 |
-
"metadata": {},
|
996 |
-
"outputs": [],
|
997 |
-
"source": [
|
998 |
-
"original_data_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
|
999 |
-
"original_data = json.load(open(original_data_path,'r'))"
|
1000 |
-
]
|
1001 |
-
},
|
1002 |
-
{
|
1003 |
-
"cell_type": "code",
|
1004 |
-
"execution_count": null,
|
1005 |
-
"metadata": {},
|
1006 |
-
"outputs": [],
|
1007 |
-
"source": [
|
1008 |
-
"import pickle\n",
|
1009 |
-
"\n",
|
1010 |
-
"ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
|
1011 |
-
"refs = pickle.load(open(ref_path, 'rb'))"
|
1012 |
-
]
|
1013 |
-
},
|
1014 |
-
{
|
1015 |
-
"cell_type": "code",
|
1016 |
-
"execution_count": null,
|
1017 |
-
"metadata": {},
|
1018 |
-
"outputs": [],
|
1019 |
-
"source": [
|
1020 |
-
"for i, ref in tqdm(enumerate(refs)):\n",
|
1021 |
-
" \n",
|
1022 |
-
" "
|
1023 |
-
]
|
1024 |
-
},
|
1025 |
-
{
|
1026 |
-
"cell_type": "code",
|
1027 |
-
"execution_count": null,
|
1028 |
-
"metadata": {},
|
1029 |
-
"outputs": [],
|
1030 |
-
"source": [
|
1031 |
-
"len(refs)"
|
1032 |
-
]
|
1033 |
-
},
|
1034 |
-
{
|
1035 |
-
"cell_type": "code",
|
1036 |
-
"execution_count": null,
|
1037 |
-
"metadata": {},
|
1038 |
-
"outputs": [],
|
1039 |
-
"source": [
|
1040 |
-
"refs[500]"
|
1041 |
-
]
|
1042 |
-
},
|
1043 |
-
{
|
1044 |
-
"cell_type": "code",
|
1045 |
-
"execution_count": null,
|
1046 |
-
"metadata": {},
|
1047 |
-
"outputs": [],
|
1048 |
-
"source": [
|
1049 |
-
"id_to_annot = {}\n",
|
1050 |
-
"for annot in original_data['annotations']:\n",
|
1051 |
-
" id_to_annot[annot['id']] = annot\n",
|
1052 |
-
" \n",
|
1053 |
-
" "
|
1054 |
-
]
|
1055 |
-
},
|
1056 |
-
{
|
1057 |
-
"cell_type": "code",
|
1058 |
-
"execution_count": null,
|
1059 |
-
"metadata": {},
|
1060 |
-
"outputs": [],
|
1061 |
-
"source": [
|
1062 |
-
"id_to_images = {}\n",
|
1063 |
-
"for annot in tqdm(original_data['images']):\n",
|
1064 |
-
" id_to_images[annot['id']] = annot"
|
1065 |
-
]
|
1066 |
-
},
|
1067 |
-
{
|
1068 |
-
"cell_type": "code",
|
1069 |
-
"execution_count": null,
|
1070 |
-
"metadata": {},
|
1071 |
-
"outputs": [],
|
1072 |
-
"source": [
|
1073 |
-
"id_to_images[576457]"
|
1074 |
-
]
|
1075 |
-
},
|
1076 |
-
{
|
1077 |
-
"cell_type": "code",
|
1078 |
-
"execution_count": null,
|
1079 |
-
"metadata": {},
|
1080 |
-
"outputs": [],
|
1081 |
-
"source": [
|
1082 |
-
"list(id_to_annot.keys())[:10]\n",
|
1083 |
-
"id_to_annot[1640859]['bbox']\n",
|
1084 |
-
"for r in tqdm(id_to_annot.values()):\n",
|
1085 |
-
" if r['bbox'][0] > 0:\n",
|
1086 |
-
" print(r['bbox'])"
|
1087 |
-
]
|
1088 |
-
},
|
1089 |
-
{
|
1090 |
-
"cell_type": "code",
|
1091 |
-
"execution_count": null,
|
1092 |
-
"metadata": {},
|
1093 |
-
"outputs": [],
|
1094 |
-
"source": []
|
1095 |
-
},
|
1096 |
-
{
|
1097 |
-
"cell_type": "code",
|
1098 |
-
"execution_count": null,
|
1099 |
-
"metadata": {},
|
1100 |
-
"outputs": [],
|
1101 |
-
"source": [
|
1102 |
-
"list(original_data.keys())[:10]"
|
1103 |
-
]
|
1104 |
-
},
|
1105 |
-
{
|
1106 |
-
"cell_type": "code",
|
1107 |
-
"execution_count": null,
|
1108 |
-
"metadata": {},
|
1109 |
-
"outputs": [],
|
1110 |
-
"source": [
|
1111 |
-
"ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
|
1112 |
-
"instances_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
|
1113 |
-
"start_id = 0\n",
|
1114 |
-
"dataset_name='refcoco_train'\n",
|
1115 |
-
"task_type='visual_grounding'\n",
|
1116 |
-
"convert_images=False\n",
|
1117 |
-
"split='train'\n",
|
1118 |
-
"\n",
|
1119 |
-
"tmp = get_tsv_from_refcoco(ref_path, instances_path, start_id, dataset_name=dataset_name, task_type=task_type, convert_images=convert_images, split=split)"
|
1120 |
-
]
|
1121 |
-
},
|
1122 |
-
{
|
1123 |
-
"cell_type": "code",
|
1124 |
-
"execution_count": null,
|
1125 |
-
"metadata": {},
|
1126 |
-
"outputs": [],
|
1127 |
-
"source": [
|
1128 |
-
"tmp[-1]"
|
1129 |
-
]
|
1130 |
-
},
|
1131 |
-
{
|
1132 |
-
"cell_type": "code",
|
1133 |
-
"execution_count": null,
|
1134 |
-
"metadata": {},
|
1135 |
-
"outputs": [],
|
1136 |
-
"source": [
|
1137 |
-
"Image.open('/data/mshukor/data/coco/train2014/COCO_train2014_000000000072.jpg').convert('RGB')"
|
1138 |
-
]
|
1139 |
-
},
|
1140 |
-
{
|
1141 |
-
"cell_type": "code",
|
1142 |
-
"execution_count": null,
|
1143 |
-
"metadata": {},
|
1144 |
-
"outputs": [],
|
1145 |
-
"source": [
|
1146 |
-
"original_data['images'][:10]"
|
1147 |
-
]
|
1148 |
-
},
|
1149 |
-
{
|
1150 |
-
"cell_type": "code",
|
1151 |
-
"execution_count": null,
|
1152 |
-
"metadata": {},
|
1153 |
-
"outputs": [],
|
1154 |
-
"source": [
|
1155 |
-
"# ['third book starting from left', '', '29.1,11.72,66.81,343.41', '', 'refcoco_train', 'visual_grounding']\n",
|
1156 |
-
"\n",
|
1157 |
-
"original_data['categories']"
|
1158 |
-
]
|
1159 |
-
},
|
1160 |
-
{
|
1161 |
-
"cell_type": "markdown",
|
1162 |
-
"metadata": {},
|
1163 |
-
"source": [
|
1164 |
-
"### Imagenet"
|
1165 |
-
]
|
1166 |
-
},
|
1167 |
-
{
|
1168 |
-
"cell_type": "code",
|
1169 |
-
"execution_count": null,
|
1170 |
-
"metadata": {},
|
1171 |
-
"outputs": [],
|
1172 |
-
"source": [
|
1173 |
-
"# image-id and image base64 string .txt file \n",
|
1174 |
-
"# id, image, code in tsv final \n",
|
1175 |
-
"\n",
|
1176 |
-
"from preprocesss.utils import create_imagenet_txt_files\n",
|
1177 |
-
"\n",
|
1178 |
-
"\n",
|
1179 |
-
"path_data = '/data/mshukor/data/imagenet/val'\n",
|
1180 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_val.txt'\n",
|
1181 |
-
"\n",
|
1182 |
-
"\n",
|
1183 |
-
"create_imagenet_txt_files(path_data, output_path)"
|
1184 |
-
]
|
1185 |
-
},
|
1186 |
-
{
|
1187 |
-
"cell_type": "code",
|
1188 |
-
"execution_count": null,
|
1189 |
-
"metadata": {},
|
1190 |
-
"outputs": [],
|
1191 |
-
"source": [
|
1192 |
-
"start_id\n",
|
1193 |
-
"len(data)\n",
|
1194 |
-
"data[0]"
|
1195 |
-
]
|
1196 |
-
},
|
1197 |
-
{
|
1198 |
-
"cell_type": "code",
|
1199 |
-
"execution_count": null,
|
1200 |
-
"metadata": {},
|
1201 |
-
"outputs": [],
|
1202 |
-
"source": [
|
1203 |
-
"\n",
|
1204 |
-
"code_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train_codes.tsv'\n",
|
1205 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
1206 |
-
"\n",
|
1207 |
-
"def save_image_only_tsv_from_code_files(code_path, output_path, start_id=0):\n",
|
1208 |
-
" selected_col_ids = [0,1]\n",
|
1209 |
-
" out_data = []\n",
|
1210 |
-
" with open(code_path) as file:\n",
|
1211 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1212 |
-
" for line in tqdm(tsv_file):\n",
|
1213 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
1214 |
-
" d = [start_id]+d\n",
|
1215 |
-
" out_data.append(d)\n",
|
1216 |
-
"\n",
|
1217 |
-
"\n",
|
1218 |
-
" with open(output_path, 'w', newline='') as f_output:\n",
|
1219 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
1220 |
-
"\n",
|
1221 |
-
" for t in tqdm(out_data):\n",
|
1222 |
-
" csv_output.writerow(t)\n",
|
1223 |
-
"\n",
|
1224 |
-
"save_image_only_tsv_from_code_files(code_path, output_path, start_id=0)"
|
1225 |
-
]
|
1226 |
-
},
|
1227 |
-
{
|
1228 |
-
"cell_type": "code",
|
1229 |
-
"execution_count": null,
|
1230 |
-
"metadata": {},
|
1231 |
-
"outputs": [],
|
1232 |
-
"source": [
|
1233 |
-
"selected_col_ids = [0,1,2]\n",
|
1234 |
-
"out_data = []\n",
|
1235 |
-
"with open(output_path) as file:\n",
|
1236 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1237 |
-
" for line in tqdm(tsv_file):\n",
|
1238 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
1239 |
-
" out_data.append(d)\n",
|
1240 |
-
" break"
|
1241 |
-
]
|
1242 |
-
},
|
1243 |
-
{
|
1244 |
-
"cell_type": "code",
|
1245 |
-
"execution_count": null,
|
1246 |
-
"metadata": {},
|
1247 |
-
"outputs": [],
|
1248 |
-
"source": [
|
1249 |
-
"len(out_data[0][2].split(' '))"
|
1250 |
-
]
|
1251 |
-
},
|
1252 |
-
{
|
1253 |
-
"cell_type": "markdown",
|
1254 |
-
"metadata": {},
|
1255 |
-
"source": [
|
1256 |
-
"#### Fix image paths"
|
1257 |
-
]
|
1258 |
-
},
|
1259 |
-
{
|
1260 |
-
"cell_type": "code",
|
1261 |
-
"execution_count": 33,
|
1262 |
-
"metadata": {},
|
1263 |
-
"outputs": [
|
1264 |
-
{
|
1265 |
-
"name": "stderr",
|
1266 |
-
"output_type": "stream",
|
1267 |
-
"text": [
|
1268 |
-
"1281167it [00:16, 79250.80it/s]\n"
|
1269 |
-
]
|
1270 |
-
}
|
1271 |
-
],
|
1272 |
-
"source": [
|
1273 |
-
"\n",
|
1274 |
-
"path_data = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
1275 |
-
"selected_cols='0,1,2'\n",
|
1276 |
-
"\n",
|
1277 |
-
"data = []\n",
|
1278 |
-
"\n",
|
1279 |
-
"selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
1280 |
-
"\n",
|
1281 |
-
"with open(path_data) as file:\n",
|
1282 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1283 |
-
" for line in tqdm(tsv_file):\n",
|
1284 |
-
"\n",
|
1285 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
1286 |
-
"# print(d)\n",
|
1287 |
-
" data.append(d)"
|
1288 |
-
]
|
1289 |
-
},
|
1290 |
-
{
|
1291 |
-
"cell_type": "code",
|
1292 |
-
"execution_count": 44,
|
1293 |
-
"metadata": {},
|
1294 |
-
"outputs": [
|
1295 |
-
{
|
1296 |
-
"name": "stderr",
|
1297 |
-
"output_type": "stream",
|
1298 |
-
"text": [
|
1299 |
-
"1281167it [00:16, 76760.12it/s]\n",
|
1300 |
-
"1281167it [00:01, 671149.72it/s]\n",
|
1301 |
-
"100%|█████| 1281167/1281167 [00:01<00:00, 947543.73it/s]\n"
|
1302 |
-
]
|
1303 |
-
}
|
1304 |
-
],
|
1305 |
-
"source": [
|
1306 |
-
"# from imge-id img-path to \n",
|
1307 |
-
"def replace_image_id_by_path(input_tsv, output_tsv, mapping_file):\n",
|
1308 |
-
" selected_cols='0,1,2'\n",
|
1309 |
-
" data = []\n",
|
1310 |
-
" selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
1311 |
-
" with open(input_tsv) as file:\n",
|
1312 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1313 |
-
" for line in tqdm(tsv_file):\n",
|
1314 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
1315 |
-
" data.append(d)\n",
|
1316 |
-
" \n",
|
1317 |
-
" im_id_to_path = {}\n",
|
1318 |
-
" with open(mapping_file) as file:\n",
|
1319 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1320 |
-
" for line in tqdm(tsv_file):\n",
|
1321 |
-
" d = [line[i] for i in [0, 1]]\n",
|
1322 |
-
" im_id_to_path[d[0]] = d[1]\n",
|
1323 |
-
" \n",
|
1324 |
-
" for d in tqdm(data):\n",
|
1325 |
-
" im_id = d[1].split('/')[-1].split('.')[0]\n",
|
1326 |
-
" im_path = im_id_to_path[im_id]\n",
|
1327 |
-
" d[1] = im_path\n",
|
1328 |
-
" \n",
|
1329 |
-
" with open(output_tsv, 'w', newline='') as f_output:\n",
|
1330 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
1331 |
-
"\n",
|
1332 |
-
" for t in tqdm(data):\n",
|
1333 |
-
" csv_output.writerow(t)\n",
|
1334 |
-
" \n",
|
1335 |
-
" return data\n",
|
1336 |
-
"\n",
|
1337 |
-
"input_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
1338 |
-
"output_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
1339 |
-
"mapping_file = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train.txt'\n",
|
1340 |
-
"\n",
|
1341 |
-
"tmp = replace_image_id_by_path(input_tsv, output_tsv, mapping_file)"
|
1342 |
-
]
|
1343 |
-
},
|
1344 |
-
{
|
1345 |
-
"cell_type": "code",
|
1346 |
-
"execution_count": 45,
|
1347 |
-
"metadata": {},
|
1348 |
-
"outputs": [
|
1349 |
-
{
|
1350 |
-
"data": {
|
1351 |
-
"text/plain": [
|
1352 |
-
"['0',\n",
|
1353 |
-
" 'RawImages/train/n03146219/n03146219_8050.JPEG',\n",
|
1354 |
-
" '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
|
1355 |
-
]
|
1356 |
-
},
|
1357 |
-
"execution_count": 45,
|
1358 |
-
"metadata": {},
|
1359 |
-
"output_type": "execute_result"
|
1360 |
-
}
|
1361 |
-
],
|
1362 |
-
"source": [
|
1363 |
-
"tmp[0]"
|
1364 |
-
]
|
1365 |
-
},
|
1366 |
-
{
|
1367 |
-
"cell_type": "code",
|
1368 |
-
"execution_count": 36,
|
1369 |
-
"metadata": {},
|
1370 |
-
"outputs": [
|
1371 |
-
{
|
1372 |
-
"name": "stderr",
|
1373 |
-
"output_type": "stream",
|
1374 |
-
"text": [
|
1375 |
-
"100%|█████| 1281167/1281167 [00:03<00:00, 336250.44it/s]\n"
|
1376 |
-
]
|
1377 |
-
}
|
1378 |
-
],
|
1379 |
-
"source": [
|
1380 |
-
"# imgage_dir = 'imagenet/RawImages/train/'\n",
|
1381 |
-
"# for d in tqdm(data):\n",
|
1382 |
-
"# im_id = d[1]\n",
|
1383 |
-
"# im_dir = im_id.split('_')[0]\n",
|
1384 |
-
"# im_path = os.path.join(im_dir, im_id+'.JPEG')\n",
|
1385 |
-
"# d[1] = os.path.join(imgage_dir, im_path)"
|
1386 |
-
]
|
1387 |
-
},
|
1388 |
-
{
|
1389 |
-
"cell_type": "code",
|
1390 |
-
"execution_count": 39,
|
1391 |
-
"metadata": {},
|
1392 |
-
"outputs": [
|
1393 |
-
{
|
1394 |
-
"data": {
|
1395 |
-
"text/plain": [
|
1396 |
-
"['0',\n",
|
1397 |
-
" 'imagenet/RawImages/train/n03146219/n03146219_8050.JPEG',\n",
|
1398 |
-
" '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
|
1399 |
-
]
|
1400 |
-
},
|
1401 |
-
"execution_count": 39,
|
1402 |
-
"metadata": {},
|
1403 |
-
"output_type": "execute_result"
|
1404 |
-
}
|
1405 |
-
],
|
1406 |
-
"source": [
|
1407 |
-
"data[0]"
|
1408 |
-
]
|
1409 |
-
},
|
1410 |
-
{
|
1411 |
-
"cell_type": "code",
|
1412 |
-
"execution_count": 40,
|
1413 |
-
"metadata": {},
|
1414 |
-
"outputs": [
|
1415 |
-
{
|
1416 |
-
"name": "stderr",
|
1417 |
-
"output_type": "stream",
|
1418 |
-
"text": [
|
1419 |
-
"100%|██████| 1281167/1281167 [00:27<00:00, 46704.02it/s]\n"
|
1420 |
-
]
|
1421 |
-
}
|
1422 |
-
],
|
1423 |
-
"source": [
|
1424 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
|
1425 |
-
"with open(output_path, 'w', newline='') as f_output:\n",
|
1426 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
1427 |
-
"\n",
|
1428 |
-
" for t in tqdm(data):\n",
|
1429 |
-
" csv_output.writerow(t)\n"
|
1430 |
-
]
|
1431 |
-
},
|
1432 |
-
{
|
1433 |
-
"cell_type": "markdown",
|
1434 |
-
"metadata": {},
|
1435 |
-
"source": [
|
1436 |
-
"### Object detection"
|
1437 |
-
]
|
1438 |
-
},
|
1439 |
-
{
|
1440 |
-
"cell_type": "markdown",
|
1441 |
-
"metadata": {},
|
1442 |
-
"source": [
|
1443 |
-
"#### COCO"
|
1444 |
-
]
|
1445 |
-
},
|
1446 |
-
{
|
1447 |
-
"cell_type": "code",
|
1448 |
-
"execution_count": null,
|
1449 |
-
"metadata": {},
|
1450 |
-
"outputs": [],
|
1451 |
-
"source": [
|
1452 |
-
"# '505.856,189.994,799.744,450.016,/m/07j7r,tree&&753.664,384.00600000000003,827.392,446.572,/m/0c9ph5,flower'\n",
|
1453 |
-
"\n",
|
1454 |
-
"path_json = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
|
1455 |
-
"\n",
|
1456 |
-
"data = json.load(open(path_json,'r'))"
|
1457 |
-
]
|
1458 |
-
},
|
1459 |
-
{
|
1460 |
-
"cell_type": "code",
|
1461 |
-
"execution_count": null,
|
1462 |
-
"metadata": {},
|
1463 |
-
"outputs": [],
|
1464 |
-
"source": [
|
1465 |
-
"def get_tsv_from_coco_detection(instances_path, start_id, convert_images=True, split='train'):\n",
|
1466 |
-
"\n",
|
1467 |
-
" instances = json.load(open(instances_path,'r'))\n",
|
1468 |
-
" imgid_to_annot = {}\n",
|
1469 |
-
" for annot in tqdm(instances['annotations']):\n",
|
1470 |
-
" if annot['image_id'] not in imgid_to_annot:\n",
|
1471 |
-
" imgid_to_annot[annot['image_id']] = [annot]\n",
|
1472 |
-
" else:\n",
|
1473 |
-
" imgid_to_annot[annot['image_id']].append(annot)\n",
|
1474 |
-
"\n",
|
1475 |
-
" id_to_category = {}\n",
|
1476 |
-
" for annot in tqdm(instances['categories']):\n",
|
1477 |
-
" id_to_category[annot['id']] = annot['name']\n",
|
1478 |
-
"\n",
|
1479 |
-
" tsv_data = []\n",
|
1480 |
-
" missied = []\n",
|
1481 |
-
" for ref in tqdm(instances['images']):\n",
|
1482 |
-
" ref_split = split\n",
|
1483 |
-
" image_id = ref['id']\n",
|
1484 |
-
" file_name = ref['file_name']\n",
|
1485 |
-
"\n",
|
1486 |
-
" if ref_split == 'train':\n",
|
1487 |
-
" file_name = os.path.join('coco/train2014', file_name)\n",
|
1488 |
-
"\n",
|
1489 |
-
" if convert_images:\n",
|
1490 |
-
" img_path = os.path.join('/data/mshukor/data/', file_name)\n",
|
1491 |
-
" img = convert_img_to_str(img_path)\n",
|
1492 |
-
" else:\n",
|
1493 |
-
" img_path = file_name.replace('/data/mshukor/data/', '')\n",
|
1494 |
-
" img = img_path\n",
|
1495 |
-
"\n",
|
1496 |
-
" # ann_id = ref['id']\n",
|
1497 |
-
" # annot = id_to_annot[ann_id]\n",
|
1498 |
-
" if image_id not in imgid_to_annot:\n",
|
1499 |
-
" missied.append(image_id)\n",
|
1500 |
-
" continue\n",
|
1501 |
-
" annots = imgid_to_annot[image_id]\n",
|
1502 |
-
" detections = []\n",
|
1503 |
-
" areas = []\n",
|
1504 |
-
" for annot in annots:\n",
|
1505 |
-
" bbox = annot['bbox'] # x,y,w,h bottom left\n",
|
1506 |
-
" area = bbox[2]*bbox[3]\n",
|
1507 |
-
" x1, y1, x2, y2 = bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3] # top left, bottom right \n",
|
1508 |
-
" # box = '{:.3f},{:.3f},{:.3f},{:.3f}'.format(x1, y1, x2, y2)\n",
|
1509 |
-
"\n",
|
1510 |
-
" object_id = annot['category_id']\n",
|
1511 |
-
" category = id_to_category[object_id]\n",
|
1512 |
-
"\n",
|
1513 |
-
" tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
|
1514 |
-
" areas.append(area)\n",
|
1515 |
-
" detections.append(tmp)\n",
|
1516 |
-
"\n",
|
1517 |
-
" sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
|
1518 |
-
" detections = [detections[k] for k in sorted_indices]\n",
|
1519 |
-
" detections = '&&'.join(detections)\n",
|
1520 |
-
" t = [start_id, img, detections]\n",
|
1521 |
-
"\n",
|
1522 |
-
" tsv_data.append(t)\n",
|
1523 |
-
" start_id+=1\n",
|
1524 |
-
"\n",
|
1525 |
-
" return tsv_data\n",
|
1526 |
-
"\n",
|
1527 |
-
"instances_path = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
|
1528 |
-
"start_id = 0\n",
|
1529 |
-
"tmp = get_tsv_from_coco_detection(instances_path, start_id, convert_images=False, split='train')"
|
1530 |
-
]
|
1531 |
-
},
|
1532 |
-
{
|
1533 |
-
"cell_type": "code",
|
1534 |
-
"execution_count": null,
|
1535 |
-
"metadata": {},
|
1536 |
-
"outputs": [],
|
1537 |
-
"source": [
|
1538 |
-
"list(imgid_to_annot.keys())[:10]\n",
|
1539 |
-
"len(missied)"
|
1540 |
-
]
|
1541 |
-
},
|
1542 |
-
{
|
1543 |
-
"cell_type": "markdown",
|
1544 |
-
"metadata": {},
|
1545 |
-
"source": [
|
1546 |
-
"#### VG"
|
1547 |
-
]
|
1548 |
-
},
|
1549 |
-
{
|
1550 |
-
"cell_type": "code",
|
1551 |
-
"execution_count": null,
|
1552 |
-
"metadata": {},
|
1553 |
-
"outputs": [],
|
1554 |
-
"source": [
|
1555 |
-
"def get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=True, split='train'):\n",
|
1556 |
-
" \n",
|
1557 |
-
" instances = json.load(open(instances_path,'r'))\n",
|
1558 |
-
" \n",
|
1559 |
-
" id_to_objects = {}\n",
|
1560 |
-
" for d in instances:\n",
|
1561 |
-
" id_to_objects[d['id']] = d\n",
|
1562 |
-
"\n",
|
1563 |
-
"\n",
|
1564 |
-
" \n",
|
1565 |
-
" id_to_image_path = {}\n",
|
1566 |
-
" for root, dirs, files, in os.walk(path_images):\n",
|
1567 |
-
" for d in dirs:\n",
|
1568 |
-
" dir_path = os.path.join(root, d)\n",
|
1569 |
-
" for _, _, dir_files in os.walk(dir_path):\n",
|
1570 |
-
" for f in dir_files:\n",
|
1571 |
-
" file_path = os.path.join(dir_path, f)\n",
|
1572 |
-
" file_path = '/'.join(file_path.split('/')[-4:])\n",
|
1573 |
-
" image_id = f.split('.')[0]\n",
|
1574 |
-
" id_to_image_path[image_id] = file_path\n",
|
1575 |
-
"\n",
|
1576 |
-
" \n",
|
1577 |
-
"\n",
|
1578 |
-
"\n",
|
1579 |
-
" tsv_data = []\n",
|
1580 |
-
" missied = []\n",
|
1581 |
-
" negs = []\n",
|
1582 |
-
" for ref in tqdm(id_to_image_path.keys()):\n",
|
1583 |
-
" ref_split = split\n",
|
1584 |
-
" \n",
|
1585 |
-
" image_id = ref\n",
|
1586 |
-
" \n",
|
1587 |
-
" file_name = id_to_image_path[image_id]\n",
|
1588 |
-
" if convert_images:\n",
|
1589 |
-
" img_path = os.path.join('/data/mshukor/data/', file_name)\n",
|
1590 |
-
" img = convert_img_to_str(img_path)\n",
|
1591 |
-
" else:\n",
|
1592 |
-
" img_path = file_name.replace('/data/mshukor/data/', '')\n",
|
1593 |
-
" img = img_path\n",
|
1594 |
-
"\n",
|
1595 |
-
" \n",
|
1596 |
-
" if int(image_id) in id_to_objects:\n",
|
1597 |
-
" objects = id_to_objects[int(image_id)]['objects']\n",
|
1598 |
-
" else:\n",
|
1599 |
-
" missied.append(image_id)\n",
|
1600 |
-
" continue\n",
|
1601 |
-
" \n",
|
1602 |
-
" if len(objects) == 0:\n",
|
1603 |
-
" missied.append(image_id)\n",
|
1604 |
-
" continue\n",
|
1605 |
-
" \n",
|
1606 |
-
" \n",
|
1607 |
-
" areas = []\n",
|
1608 |
-
" detections = []\n",
|
1609 |
-
" for annot in objects:\n",
|
1610 |
-
" x,y,w,h = annot['x'], annot['y'], annot['w'], annot['h'] # x,y,w,h bottom left\n",
|
1611 |
-
" \n",
|
1612 |
-
" area = w*h\n",
|
1613 |
-
" \n",
|
1614 |
-
" x1, y1, x2, y2 = x, y, x + w, y + h # top left, bottom right \n",
|
1615 |
-
" \n",
|
1616 |
-
" if x1 < 0 or x2 < 0:\n",
|
1617 |
-
" negs.append(annot)\n",
|
1618 |
-
" x1 = max(0, x1)\n",
|
1619 |
-
" x2 = max(0, x2)\n",
|
1620 |
-
" \n",
|
1621 |
-
" \n",
|
1622 |
-
" category = ','.join(annot['names']).replace('\\x00','')\n",
|
1623 |
-
" object_id = annot['id']\n",
|
1624 |
-
" \n",
|
1625 |
-
" \n",
|
1626 |
-
" tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
|
1627 |
-
" detections.append(tmp)\n",
|
1628 |
-
" areas.append(area)\n",
|
1629 |
-
"\n",
|
1630 |
-
" sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
|
1631 |
-
" detections = [detections[k] for k in sorted_indices]\n",
|
1632 |
-
" \n",
|
1633 |
-
" detections = '&&'.join(detections)\n",
|
1634 |
-
" t = [start_id, img, detections]\n",
|
1635 |
-
"\n",
|
1636 |
-
" tsv_data.append(t)\n",
|
1637 |
-
" start_id+=1\n",
|
1638 |
-
" print('missed images:', len(missied), 'negs', len(negs))\n",
|
1639 |
-
" return tsv_data\n",
|
1640 |
-
"\n",
|
1641 |
-
"\n",
|
1642 |
-
"instances_path = '/data/mshukor/data/visual_genome/annotations/objects.json'\n",
|
1643 |
-
"path_images = '/data/mshukor/data/visual_genome/images'\n",
|
1644 |
-
"start_id = 0\n",
|
1645 |
-
"\n",
|
1646 |
-
"tmp = get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=False, split='train')"
|
1647 |
-
]
|
1648 |
-
},
|
1649 |
-
{
|
1650 |
-
"cell_type": "code",
|
1651 |
-
"execution_count": null,
|
1652 |
-
"metadata": {},
|
1653 |
-
"outputs": [],
|
1654 |
-
"source": [
|
1655 |
-
"image_root = '/data/mshukor/data/'\n",
|
1656 |
-
"\n",
|
1657 |
-
"Image.open(image_root+id_to_image_path['1087']).convert('RGB')"
|
1658 |
-
]
|
1659 |
-
},
|
1660 |
-
{
|
1661 |
-
"cell_type": "markdown",
|
1662 |
-
"metadata": {},
|
1663 |
-
"source": [
|
1664 |
-
"#### OpenImagesV5"
|
1665 |
-
]
|
1666 |
-
},
|
1667 |
-
{
|
1668 |
-
"cell_type": "code",
|
1669 |
-
"execution_count": null,
|
1670 |
-
"metadata": {},
|
1671 |
-
"outputs": [],
|
1672 |
-
"source": [
|
1673 |
-
"# data_path = '/data/mshukor/data/OpenImagesV5/train-annotations-bbox.csv'\n",
|
1674 |
-
"# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable.csv'\n",
|
1675 |
-
"# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable-with-rotation.csv'\n",
|
1676 |
-
"data_path = '/data/mshukor/data/OpenImagesV5/class-descriptions-boxable.csv'\n",
|
1677 |
-
"\n",
|
1678 |
-
"\n",
|
1679 |
-
"\n",
|
1680 |
-
"\n",
|
1681 |
-
"selected_col_ids = [0,1,2]\n",
|
1682 |
-
"out_data = []\n",
|
1683 |
-
"with open(data_path) as file:\n",
|
1684 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1685 |
-
" for i, line in tqdm(enumerate(tsv_file)):\n",
|
1686 |
-
" # d = [line[i] for i in selected_col_ids]\n",
|
1687 |
-
" out_data.append(line)\n",
|
1688 |
-
"# print(line)\n",
|
1689 |
-
"# if i > 2:\n",
|
1690 |
-
"# break\n",
|
1691 |
-
" "
|
1692 |
-
]
|
1693 |
-
},
|
1694 |
-
{
|
1695 |
-
"cell_type": "code",
|
1696 |
-
"execution_count": null,
|
1697 |
-
"metadata": {},
|
1698 |
-
"outputs": [],
|
1699 |
-
"source": [
|
1700 |
-
"def get_tsv_from_openimages_detection(instances_path, path_images, start_id, convert_images=False, split='train')\n",
|
1701 |
-
"\n",
|
1702 |
-
" id_to_image_path = {}\n",
|
1703 |
-
" for root, dirs, files, in os.walk(path_images):\n",
|
1704 |
-
" for d in dirs:\n",
|
1705 |
-
" dir_path = os.path.join(root, d)\n",
|
1706 |
-
" for _, _, dir_files in os.walk(dir_path):\n",
|
1707 |
-
" for f in dir_files:\n",
|
1708 |
-
" file_path = os.path.join(dir_path, f)\n",
|
1709 |
-
" file_path = '/'.join(file_path.split('/')[-4:])\n",
|
1710 |
-
" image_id = f.split('.')[0]\n",
|
1711 |
-
" id_to_image_path[image_id] = file_path\n",
|
1712 |
-
"\n",
|
1713 |
-
" image_root = '/gpfsdswork/dataset'\n",
|
1714 |
-
"\n",
|
1715 |
-
" def imagepath_to_image_size(path):\n",
|
1716 |
-
" w, h = Image.open(path).size\n",
|
1717 |
-
"\n",
|
1718 |
-
" id_to_annot = {}\n",
|
1719 |
-
" with open(instances_path) as file:\n",
|
1720 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1721 |
-
" for i, line in tqdm(enumerate(tsv_file)):\n",
|
1722 |
-
" img_id = line[0].split(',')[0]\n",
|
1723 |
-
" if img_id in id_to_annot:\n",
|
1724 |
-
" id_to_annot[img_id].append(line)\n",
|
1725 |
-
" else:\n",
|
1726 |
-
" id_to_annot[img_id] = [line]\n",
|
1727 |
-
"\n",
|
1728 |
-
" classid_to_class = {}\n",
|
1729 |
-
"\n",
|
1730 |
-
" with open(class_path) as file:\n",
|
1731 |
-
" tsv_file = csv.reader(file, delimiter=',')\n",
|
1732 |
-
" for i, line in tqdm(enumerate(tsv_file)):\n",
|
1733 |
-
" classid_to_class[line[0]] = line[1]\n",
|
1734 |
-
"\n",
|
1735 |
-
" tsv_data = []\n",
|
1736 |
-
" for img_id in id_to_annot.keys():\n",
|
1737 |
-
" annots = id_to_annot[img_id]\n",
|
1738 |
-
" img_path = id_to_image_path[img_id]\n",
|
1739 |
-
" orig_img_path = os.path.join(image_root, img_path)\n",
|
1740 |
-
" w, h = imagepath_to_image_size(path)\n",
|
1741 |
-
"\n",
|
1742 |
-
" if convert_images:\n",
|
1743 |
-
" img = convert_img_to_str(orig_img_path)\n",
|
1744 |
-
" else:\n",
|
1745 |
-
" img = img_path\n",
|
1746 |
-
"\n",
|
1747 |
-
" areas = []\n",
|
1748 |
-
" detections = []\n",
|
1749 |
-
" for d in annots:\n",
|
1750 |
-
" d = d[0].split(',')\n",
|
1751 |
-
"\n",
|
1752 |
-
" x1, x2, y1, y2 = d[4:8]\n",
|
1753 |
-
" x1, x2, y1, y2 = x1*w, x2*w, y1*h, y2*h\n",
|
1754 |
-
" box_w, box_h = x2 - x1, y2 - y1\n",
|
1755 |
-
" area = box_w*box_h\n",
|
1756 |
-
" areas.append(area)\n",
|
1757 |
-
"\n",
|
1758 |
-
" object_id = d[2]\n",
|
1759 |
-
" category = classid_to_class[object_id]\n",
|
1760 |
-
"\n",
|
1761 |
-
" tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
|
1762 |
-
" detections.append(tmp)\n",
|
1763 |
-
"\n",
|
1764 |
-
"\n",
|
1765 |
-
" sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
|
1766 |
-
" detections = [detections[k] for k in sorted_indices]\n",
|
1767 |
-
"\n",
|
1768 |
-
" detections = '&&'.join(detections)\n",
|
1769 |
-
" t = [start_id, img, detections]\n",
|
1770 |
-
"\n",
|
1771 |
-
" tsv_data.append(t)\n",
|
1772 |
-
" start_id+=1\n",
|
1773 |
-
" \n",
|
1774 |
-
" return tsv_data\n",
|
1775 |
-
"\n",
|
1776 |
-
" "
|
1777 |
-
]
|
1778 |
-
},
|
1779 |
-
{
|
1780 |
-
"cell_type": "code",
|
1781 |
-
"execution_count": null,
|
1782 |
-
"metadata": {},
|
1783 |
-
"outputs": [],
|
1784 |
-
"source": [
|
1785 |
-
"e39871fd9fd74f55"
|
1786 |
-
]
|
1787 |
-
},
|
1788 |
-
{
|
1789 |
-
"cell_type": "markdown",
|
1790 |
-
"metadata": {},
|
1791 |
-
"source": [
|
1792 |
-
"### Text"
|
1793 |
-
]
|
1794 |
-
},
|
1795 |
-
{
|
1796 |
-
"cell_type": "markdown",
|
1797 |
-
"metadata": {},
|
1798 |
-
"source": [
|
1799 |
-
"#### En Wikipedia"
|
1800 |
-
]
|
1801 |
-
},
|
1802 |
-
{
|
1803 |
-
"cell_type": "code",
|
1804 |
-
"execution_count": null,
|
1805 |
-
"metadata": {},
|
1806 |
-
"outputs": [],
|
1807 |
-
"source": [
|
1808 |
-
"from datasets import load_dataset"
|
1809 |
-
]
|
1810 |
-
},
|
1811 |
-
{
|
1812 |
-
"cell_type": "code",
|
1813 |
-
"execution_count": null,
|
1814 |
-
"metadata": {},
|
1815 |
-
"outputs": [],
|
1816 |
-
"source": [
|
1817 |
-
"%env http_proxy='http://192.168.0.100:3128' \n",
|
1818 |
-
"%env https_proxy='http://192.168.0.100:3128'\n",
|
1819 |
-
"\n",
|
1820 |
-
"%env HF_DATASETS_CACHE=\"/data/mshukor/data/.cache\"\n",
|
1821 |
-
"%env HF_DATASETS_OFFLINE=1"
|
1822 |
-
]
|
1823 |
-
},
|
1824 |
-
{
|
1825 |
-
"cell_type": "code",
|
1826 |
-
"execution_count": null,
|
1827 |
-
"metadata": {},
|
1828 |
-
"outputs": [],
|
1829 |
-
"source": [
|
1830 |
-
"tmp = load_dataset(\"wikipedia\", \"20220301.en\", cache_dir=\"/data/mshukor/data/.cache\")"
|
1831 |
-
]
|
1832 |
-
},
|
1833 |
-
{
|
1834 |
-
"cell_type": "code",
|
1835 |
-
"execution_count": null,
|
1836 |
-
"metadata": {},
|
1837 |
-
"outputs": [],
|
1838 |
-
"source": [
|
1839 |
-
"len(tmp['train'][0]['text'])\n",
|
1840 |
-
"tmp['train'][0]['text'][:512]"
|
1841 |
-
]
|
1842 |
-
},
|
1843 |
-
{
|
1844 |
-
"cell_type": "code",
|
1845 |
-
"execution_count": null,
|
1846 |
-
"metadata": {},
|
1847 |
-
"outputs": [],
|
1848 |
-
"source": [
|
1849 |
-
"def remove_special(input_string):\n",
|
1850 |
-
" final_string = \"\"\n",
|
1851 |
-
" for character in input_string:\n",
|
1852 |
-
" if character == \" \":\n",
|
1853 |
-
" final_string = final_string + character\n",
|
1854 |
-
" else:\n",
|
1855 |
-
" if(character.isalnum()):\n",
|
1856 |
-
" final_string = final_string + character\n",
|
1857 |
-
" return final_string\n",
|
1858 |
-
"\n",
|
1859 |
-
"def get_tsv_from_text_data(data_name=\"wikipedia\", data_subname=\"20220301.en\", \n",
|
1860 |
-
" output_path, cache_dir=\"/data/mshukor/data/.cache\", start_id=0, num_max_characters=2500):\n",
|
1861 |
-
" from datasets import load_dataset\n",
|
1862 |
-
" tmp = load_dataset(data_name, data_subname, cache_dir=cache_dir)\n",
|
1863 |
-
"\n",
|
1864 |
-
" with open(output_path, 'w', newline='') as f_output:\n",
|
1865 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
1866 |
-
"\n",
|
1867 |
-
" for i, t in tqdm(enumerate(tmp['train'])):\n",
|
1868 |
-
" text = t['text'][:num_max_characters].replace('\\t', ' ').replace(\"\\n\", ' ').replace('\\\"', '')\n",
|
1869 |
-
" text = remove_special(text)\n",
|
1870 |
-
" item = [start_id, text]\n",
|
1871 |
-
" csv_output.writerow(item)\n",
|
1872 |
-
" start_id+=1\n",
|
1873 |
-
"\n",
|
1874 |
-
" "
|
1875 |
-
]
|
1876 |
-
},
|
1877 |
-
{
|
1878 |
-
"cell_type": "code",
|
1879 |
-
"execution_count": null,
|
1880 |
-
"metadata": {},
|
1881 |
-
"outputs": [],
|
1882 |
-
"source": []
|
1883 |
-
},
|
1884 |
-
{
|
1885 |
-
"cell_type": "code",
|
1886 |
-
"execution_count": null,
|
1887 |
-
"metadata": {},
|
1888 |
-
"outputs": [],
|
1889 |
-
"source": [
|
1890 |
-
"import csv\n",
|
1891 |
-
"from io import StringIO\n",
|
1892 |
-
"\n",
|
1893 |
-
"output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
|
1894 |
-
"\n",
|
1895 |
-
"start_id = 0 \n",
|
1896 |
-
"num_max_characters = 2500\n",
|
1897 |
-
"\n",
|
1898 |
-
"with open(output_path, 'w', newline='') as f_output:\n",
|
1899 |
-
" csv_output = csv.writer(f_output, delimiter='\\t')\n",
|
1900 |
-
"\n",
|
1901 |
-
" for i, t in tqdm(enumerate(tmp['train'])):\n",
|
1902 |
-
" text = t['text'][:num_max_characters]\n",
|
1903 |
-
" item = [start_id, text]\n",
|
1904 |
-
" csv_output.writerow(item)\n",
|
1905 |
-
" start_id+=1"
|
1906 |
-
]
|
1907 |
-
},
|
1908 |
-
{
|
1909 |
-
"cell_type": "code",
|
1910 |
-
"execution_count": null,
|
1911 |
-
"metadata": {},
|
1912 |
-
"outputs": [],
|
1913 |
-
"source": [
|
1914 |
-
"out_data = []\n",
|
1915 |
-
"selected_cols='0,1'\n",
|
1916 |
-
"\n",
|
1917 |
-
"selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
|
1918 |
-
"\n",
|
1919 |
-
"with open(output_path) as file:\n",
|
1920 |
-
" tsv_file = csv.reader(file, delimiter='\\t')\n",
|
1921 |
-
" for line in tqdm(tsv_file):\n",
|
1922 |
-
" d = [line[i] for i in selected_col_ids]\n",
|
1923 |
-
" out_data.append(d)\n",
|
1924 |
-
" "
|
1925 |
-
]
|
1926 |
-
},
|
1927 |
-
{
|
1928 |
-
"cell_type": "code",
|
1929 |
-
"execution_count": null,
|
1930 |
-
"metadata": {},
|
1931 |
-
"outputs": [],
|
1932 |
-
"source": [
|
1933 |
-
"out_data"
|
1934 |
-
]
|
1935 |
-
},
|
1936 |
-
{
|
1937 |
-
"cell_type": "markdown",
|
1938 |
-
"metadata": {},
|
1939 |
-
"source": [
|
1940 |
-
"### Create from finetuned data"
|
1941 |
-
]
|
1942 |
-
},
|
1943 |
-
{
|
1944 |
-
"cell_type": "code",
|
1945 |
-
"execution_count": null,
|
1946 |
-
"metadata": {},
|
1947 |
-
"outputs": [],
|
1948 |
-
"source": [
|
1949 |
-
"read from tsv and write to tsv directly \n",
|
1950 |
-
"same for vqa v2\n",
|
1951 |
-
"then create ofa_mini 4m, vqa and refcoco for pretraining "
|
1952 |
-
]
|
1953 |
-
},
|
1954 |
-
{
|
1955 |
-
"cell_type": "markdown",
|
1956 |
-
"metadata": {},
|
1957 |
-
"source": [
|
1958 |
-
"# Convert weights"
|
1959 |
-
]
|
1960 |
-
},
|
1961 |
-
{
|
1962 |
-
"cell_type": "code",
|
1963 |
-
"execution_count": 3,
|
1964 |
-
"metadata": {},
|
1965 |
-
"outputs": [],
|
1966 |
-
"source": [
|
1967 |
-
"from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
|
1968 |
-
"from models import ofa_base_architecture, OFAModel\n",
|
1969 |
-
"from transformers.tokenization_utils_base import BatchEncoding"
|
1970 |
-
]
|
1971 |
-
},
|
1972 |
-
{
|
1973 |
-
"cell_type": "markdown",
|
1974 |
-
"metadata": {},
|
1975 |
-
"source": [
|
1976 |
-
"### Explore ofa"
|
1977 |
-
]
|
1978 |
-
},
|
1979 |
-
{
|
1980 |
-
"cell_type": "code",
|
1981 |
-
"execution_count": 4,
|
1982 |
-
"metadata": {},
|
1983 |
-
"outputs": [
|
1984 |
-
{
|
1985 |
-
"name": "stderr",
|
1986 |
-
"output_type": "stream",
|
1987 |
-
"text": [
|
1988 |
-
"2022-11-15 08:52:08 | INFO | tasks.ofa_task | source dictionary: 59457 types\n",
|
1989 |
-
"2022-11-15 08:52:08 | INFO | tasks.ofa_task | target dictionary: 59457 types\n"
|
1990 |
-
]
|
1991 |
-
}
|
1992 |
-
],
|
1993 |
-
"source": [
|
1994 |
-
"import torch\n",
|
1995 |
-
"import numpy as np\n",
|
1996 |
-
"from fairseq import utils, tasks\n",
|
1997 |
-
"from fairseq import checkpoint_utils\n",
|
1998 |
-
"from utils.eval_utils import eval_step\n",
|
1999 |
-
"from tasks.mm_tasks.caption import CaptionTask\n",
|
2000 |
-
"from models.ofa import OFAModel\n",
|
2001 |
-
"from PIL import Image\n",
|
2002 |
-
"\n",
|
2003 |
-
"# Register refcoco task\n",
|
2004 |
-
"tasks.register_task('caption', CaptionTask)\n",
|
2005 |
-
"\n",
|
2006 |
-
"# turn on cuda if GPU is available\n",
|
2007 |
-
"use_cuda = torch.cuda.is_available()\n",
|
2008 |
-
"# use fp16 only when GPU is available\n",
|
2009 |
-
"use_fp16 = False\n",
|
2010 |
-
"\n",
|
2011 |
-
"# Load pretrained ckpt & config\n",
|
2012 |
-
"overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":16, \"no_repeat_ngram_size\":3, \"seed\":7}\n",
|
2013 |
-
"models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
|
2014 |
-
" utils.split_paths('/data/mshukor/logs/ofa/checkpoints/caption/ofa_caption_stage_1/5_0.06_6000/checkpoint_best.pt'),\n",
|
2015 |
-
" arg_overrides=overrides\n",
|
2016 |
-
" )\n",
|
2017 |
-
"\n",
|
2018 |
-
"# Move models to GPU\n",
|
2019 |
-
"for model in models:\n",
|
2020 |
-
" model.eval()\n",
|
2021 |
-
" if use_fp16:\n",
|
2022 |
-
" model.half()\n",
|
2023 |
-
" if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
|
2024 |
-
" model.cuda()\n",
|
2025 |
-
" model.prepare_for_inference_(cfg)\n",
|
2026 |
-
"\n",
|
2027 |
-
"# Initialize generator\n",
|
2028 |
-
"generator = task.build_generator(models, cfg.generation)"
|
2029 |
-
]
|
2030 |
-
},
|
2031 |
-
{
|
2032 |
-
"cell_type": "code",
|
2033 |
-
"execution_count": 5,
|
2034 |
-
"metadata": {},
|
2035 |
-
"outputs": [],
|
2036 |
-
"source": [
|
2037 |
-
"model_ofa = models[0]\n",
|
2038 |
-
"ofa_state = model_ofa.state_dict()"
|
2039 |
-
]
|
2040 |
-
},
|
2041 |
-
{
|
2042 |
-
"cell_type": "code",
|
2043 |
-
"execution_count": null,
|
2044 |
-
"metadata": {},
|
2045 |
-
"outputs": [],
|
2046 |
-
"source": [
|
2047 |
-
"def get_state_given_key(state, key, excluded_keys=None):\n",
|
2048 |
-
" new_state = {}\n",
|
2049 |
-
" for k, v in state.items():\n",
|
2050 |
-
" if key in k:\n",
|
2051 |
-
" if excluded_keys is not None:\n",
|
2052 |
-
" if not any([ek in k for ek in excluded_keys]):\n",
|
2053 |
-
" new_state[k] = v\n",
|
2054 |
-
" else:\n",
|
2055 |
-
" new_state[k] = v\n",
|
2056 |
-
" return new_state\n",
|
2057 |
-
"\n",
|
2058 |
-
"key = 'encoder.layers.0'\n",
|
2059 |
-
"excluded_keys = ['embed', 'image']\n",
|
2060 |
-
"ofa_tmp = get_state_given_key(ofa_state, key, excluded_keys=excluded_keys)"
|
2061 |
-
]
|
2062 |
-
},
|
2063 |
-
{
|
2064 |
-
"cell_type": "code",
|
2065 |
-
"execution_count": null,
|
2066 |
-
"metadata": {},
|
2067 |
-
"outputs": [],
|
2068 |
-
"source": [
|
2069 |
-
"# def get_ofa_args_large(args):\n",
|
2070 |
-
"# args['encoder_embed_path'] = getattr(args, \"encoder_embed_path\", None)\n",
|
2071 |
-
"# args['encoder_embed_dim'] = getattr(args, \"encoder_embed_dim\", 1024)\n",
|
2072 |
-
"# args['encoder_ffn_embed_dim'] = getattr(args, \"encoder_ffn_embed_dim\", 4 * 1024)\n",
|
2073 |
-
"# args['encoder_layers'] = getattr(args, \"encoder_layers\", 12)\n",
|
2074 |
-
"# args['encoder_attention_heads'] = getattr(args, \"encoder_attention_heads\", 16)\n",
|
2075 |
-
"# args['encoder_normalize_before'] = getattr(args, \"encoder_normalize_before\", True)\n",
|
2076 |
-
"# args['encoder_learned_pos'] = getattr(args, \"encoder_learned_pos\", True)\n",
|
2077 |
-
"# args['decoder_embed_path'] = getattr(args, \"decoder_embed_path\", None)\n",
|
2078 |
-
"# args['decoder_embed_dim'] = getattr(args, \"decoder_embed_dim\", args['encoder_embed_dim'])\n",
|
2079 |
-
"# args['decoder_ffn_embed_dim'] = getattr(\n",
|
2080 |
-
"# args, \"decoder_ffn_embed_dim\", args['encoder_ffn_embed_dim']\n",
|
2081 |
-
"# )\n",
|
2082 |
-
"# args['decoder_layers'] = getattr(args, \"decoder_layers\", 12)\n",
|
2083 |
-
"# args['decoder_attention_heads'] = getattr(args, \"decoder_attention_heads\", 16)\n",
|
2084 |
-
"# args['decoder_normalize_before'] = getattr(args, \"decoder_normalize_before\", True)\n",
|
2085 |
-
"# args['decoder_learned_pos'] = getattr(args, \"decoder_learned_pos\", True)\n",
|
2086 |
-
"# args['attention_dropout'] = getattr(args, \"attention_dropout\", 0.0)\n",
|
2087 |
-
"# args['relu_dropout'] = getattr(args, \"relu_dropout\", 0.0)\n",
|
2088 |
-
"# args['dropout'] = getattr(args, \"dropout\", 0.0)\n",
|
2089 |
-
"# args['max_target_positions'] = getattr(args, \"max_target_positions\", 1024)\n",
|
2090 |
-
"# args['max_source_positions'] = getattr(args, \"max_source_positions\", 1024)\n",
|
2091 |
-
"# args['adaptive_softmax_cutoff'] = getattr(args, \"adaptive_softmax_cutoff\", None)\n",
|
2092 |
-
"# args['adaptive_softmax_dropout'] = getattr(args, \"adaptive_softmax_dropout\", 0)\n",
|
2093 |
-
"# args['share_decoder_input_output_embed'] = getattr(\n",
|
2094 |
-
"# args, \"share_decoder_input_output_embed\", True\n",
|
2095 |
-
"# )\n",
|
2096 |
-
"# args['share_all_embeddings'] = getattr(args, \"share_all_embeddings\", True)\n",
|
2097 |
-
"\n",
|
2098 |
-
"# args['decoder_output_dim'] = getattr(\n",
|
2099 |
-
"# args, \"decoder_output_dim\", args['decoder_embed_dim']\n",
|
2100 |
-
"# )\n",
|
2101 |
-
"# args['decoder_input_dim'] = getattr(args, \"decoder_input_dim\", args['decoder_embed_dim'])\n",
|
2102 |
-
"\n",
|
2103 |
-
"# args['no_scale_embedding'] = getattr(args, \"no_scale_embedding\", True)\n",
|
2104 |
-
"# args['layernorm_embedding'] = getattr(args, \"layernorm_embedding\", True)\n",
|
2105 |
-
"\n",
|
2106 |
-
"# args['activation_fn'] = getattr(args, \"activation_fn\", \"gelu\")\n",
|
2107 |
-
"# args['pooler_activation_fn'] = getattr(args, \"pooler_activation_fn\", \"tanh\")\n",
|
2108 |
-
"# args['pooler_dropout'] = getattr(args, \"pooler_dropout\", 0.0)\n",
|
2109 |
-
"# args['pooler_classifier'] = getattr(args, \"pooler_classifier\", \"mlp\")\n",
|
2110 |
-
"\n",
|
2111 |
-
"# args['resnet_drop_path_rate'] = getattr(args, \"resnet_drop_path_rate\", 0.0)\n",
|
2112 |
-
"# args['encoder_drop_path_rate'] = getattr(args, \"encoder_drop_path_rate\", 0.0)\n",
|
2113 |
-
"# args['decoder_drop_path_rate'] = getattr(args, \"decoder_drop_path_rate\", 0.0)\n",
|
2114 |
-
"\n",
|
2115 |
-
"# args['resnet_type'] = getattr(args, \"resnet_type\", \"resnet152\")\n",
|
2116 |
-
"# args['token_bucket_size'] = getattr(args, \"token_bucket_size\", 256)\n",
|
2117 |
-
"# args['image_bucket_size'] = getattr(args, \"image_bucket_size\", 42)\n",
|
2118 |
-
"\n",
|
2119 |
-
"# args['freeze_encoder_embedding'] = getattr(args, \"freeze_encoder_embedding\", False)\n",
|
2120 |
-
"# args['freeze_decoder_embedding'] = getattr(args, \"freeze_decoder_embedding\", False)\n",
|
2121 |
-
"# args['add_type_embedding'] = getattr(args, \"add_type_embedding\", True)\n",
|
2122 |
-
"# args['attn_scale_factor'] = getattr(args, \"attn_scale_factor\", 2)\n",
|
2123 |
-
"\n",
|
2124 |
-
"# args['code_image_size'] = getattr(args, \"code_image_size\", 128)\n",
|
2125 |
-
"# args['patch_layernorm_embedding'] = getattr(args, \"patch_layernorm_embedding\", True)\n",
|
2126 |
-
"# args['code_layernorm_embedding'] = getattr(args, \"code_layernorm_embedding\", True)\n",
|
2127 |
-
"# args['entangle_position_embedding'] = getattr(args, \"entangle_position_embedding\", False)\n",
|
2128 |
-
"# args['disable_entangle'] = getattr(args, \"disable_entangle\", False)\n",
|
2129 |
-
"# args['sync_bn'] = getattr(args, \"sync_bn\", False)\n",
|
2130 |
-
"\n",
|
2131 |
-
"# args['scale_attn'] = getattr(args, \"scale_attn\", False)\n",
|
2132 |
-
"# args['scale_fc'] = getattr(args, \"scale_fc\", False)\n",
|
2133 |
-
"# args['scale_heads'] = getattr(args, \"scale_heads\", False)\n",
|
2134 |
-
"# args['scale_resids'] = getattr(args, \"scale_resids\", False)\n",
|
2135 |
-
"\n",
|
2136 |
-
"# args['orig_patch_image_size'] = getattr(args, \"orig_patch_image_size\", 256)\n",
|
2137 |
-
"\n",
|
2138 |
-
"# return args"
|
2139 |
-
]
|
2140 |
-
},
|
2141 |
-
{
|
2142 |
-
"cell_type": "code",
|
2143 |
-
"execution_count": null,
|
2144 |
-
"metadata": {},
|
2145 |
-
"outputs": [],
|
2146 |
-
"source": [
|
2147 |
-
"# args = {}\n",
|
2148 |
-
"# args = get_ofa_args_large(args)\n",
|
2149 |
-
"# args = BatchEncoding(args)\n",
|
2150 |
-
"# ofa_base_architecture(args)\n",
|
2151 |
-
"# data_dir = '/data/mshukor/data/ofa/pretrain_example'\n",
|
2152 |
-
"\n",
|
2153 |
-
"# cfg.task.neg_sample_dir = data_dir+'/negative_sample'"
|
2154 |
-
]
|
2155 |
-
},
|
2156 |
-
{
|
2157 |
-
"cell_type": "markdown",
|
2158 |
-
"metadata": {},
|
2159 |
-
"source": [
|
2160 |
-
"### convert t5 weights"
|
2161 |
-
]
|
2162 |
-
},
|
2163 |
-
{
|
2164 |
-
"cell_type": "code",
|
2165 |
-
"execution_count": 6,
|
2166 |
-
"metadata": {},
|
2167 |
-
"outputs": [],
|
2168 |
-
"source": [
|
2169 |
-
"model_t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")"
|
2170 |
-
]
|
2171 |
-
},
|
2172 |
-
{
|
2173 |
-
"cell_type": "code",
|
2174 |
-
"execution_count": null,
|
2175 |
-
"metadata": {},
|
2176 |
-
"outputs": [],
|
2177 |
-
"source": [
|
2178 |
-
"model_t5"
|
2179 |
-
]
|
2180 |
-
},
|
2181 |
-
{
|
2182 |
-
"cell_type": "code",
|
2183 |
-
"execution_count": 7,
|
2184 |
-
"metadata": {},
|
2185 |
-
"outputs": [],
|
2186 |
-
"source": [
|
2187 |
-
"t5_state = model_t5.state_dict()"
|
2188 |
-
]
|
2189 |
-
},
|
2190 |
-
{
|
2191 |
-
"cell_type": "code",
|
2192 |
-
"execution_count": 56,
|
2193 |
-
"metadata": {},
|
2194 |
-
"outputs": [],
|
2195 |
-
"source": [
|
2196 |
-
"import re\n",
|
2197 |
-
"# line = re.sub(r\"</?\\[\\d+>\", \"\", line)\n",
|
2198 |
-
"\n",
|
2199 |
-
"mapping_dict = {\n",
|
2200 |
-
" ## encoder\n",
|
2201 |
-
" 'block': 'layers', \n",
|
2202 |
-
" 'layer.[0-9]+.SelfAttention': 'self_attn', \n",
|
2203 |
-
" '.q.': '.q_proj.', \n",
|
2204 |
-
" '.k.weight': '.k_proj.weight', \n",
|
2205 |
-
" '.v.': '.v_proj.', \n",
|
2206 |
-
" # '.o.weight': '.out_proj.weight', \n",
|
2207 |
-
" 'layer.0.layer_norm.': 'self_attn_layer_norm.', \n",
|
2208 |
-
" 'layer.[0-9]+.DenseReluDense.': '', \n",
|
2209 |
-
" '.wi.': '.fc1.', \n",
|
2210 |
-
" '.wo.': '.fc2.', \n",
|
2211 |
-
" \n",
|
2212 |
-
" \n",
|
2213 |
-
" # decoder\n",
|
2214 |
-
" 'layer.[0-9]+.EncDecAttention': 'encoder_attn', \n",
|
2215 |
-
" # 'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
|
2216 |
-
" \n",
|
2217 |
-
" \n",
|
2218 |
-
"}\n",
|
2219 |
-
"\n",
|
2220 |
-
"encoder_mapping = {\n",
|
2221 |
-
" 'layer.1.layer_norm.': 'final_layer_norm.', \n",
|
2222 |
-
"}\n",
|
2223 |
-
"\n",
|
2224 |
-
"decoder_mapping = {\n",
|
2225 |
-
" 'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
|
2226 |
-
" 'layer.2.layer_norm.': 'final_layer_norm.', \n",
|
2227 |
-
"}\n",
|
2228 |
-
"\n",
|
2229 |
-
"\n",
|
2230 |
-
"simple_replace_mapping = {\n",
|
2231 |
-
" \n",
|
2232 |
-
" '.o.weight': '.out_proj.weight', \n",
|
2233 |
-
"}\n",
|
2234 |
-
"def modify_state(state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping):\n",
|
2235 |
-
" # orig_keys = ['block', 'layer.[0-9]+.SelfAttention', '.q.', '.k.', '.v.', '.o.', '0.layer_norm.', '.DenseReluDense.wi.', '.DenseReluDense.wo.', '.1.layer_norm.']\n",
|
2236 |
-
" # new_keys = ['layers', 'layer.self_attn', '.q_proj.', '.k_proj.', '.v_proj.', '.out_proj.', '.self_attn_layer_norm.', '.fc1.', '.fc2.', '.final_layer_norm.']\n",
|
2237 |
-
" \n",
|
2238 |
-
" new_state = state.copy()\n",
|
2239 |
-
" old_keys = []\n",
|
2240 |
-
" for k, v in state.items():\n",
|
2241 |
-
" \n",
|
2242 |
-
" new_key = '%s' % k \n",
|
2243 |
-
" for old, new in simple_replace_mapping.items():\n",
|
2244 |
-
" new_key = new_key.replace(old, new)\n",
|
2245 |
-
" \n",
|
2246 |
-
" for old, new in mapping_dict.items():\n",
|
2247 |
-
" new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
|
2248 |
-
" \n",
|
2249 |
-
" if 'encoder' in new_key:\n",
|
2250 |
-
" for old, new in encoder_mapping.items():\n",
|
2251 |
-
" new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
|
2252 |
-
" \n",
|
2253 |
-
" if 'decoder' in new_key:\n",
|
2254 |
-
" for old, new in decoder_mapping.items():\n",
|
2255 |
-
" new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
|
2256 |
-
" \n",
|
2257 |
-
" new_state[new_key] = v\n",
|
2258 |
-
" old_keys.append(k)\n",
|
2259 |
-
" \n",
|
2260 |
-
" \n",
|
2261 |
-
" \n",
|
2262 |
-
" \n",
|
2263 |
-
" for k in old_keys:\n",
|
2264 |
-
" del new_state[k]\n",
|
2265 |
-
" \n",
|
2266 |
-
" final_state = {}\n",
|
2267 |
-
" final_state['model'] = new_state\n",
|
2268 |
-
" return final_state\n",
|
2269 |
-
" \n",
|
2270 |
-
"new_state = modify_state(t5_state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping)\n",
|
2271 |
-
"\n"
|
2272 |
-
]
|
2273 |
-
},
|
2274 |
-
{
|
2275 |
-
"cell_type": "code",
|
2276 |
-
"execution_count": null,
|
2277 |
-
"metadata": {},
|
2278 |
-
"outputs": [],
|
2279 |
-
"source": [
|
2280 |
-
"new_state['model'].keys()"
|
2281 |
-
]
|
2282 |
-
},
|
2283 |
-
{
|
2284 |
-
"cell_type": "code",
|
2285 |
-
"execution_count": null,
|
2286 |
-
"metadata": {},
|
2287 |
-
"outputs": [],
|
2288 |
-
"source": [
|
2289 |
-
"def compare_states(state1, state2):\n",
|
2290 |
-
" different = []\n",
|
2291 |
-
" for k1 in state1.keys():\n",
|
2292 |
-
" if k1 not in state2:\n",
|
2293 |
-
" different.append(k1)\n",
|
2294 |
-
" return different\n",
|
2295 |
-
" \n",
|
2296 |
-
"tmp = compare_states(new_state, ofa_state)"
|
2297 |
-
]
|
2298 |
-
},
|
2299 |
-
{
|
2300 |
-
"cell_type": "code",
|
2301 |
-
"execution_count": 35,
|
2302 |
-
"metadata": {},
|
2303 |
-
"outputs": [],
|
2304 |
-
"source": [
|
2305 |
-
"output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
|
2306 |
-
"torch.save(new_state, output_path)"
|
2307 |
-
]
|
2308 |
-
},
|
2309 |
-
{
|
2310 |
-
"cell_type": "code",
|
2311 |
-
"execution_count": 51,
|
2312 |
-
"metadata": {},
|
2313 |
-
"outputs": [],
|
2314 |
-
"source": [
|
2315 |
-
"output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
|
2316 |
-
"\n",
|
2317 |
-
"tmp_state = torch.load(output_path)"
|
2318 |
-
]
|
2319 |
-
},
|
2320 |
-
{
|
2321 |
-
"cell_type": "code",
|
2322 |
-
"execution_count": null,
|
2323 |
-
"metadata": {},
|
2324 |
-
"outputs": [],
|
2325 |
-
"source": [
|
2326 |
-
"\n",
|
2327 |
-
"model_ofa.load_state_dict(tmp_state['model'], strict=False)"
|
2328 |
-
]
|
2329 |
-
},
|
2330 |
-
{
|
2331 |
-
"cell_type": "code",
|
2332 |
-
"execution_count": null,
|
2333 |
-
"metadata": {},
|
2334 |
-
"outputs": [],
|
2335 |
-
"source": [
|
2336 |
-
"tmp_state['model'].keys()"
|
2337 |
-
]
|
2338 |
-
},
|
2339 |
-
{
|
2340 |
-
"cell_type": "code",
|
2341 |
-
"execution_count": 18,
|
2342 |
-
"metadata": {},
|
2343 |
-
"outputs": [],
|
2344 |
-
"source": [
|
2345 |
-
"tmp_state = torch.load('/data/mshukor/logs/ofa/pretrained_models/ofa_base.pt')\n"
|
2346 |
-
]
|
2347 |
-
},
|
2348 |
-
{
|
2349 |
-
"cell_type": "code",
|
2350 |
-
"execution_count": 19,
|
2351 |
-
"metadata": {},
|
2352 |
-
"outputs": [
|
2353 |
-
{
|
2354 |
-
"data": {
|
2355 |
-
"text/plain": [
|
2356 |
-
"dict_keys(['args', 'cfg', 'model', 'criterion', 'optimizer_history', 'task_state', 'extra_state', 'last_optimizer_state'])"
|
2357 |
-
]
|
2358 |
-
},
|
2359 |
-
"execution_count": 19,
|
2360 |
-
"metadata": {},
|
2361 |
-
"output_type": "execute_result"
|
2362 |
-
}
|
2363 |
-
],
|
2364 |
-
"source": [
|
2365 |
-
"tmp_state.keys()"
|
2366 |
-
]
|
2367 |
-
},
|
2368 |
-
{
|
2369 |
-
"cell_type": "code",
|
2370 |
-
"execution_count": null,
|
2371 |
-
"metadata": {},
|
2372 |
-
"outputs": [],
|
2373 |
-
"source": []
|
2374 |
-
},
|
2375 |
-
{
|
2376 |
-
"cell_type": "code",
|
2377 |
-
"execution_count": null,
|
2378 |
-
"metadata": {},
|
2379 |
-
"outputs": [],
|
2380 |
-
"source": [
|
2381 |
-
"model_t5.encoder.block[0]"
|
2382 |
-
]
|
2383 |
-
},
|
2384 |
-
{
|
2385 |
-
"cell_type": "code",
|
2386 |
-
"execution_count": null,
|
2387 |
-
"metadata": {},
|
2388 |
-
"outputs": [],
|
2389 |
-
"source": [
|
2390 |
-
"model_ofa.encoder.layers[0]"
|
2391 |
-
]
|
2392 |
-
},
|
2393 |
-
{
|
2394 |
-
"cell_type": "markdown",
|
2395 |
-
"metadata": {},
|
2396 |
-
"source": [
|
2397 |
-
"### convert BART weights"
|
2398 |
-
]
|
2399 |
-
},
|
2400 |
-
{
|
2401 |
-
"cell_type": "code",
|
2402 |
-
"execution_count": 7,
|
2403 |
-
"metadata": {},
|
2404 |
-
"outputs": [],
|
2405 |
-
"source": [
|
2406 |
-
"weights_path = '/data/mshukor/logs/ofa/pretrained_models/bart.base/model.pt'\n",
|
2407 |
-
"bart_state = torch.load(weights_path, map_location=torch.device('cpu'))"
|
2408 |
-
]
|
2409 |
-
},
|
2410 |
-
{
|
2411 |
-
"cell_type": "code",
|
2412 |
-
"execution_count": 13,
|
2413 |
-
"metadata": {},
|
2414 |
-
"outputs": [
|
2415 |
-
{
|
2416 |
-
"data": {
|
2417 |
-
"text/plain": [
|
2418 |
-
"<All keys matched successfully>"
|
2419 |
-
]
|
2420 |
-
},
|
2421 |
-
"execution_count": 13,
|
2422 |
-
"metadata": {},
|
2423 |
-
"output_type": "execute_result"
|
2424 |
-
}
|
2425 |
-
],
|
2426 |
-
"source": [
|
2427 |
-
"model_ofa.load_state_dict(bart_state['model'], strict=True)"
|
2428 |
-
]
|
2429 |
-
},
|
2430 |
-
{
|
2431 |
-
"cell_type": "code",
|
2432 |
-
"execution_count": 9,
|
2433 |
-
"metadata": {},
|
2434 |
-
"outputs": [
|
2435 |
-
{
|
2436 |
-
"data": {
|
2437 |
-
"text/plain": [
|
2438 |
-
"odict_keys(['encoder.version', 'encoder.embed_tokens.weight', 'encoder.embed_positions.weight', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.self_attn_layer_norm.bias', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc1.bias', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.fc2.bias', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.k_proj.bias', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.v_proj.bias', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.q_proj.bias', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.self_attn_layer_norm.bias', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc1.bias', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.fc2.bias', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.1.final_layer_norm.bias', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.k_proj.bias', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.v_proj.bias', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.q_proj.bias', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn.out_proj.bias', 'encoder.layers.2.self_attn_layer_norm.weight', 'encoder.layers.2.self_attn_layer_norm.bias', 'encoder.layers.2.fc1.weight', 'encoder.layers.2.fc1.bias', 'encoder.layers.2.fc2.weight', 'encoder.layers.2.fc2.bias', 'encoder.layers.2.final_layer_norm.weight', 'encoder.layers.2.final_layer_norm.bias', 'encoder.layers.3.self_attn.k_proj.weight', 'encoder.layers.3.self_attn.k_proj.bias', 'encoder.layers.3.self_attn.v_proj.weight', 'encoder.layers.3.self_attn.v_proj.bias', 'encoder.layers.3.self_attn.q_proj.weight', 'encoder.layers.3.self_attn.q_proj.bias', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.3.self_attn_layer_norm.weight', 'encoder.layers.3.self_attn_layer_norm.bias', 'encoder.layers.3.fc1.weight', 'encoder.layers.3.fc1.bias', 'encoder.layers.3.fc2.weight', 'encoder.layers.3.fc2.bias', 'encoder.layers.3.final_layer_norm.weight', 'encoder.layers.3.final_layer_norm.bias', 'encoder.layers.4.self_attn.k_proj.weight', 'encoder.layers.4.self_attn.k_proj.bias', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.4.self_attn.v_proj.bias', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.4.self_attn.q_proj.bias', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.4.self_attn_layer_norm.weight', 'encoder.layers.4.self_attn_layer_norm.bias', 'encoder.layers.4.fc1.weight', 'encoder.layers.4.fc1.bias', 'encoder.layers.4.fc2.weight', 'encoder.layers.4.fc2.bias', 'encoder.layers.4.final_layer_norm.weight', 'encoder.layers.4.final_layer_norm.bias', 'encoder.layers.5.self_attn.k_proj.weight', 'encoder.layers.5.self_attn.k_proj.bias', 'encoder.layers.5.self_attn.v_proj.weight', 'encoder.layers.5.self_attn.v_proj.bias', 'encoder.layers.5.self_attn.q_proj.weight', 'encoder.layers.5.self_attn.q_proj.bias', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.5.self_attn_layer_norm.weight', 'encoder.layers.5.self_attn_layer_norm.bias', 'encoder.layers.5.fc1.weight', 'encoder.layers.5.fc1.bias', 'encoder.layers.5.fc2.weight', 'encoder.layers.5.fc2.bias', 'encoder.layers.5.final_layer_norm.weight', 'encoder.layers.5.final_layer_norm.bias', 'encoder.layernorm_embedding.weight', 'encoder.layernorm_embedding.bias', 'decoder.version', 'decoder.embed_tokens.weight', 'decoder.embed_positions.weight', 'decoder.layers.0.self_attn.k_proj.weight', 'decoder.layers.0.self_attn.k_proj.bias', 'decoder.layers.0.self_attn.v_proj.weight', 'decoder.layers.0.self_attn.v_proj.bias', 'decoder.layers.0.self_attn.q_proj.weight', 'decoder.layers.0.self_attn.q_proj.bias', 'decoder.layers.0.self_attn.out_proj.weight', 'decoder.layers.0.self_attn.out_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.self_attn_layer_norm.bias', 'decoder.layers.0.encoder_attn.k_proj.weight', 'decoder.layers.0.encoder_attn.k_proj.bias', 'decoder.layers.0.encoder_attn.v_proj.weight', 'decoder.layers.0.encoder_attn.v_proj.bias', 'decoder.layers.0.encoder_attn.q_proj.weight', 'decoder.layers.0.encoder_attn.q_proj.bias', 'decoder.layers.0.encoder_attn.out_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.bias', 'decoder.layers.0.encoder_attn_layer_norm.weight', 'decoder.layers.0.encoder_attn_layer_norm.bias', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc1.bias', 'decoder.layers.0.fc2.weight', 'decoder.layers.0.fc2.bias', 'decoder.layers.0.final_layer_norm.weight', 'decoder.layers.0.final_layer_norm.bias', 'decoder.layers.1.self_attn.k_proj.weight', 'decoder.layers.1.self_attn.k_proj.bias', 'decoder.layers.1.self_attn.v_proj.weight', 'decoder.layers.1.self_attn.v_proj.bias', 'decoder.layers.1.self_attn.q_proj.weight', 'decoder.layers.1.self_attn.q_proj.bias', 'decoder.layers.1.self_attn.out_proj.weight', 'decoder.layers.1.self_attn.out_proj.bias', 'decoder.layers.1.self_attn_layer_norm.weight', 'decoder.layers.1.self_attn_layer_norm.bias', 'decoder.layers.1.encoder_attn.k_proj.weight', 'decoder.layers.1.encoder_attn.k_proj.bias', 'decoder.layers.1.encoder_attn.v_proj.weight', 'decoder.layers.1.encoder_attn.v_proj.bias', 'decoder.layers.1.encoder_attn.q_proj.weight', 'decoder.layers.1.encoder_attn.q_proj.bias', 'decoder.layers.1.encoder_attn.out_proj.weight', 'decoder.layers.1.encoder_attn.out_proj.bias', 'decoder.layers.1.encoder_attn_layer_norm.weight', 'decoder.layers.1.encoder_attn_layer_norm.bias', 'decoder.layers.1.fc1.weight', 'decoder.layers.1.fc1.bias', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.fc2.bias', 'decoder.layers.1.final_layer_norm.weight', 'decoder.layers.1.final_layer_norm.bias', 'decoder.layers.2.self_attn.k_proj.weight', 'decoder.layers.2.self_attn.k_proj.bias', 'decoder.layers.2.self_attn.v_proj.weight', 'decoder.layers.2.self_attn.v_proj.bias', 'decoder.layers.2.self_attn.q_proj.weight', 'decoder.layers.2.self_attn.q_proj.bias', 'decoder.layers.2.self_attn.out_proj.weight', 'decoder.layers.2.self_attn.out_proj.bias', 'decoder.layers.2.self_attn_layer_norm.weight', 'decoder.layers.2.self_attn_layer_norm.bias', 'decoder.layers.2.encoder_attn.k_proj.weight', 'decoder.layers.2.encoder_attn.k_proj.bias', 'decoder.layers.2.encoder_attn.v_proj.weight', 'decoder.layers.2.encoder_attn.v_proj.bias', 'decoder.layers.2.encoder_attn.q_proj.weight', 'decoder.layers.2.encoder_attn.q_proj.bias', 'decoder.layers.2.encoder_attn.out_proj.weight', 'decoder.layers.2.encoder_attn.out_proj.bias', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.encoder_attn_layer_norm.bias', 'decoder.layers.2.fc1.weight', 'decoder.layers.2.fc1.bias', 'decoder.layers.2.fc2.weight', 'decoder.layers.2.fc2.bias', 'decoder.layers.2.final_layer_norm.weight', 'decoder.layers.2.final_layer_norm.bias', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.3.self_attn.k_proj.bias', 'decoder.layers.3.self_attn.v_proj.weight', 'decoder.layers.3.self_attn.v_proj.bias', 'decoder.layers.3.self_attn.q_proj.weight', 'decoder.layers.3.self_attn.q_proj.bias', 'decoder.layers.3.self_attn.out_proj.weight', 'decoder.layers.3.self_attn.out_proj.bias', 'decoder.layers.3.self_attn_layer_norm.weight', 'decoder.layers.3.self_attn_layer_norm.bias', 'decoder.layers.3.encoder_attn.k_proj.weight', 'decoder.layers.3.encoder_attn.k_proj.bias', 'decoder.layers.3.encoder_attn.v_proj.weight', 'decoder.layers.3.encoder_attn.v_proj.bias', 'decoder.layers.3.encoder_attn.q_proj.weight', 'decoder.layers.3.encoder_attn.q_proj.bias', 'decoder.layers.3.encoder_attn.out_proj.weight', 'decoder.layers.3.encoder_attn.out_proj.bias', 'decoder.layers.3.encoder_attn_layer_norm.weight', 'decoder.layers.3.encoder_attn_layer_norm.bias', 'decoder.layers.3.fc1.weight', 'decoder.layers.3.fc1.bias', 'decoder.layers.3.fc2.weight', 'decoder.layers.3.fc2.bias', 'decoder.layers.3.final_layer_norm.weight', 'decoder.layers.3.final_layer_norm.bias', 'decoder.layers.4.self_attn.k_proj.weight', 'decoder.layers.4.self_attn.k_proj.bias', 'decoder.layers.4.self_attn.v_proj.weight', 'decoder.layers.4.self_attn.v_proj.bias', 'decoder.layers.4.self_attn.q_proj.weight', 'decoder.layers.4.self_attn.q_proj.bias', 'decoder.layers.4.self_attn.out_proj.weight', 'decoder.layers.4.self_attn.out_proj.bias', 'decoder.layers.4.self_attn_layer_norm.weight', 'decoder.layers.4.self_attn_layer_norm.bias', 'decoder.layers.4.encoder_attn.k_proj.weight', 'decoder.layers.4.encoder_attn.k_proj.bias', 'decoder.layers.4.encoder_attn.v_proj.weight', 'decoder.layers.4.encoder_attn.v_proj.bias', 'decoder.layers.4.encoder_attn.q_proj.weight', 'decoder.layers.4.encoder_attn.q_proj.bias', 'decoder.layers.4.encoder_attn.out_proj.weight', 'decoder.layers.4.encoder_attn.out_proj.bias', 'decoder.layers.4.encoder_attn_layer_norm.weight', 'decoder.layers.4.encoder_attn_layer_norm.bias', 'decoder.layers.4.fc1.weight', 'decoder.layers.4.fc1.bias', 'decoder.layers.4.fc2.weight', 'decoder.layers.4.fc2.bias', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.4.final_layer_norm.bias', 'decoder.layers.5.self_attn.k_proj.weight', 'decoder.layers.5.self_attn.k_proj.bias', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.5.self_attn.v_proj.bias', 'decoder.layers.5.self_attn.q_proj.weight', 'decoder.layers.5.self_attn.q_proj.bias', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.5.self_attn.out_proj.bias', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.5.self_attn_layer_norm.bias', 'decoder.layers.5.encoder_attn.k_proj.weight', 'decoder.layers.5.encoder_attn.k_proj.bias', 'decoder.layers.5.encoder_attn.v_proj.weight', 'decoder.layers.5.encoder_attn.v_proj.bias', 'decoder.layers.5.encoder_attn.q_proj.weight', 'decoder.layers.5.encoder_attn.q_proj.bias', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.5.encoder_attn.out_proj.bias', 'decoder.layers.5.encoder_attn_layer_norm.weight', 'decoder.layers.5.encoder_attn_layer_norm.bias', 'decoder.layers.5.fc1.weight', 'decoder.layers.5.fc1.bias', 'decoder.layers.5.fc2.weight', 'decoder.layers.5.fc2.bias', 'decoder.layers.5.final_layer_norm.weight', 'decoder.layers.5.final_layer_norm.bias', 'decoder.layernorm_embedding.weight', 'decoder.layernorm_embedding.bias'])"
|
2439 |
-
]
|
2440 |
-
},
|
2441 |
-
"execution_count": 9,
|
2442 |
-
"metadata": {},
|
2443 |
-
"output_type": "execute_result"
|
2444 |
-
}
|
2445 |
-
],
|
2446 |
-
"source": [
|
2447 |
-
"bart_state['model'].keys()"
|
2448 |
-
]
|
2449 |
-
},
|
2450 |
-
{
|
2451 |
-
"cell_type": "code",
|
2452 |
-
"execution_count": 12,
|
2453 |
-
"metadata": {},
|
2454 |
-
"outputs": [
|
2455 |
-
{
|
2456 |
-
"data": {
|
2457 |
-
"text/plain": [
|
2458 |
-
"tensor([[ 0.0125, 0.0014, -0.0096, ..., 0.0022, 0.1057, 0.0103],\n",
|
2459 |
-
" [-0.0114, -0.0169, -0.0184, ..., -0.0131, -0.0043, -0.0053],\n",
|
2460 |
-
" [ 0.0842, -0.0389, 0.0096, ..., 0.0583, 0.0082, 0.0357],\n",
|
2461 |
-
" ...,\n",
|
2462 |
-
" [-0.0032, -0.0313, -0.1026, ..., 0.0138, 0.0056, -0.0023],\n",
|
2463 |
-
" [ 0.0104, -0.0045, 0.0263, ..., 0.0158, 0.0324, -0.0111],\n",
|
2464 |
-
" [-0.0038, -0.0532, -0.0147, ..., 0.0067, 0.0256, 0.0009]])"
|
2465 |
-
]
|
2466 |
-
},
|
2467 |
-
"execution_count": 12,
|
2468 |
-
"metadata": {},
|
2469 |
-
"output_type": "execute_result"
|
2470 |
-
}
|
2471 |
-
],
|
2472 |
-
"source": [
|
2473 |
-
"ofa_state.keys()\n",
|
2474 |
-
"ofa_state['encoder.embed_tokens.weight']"
|
2475 |
-
]
|
2476 |
-
}
|
2477 |
-
],
|
2478 |
-
"metadata": {
|
2479 |
-
"kernelspec": {
|
2480 |
-
"display_name": "ofa",
|
2481 |
-
"language": "python",
|
2482 |
-
"name": "ofa"
|
2483 |
-
},
|
2484 |
-
"language_info": {
|
2485 |
-
"codemirror_mode": {
|
2486 |
-
"name": "ipython",
|
2487 |
-
"version": 3
|
2488 |
-
},
|
2489 |
-
"file_extension": ".py",
|
2490 |
-
"mimetype": "text/x-python",
|
2491 |
-
"name": "python",
|
2492 |
-
"nbconvert_exporter": "python",
|
2493 |
-
"pygments_lexer": "ipython3",
|
2494 |
-
"version": "3.7.4"
|
2495 |
-
}
|
2496 |
-
},
|
2497 |
-
"nbformat": 4,
|
2498 |
-
"nbformat_minor": 4
|
2499 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_tuning.md
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
<!---
|
2 |
-
Copyright 2022 The OFA-Sys Team.
|
3 |
-
All rights reserved.
|
4 |
-
This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory.
|
5 |
-
-->
|
6 |
-
|
7 |
-
## Prompt Tuning for Generative Multimodal Pretrained Models
|
8 |
-
|
9 |
-
### Overview
|
10 |
-
This is the code for **"Prompt Tuning for Generative Multimodal Pretrained Models"**, [Check our paper on ArXiv](https://arxiv.org/abs/2208.02532). This paper explores prompt tuning for generative multimodal pretrained models, instead of the constrastive learning models. We specifically focuses on the unified sequence-to-sequence learning framework and implement on our OFA models.
|
11 |
-
<br>
|
12 |
-
|
13 |
-
### Requirements
|
14 |
-
* python 3.7.4
|
15 |
-
* pytorch 1.8.1
|
16 |
-
* torchvision 0.9.1
|
17 |
-
* JAVA 1.8 (for COCO evaluation)
|
18 |
-
<br></br>
|
19 |
-
|
20 |
-
### Installation
|
21 |
-
```bash
|
22 |
-
pip install -r requirements.txt
|
23 |
-
```
|
24 |
-
<br>
|
25 |
-
|
26 |
-
### Datasets and Checkpoints
|
27 |
-
See [datasets.md](datasets.md) and [checkpoints.md](checkpoints.md).
|
28 |
-
<br>
|
29 |
-
|
30 |
-
### Training
|
31 |
-
We provide a demo script (`run_scripts/refcoco/train_refcoco_prefix.sh`) that has all the required parts for training.
|
32 |
-
|
33 |
-
```sh
|
34 |
-
sh ./run_scripts/refcoco/train_refcoco_prefix.sh
|
35 |
-
```
|
36 |
-
A few options of note:
|
37 |
-
* `--encoder-prompt` :: whether to insert prompts to the encoder
|
38 |
-
* `--decoder-prompt` :: whether to insert prompts to the decoder
|
39 |
-
* `--encoder-prompt-length` :: encoder prompt length
|
40 |
-
* `--decoder-prompt-length` :: decoder prompt length
|
41 |
-
* `--bitfit` :: whether to use bitfit
|
42 |
-
* `--adapter` :: whether to use adapter
|
43 |
-
* `--adapter-dim` :: adapter projection dim
|
44 |
-
|
45 |
-
We recommend that your workspace directory should be organized like this:
|
46 |
-
```
|
47 |
-
OFA/
|
48 |
-
├── checkpoints/
|
49 |
-
│ ├── ofa_base.pt
|
50 |
-
│ ├── ofa_large.pt
|
51 |
-
│ └── ...
|
52 |
-
├── criterions/
|
53 |
-
├── data/
|
54 |
-
├── dataset/
|
55 |
-
│ ├── caption_data/
|
56 |
-
│ ├── refcoco_data/
|
57 |
-
│ └── ...
|
58 |
-
├── fairseq/
|
59 |
-
├── models/
|
60 |
-
├── run_scripts/
|
61 |
-
├── tasks/
|
62 |
-
├── train.py
|
63 |
-
├── trainer.py
|
64 |
-
└── utils/
|
65 |
-
```
|
66 |
-
<br>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# Spaces
|
2 |
-
To provide better experience, we plan to build demos for our OFA models on Huggingface Spaces. Below we provide links to the demos. Have fun!
|
3 |
-
|
4 |
-
* Generic Interface: [](https://huggingface.co/spaces/OFA-Sys/OFA-Generic_Interface)
|
5 |
-
* Text-to-Image Generation: [](https://huggingface.co/spaces/OFA-Sys/OFA-Text2Image_Generation)
|
6 |
-
* Image Captioning: [](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)
|
7 |
-
* Referring Expression Comprehension: [](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Grounding)
|
8 |
-
* Visual Question Answering: [](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Question_Answering)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG
|
2 |
-
import soundfile as sf
|
3 |
-
import io
|
4 |
-
import torch
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
AUDIO_CFG = {
|
8 |
-
"audio_length": 1024,
|
9 |
-
"clip_samples": 480000,
|
10 |
-
"mel_bins": 64,
|
11 |
-
"sample_rate": 48000,
|
12 |
-
"window_size": 1024,
|
13 |
-
"hop_size": 480,
|
14 |
-
"fmin": 50,
|
15 |
-
"fmax": 14000,
|
16 |
-
"class_num": 527,
|
17 |
-
}
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
audio_cfg = AUDIO_CFG
|
22 |
-
max_len = 480000
|
23 |
-
data_path = '/work/NAT/gda2204/mshukor/data/audiocaps/train/--CHY2qO5zc.wav'
|
24 |
-
|
25 |
-
audio_data, orig_sr = sf.read(data_path)
|
26 |
-
# import librosa
|
27 |
-
# audio_data, orig_sr = librosa.load(data_path, sr=48000)
|
28 |
-
|
29 |
-
print(orig_sr)
|
30 |
-
if audio_data.ndim>1:
|
31 |
-
audio_data = np.mean(audio_data,axis=1)
|
32 |
-
|
33 |
-
|
34 |
-
print(audio_data.shape, audio_data)
|
35 |
-
|
36 |
-
audio_data = int16_to_float32(float32_to_int16(audio_data))
|
37 |
-
audio_data = torch.tensor(audio_data).float()
|
38 |
-
print(audio_data.dtype)
|
39 |
-
print(audio_data.shape, audio_data)
|
40 |
-
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
|
41 |
-
sample = {}
|
42 |
-
|
43 |
-
sample = get_audio_features(
|
44 |
-
sample, audio_data, max_len,
|
45 |
-
data_truncating='fusion',
|
46 |
-
data_filling='repeatpad',
|
47 |
-
audio_cfg=audio_cfg,
|
48 |
-
)
|
49 |
-
|
50 |
-
patch_audio = sample['waveform'] #.half()
|
51 |
-
print(patch_audio.shape, patch_audio.min(), patch_audio.max(), patch_audio)
|
52 |
-
|
53 |
-
patch_audio = torch.zeros(480000)
|
54 |
-
print(patch_audio.shape)
|
55 |
-
|
56 |
-
|
57 |
-
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
58 |
-
|
59 |
-
AUDIO_CFG = {
|
60 |
-
"sample_rate": 48000,
|
61 |
-
"audio_length": 1024,
|
62 |
-
"clip_samples": 480000,
|
63 |
-
"mel_bins": 64,
|
64 |
-
"sample_rate": 48000,
|
65 |
-
"window_size": 1024,
|
66 |
-
"hop_size": 480,
|
67 |
-
"fmin": 50,
|
68 |
-
"fmax": 14000,
|
69 |
-
"class_num": 527,
|
70 |
-
}
|
71 |
-
|
72 |
-
window = 'hann'
|
73 |
-
center = True
|
74 |
-
pad_mode = 'reflect'
|
75 |
-
ref = 1.0
|
76 |
-
amin = 1e-10
|
77 |
-
top_db = None
|
78 |
-
|
79 |
-
spectrogram_extractor = Spectrogram(n_fft=AUDIO_CFG['window_size'], hop_length=AUDIO_CFG['hop_size'],
|
80 |
-
win_length=AUDIO_CFG['window_size'], window=window, center=center, pad_mode=pad_mode,
|
81 |
-
freeze_parameters=True)
|
82 |
-
|
83 |
-
|
84 |
-
logmel_extractor = LogmelFilterBank(sr=AUDIO_CFG['sample_rate'], n_fft=AUDIO_CFG['window_size'],
|
85 |
-
n_mels=AUDIO_CFG['mel_bins'], fmin=AUDIO_CFG['fmin'], fmax=AUDIO_CFG['fmax'],
|
86 |
-
ref=ref, amin=amin, top_db=top_db,
|
87 |
-
freeze_parameters=True)#.half()
|
88 |
-
|
89 |
-
|
90 |
-
patch_audio = patch_audio[None, :]
|
91 |
-
print(patch_audio.shape)
|
92 |
-
spectro = spectrogram_extractor(patch_audio)
|
93 |
-
|
94 |
-
print(spectro.shape)
|
95 |
-
print(spectro)
|
96 |
-
|
97 |
-
|
98 |
-
mel = logmel_extractor(spectro)
|
99 |
-
|
100 |
-
print(mel.shape)
|
101 |
-
print(mel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,729 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3 -u
|
2 |
-
# Copyright 2022 The OFA-Sys Team.
|
3 |
-
# All rights reserved.
|
4 |
-
# This source code is licensed under the Apache 2.0 license
|
5 |
-
# found in the LICENSE file in the root directory.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Train a new model on one or across multiple GPUs.
|
9 |
-
"""
|
10 |
-
|
11 |
-
import argparse
|
12 |
-
import logging
|
13 |
-
import math
|
14 |
-
import os
|
15 |
-
import sys
|
16 |
-
from typing import Dict, Optional, Any, List, Tuple, Callable
|
17 |
-
|
18 |
-
# We need to setup root logger before importing any fairseq libraries.
|
19 |
-
logging.basicConfig(
|
20 |
-
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
|
21 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
22 |
-
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
23 |
-
stream=sys.stdout,
|
24 |
-
)
|
25 |
-
logger = logging.getLogger("fairseq_cli.train")
|
26 |
-
|
27 |
-
import numpy as np
|
28 |
-
import torch
|
29 |
-
from fairseq import (
|
30 |
-
# checkpoint_utils,
|
31 |
-
options,
|
32 |
-
quantization_utils,
|
33 |
-
tasks,
|
34 |
-
utils,
|
35 |
-
)
|
36 |
-
from fairseq.data import iterators
|
37 |
-
from fairseq.data.plasma_utils import PlasmaStore
|
38 |
-
from fairseq.dataclass.configs import FairseqConfig
|
39 |
-
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
40 |
-
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
|
41 |
-
from fairseq.file_io import PathManager
|
42 |
-
from fairseq.logging import meters, metrics, progress_bar
|
43 |
-
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
44 |
-
# from fairseq.trainer import Trainer
|
45 |
-
from omegaconf import DictConfig, OmegaConf
|
46 |
-
|
47 |
-
from utils import checkpoint_utils
|
48 |
-
from trainer import Trainer
|
49 |
-
|
50 |
-
from utils.utils import print_trainable_params_percentage, setup_for_distributed
|
51 |
-
|
52 |
-
import psutil
|
53 |
-
|
54 |
-
def main(cfg: FairseqConfig) -> None:
|
55 |
-
print(distributed_utils.is_master(cfg.distributed_training))
|
56 |
-
print(cfg.distributed_training)
|
57 |
-
setup_for_distributed(distributed_utils.is_master(cfg.distributed_training))
|
58 |
-
|
59 |
-
if isinstance(cfg, argparse.Namespace):
|
60 |
-
cfg = convert_namespace_to_omegaconf(cfg)
|
61 |
-
|
62 |
-
utils.import_user_module(cfg.common)
|
63 |
-
|
64 |
-
if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
|
65 |
-
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
|
66 |
-
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
|
67 |
-
|
68 |
-
assert (
|
69 |
-
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
70 |
-
), "Must specify batch size either with --max-tokens or --batch-size"
|
71 |
-
metrics.reset()
|
72 |
-
|
73 |
-
if cfg.common.log_file is not None:
|
74 |
-
handler = logging.FileHandler(filename=cfg.common.log_file)
|
75 |
-
logger.addHandler(handler)
|
76 |
-
|
77 |
-
np.random.seed(cfg.common.seed)
|
78 |
-
utils.set_torch_seed(cfg.common.seed)
|
79 |
-
|
80 |
-
if distributed_utils.is_master(cfg.distributed_training):
|
81 |
-
print(cfg.checkpoint.save_dir)
|
82 |
-
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
|
83 |
-
|
84 |
-
|
85 |
-
# Print args
|
86 |
-
logger.info(cfg)
|
87 |
-
|
88 |
-
|
89 |
-
if cfg.checkpoint.write_checkpoints_asynchronously:
|
90 |
-
try:
|
91 |
-
import iopath # noqa: F401
|
92 |
-
except ImportError:
|
93 |
-
logging.exception(
|
94 |
-
"Asynchronous checkpoint writing is specified but iopath is "
|
95 |
-
"not installed: `pip install iopath`"
|
96 |
-
)
|
97 |
-
return
|
98 |
-
|
99 |
-
# Setup task, e.g., translation, language modeling, etc.
|
100 |
-
task = tasks.setup_task(cfg.task)
|
101 |
-
|
102 |
-
assert cfg.criterion, "Please specify criterion to train a model"
|
103 |
-
|
104 |
-
# Build model and criterion
|
105 |
-
if cfg.distributed_training.ddp_backend == "fully_sharded":
|
106 |
-
with fsdp_enable_wrap(cfg.distributed_training):
|
107 |
-
model = fsdp_wrap(task.build_model(cfg.model))
|
108 |
-
else:
|
109 |
-
model = task.build_model(cfg.model)
|
110 |
-
|
111 |
-
# bitfit
|
112 |
-
if cfg.model.bitfit:
|
113 |
-
for name, param in model.named_parameters():
|
114 |
-
if ("layer_norm" in name and "bias" in name) or ("fc" in name and "bias" in name):
|
115 |
-
param.requires_grad = True
|
116 |
-
else:
|
117 |
-
param.requires_grad = False
|
118 |
-
|
119 |
-
criterion = task.build_criterion(cfg.criterion)
|
120 |
-
|
121 |
-
logger.info(model)
|
122 |
-
logger.info("task: {}".format(task.__class__.__name__))
|
123 |
-
logger.info("model: {}".format(model.__class__.__name__))
|
124 |
-
logger.info("criterion: {}".format(criterion.__class__.__name__))
|
125 |
-
logger.info(
|
126 |
-
"num. shared model params: {:,} (num. trained: {:,})".format(
|
127 |
-
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
|
128 |
-
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
|
129 |
-
)
|
130 |
-
)
|
131 |
-
|
132 |
-
logger.info(
|
133 |
-
"num. expert model params: {} (num. trained: {})".format(
|
134 |
-
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
|
135 |
-
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
|
136 |
-
)
|
137 |
-
)
|
138 |
-
|
139 |
-
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
140 |
-
# We load the valid dataset AFTER building the model
|
141 |
-
# data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
|
142 |
-
if cfg.dataset.combine_valid_subsets:
|
143 |
-
task.load_dataset("valid", combine=True, epoch=1)
|
144 |
-
else:
|
145 |
-
for valid_sub_split in cfg.dataset.valid_subset.split(","):
|
146 |
-
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
147 |
-
|
148 |
-
# (optionally) Configure quantization
|
149 |
-
if cfg.common.quantization_config_path is not None:
|
150 |
-
quantizer = quantization_utils.Quantizer(
|
151 |
-
config_path=cfg.common.quantization_config_path,
|
152 |
-
max_epoch=cfg.optimization.max_epoch,
|
153 |
-
max_update=cfg.optimization.max_update,
|
154 |
-
)
|
155 |
-
else:
|
156 |
-
quantizer = None
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
# for n, p in model.named_parameters():
|
162 |
-
# if not p.requires_grad:
|
163 |
-
# print(n)
|
164 |
-
|
165 |
-
# Build trainer
|
166 |
-
if cfg.common.model_parallel_size == 1:
|
167 |
-
trainer = Trainer(cfg, task, model, criterion, quantizer)
|
168 |
-
else:
|
169 |
-
trainer = MegatronTrainer(cfg, task, model, criterion)
|
170 |
-
logger.info(
|
171 |
-
"training on {} devices (GPUs/TPUs)".format(
|
172 |
-
cfg.distributed_training.distributed_world_size
|
173 |
-
)
|
174 |
-
)
|
175 |
-
logger.info(
|
176 |
-
"max tokens per device = {} and max sentences per device = {}".format(
|
177 |
-
cfg.dataset.max_tokens,
|
178 |
-
cfg.dataset.batch_size,
|
179 |
-
)
|
180 |
-
)
|
181 |
-
|
182 |
-
|
183 |
-
# Load the latest checkpoint if one is available and restore the
|
184 |
-
# corresponding train iterator
|
185 |
-
strict = getattr(cfg.model, 'strict', True)
|
186 |
-
logger.info('load checkpoint, strict:{}'.format(strict))
|
187 |
-
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
188 |
-
cfg.checkpoint,
|
189 |
-
trainer,
|
190 |
-
strict=strict,
|
191 |
-
# don't cache epoch iterators for sharded datasets
|
192 |
-
disable_iterator_cache=True,
|
193 |
-
load_on_cuda=cfg.checkpoint.load_on_cuda,
|
194 |
-
)
|
195 |
-
if cfg.common.tpu:
|
196 |
-
import torch_xla.core.xla_model as xm
|
197 |
-
xm.rendezvous("load_checkpoint") # wait for all workers
|
198 |
-
|
199 |
-
max_epoch = cfg.optimization.max_epoch or math.inf
|
200 |
-
if max_epoch > 0 and max_epoch != math.inf:
|
201 |
-
total_num_updates = sum(
|
202 |
-
math.ceil(len(epoch_itr) / cfg.optimization.update_freq[i])
|
203 |
-
if i < len(cfg.optimization.update_freq) else
|
204 |
-
math.ceil(len(epoch_itr) / cfg.optimization.update_freq[-1])
|
205 |
-
for i in range(max_epoch)
|
206 |
-
)
|
207 |
-
trainer.lr_reinit(total_num_updates, trainer.get_num_updates())
|
208 |
-
|
209 |
-
# if getattr(cfg.model, "freeze_encoder", False):
|
210 |
-
# for idx, layer in enumerate(model.encoder.layers):
|
211 |
-
# layer.requires_grad_(False)
|
212 |
-
# if getattr(cfg.model, "freeze_decoder", False):
|
213 |
-
# for idx, layer in enumerate(model.decoder.layers):
|
214 |
-
# layer.requires_grad_(False)
|
215 |
-
|
216 |
-
# if hasattr(cfg.model, 'progressive') or getattr(cfg.model, "freeze_perception", False):
|
217 |
-
# custom_unfreeze(trainer, epoch_itr, cfg.model)
|
218 |
-
|
219 |
-
# if hasattr(cfg.model, 'only_linear_proj') and getattr(cfg.model, "only_linear_proj", False):
|
220 |
-
# model.requires_grad_(False)
|
221 |
-
# model.encoder.embed_tokens.weight.requires_grad = True
|
222 |
-
# model.decoder.embed_tokens.weight.requires_grad = True
|
223 |
-
|
224 |
-
# if getattr(cfg.model, "freeze_encoder_embedding", False) or getattr(
|
225 |
-
# cfg.model, "encoder_prompt", False) or getattr(cfg.model, "decoder_prompt", False) or getattr(cfg.model, "adapter", False):
|
226 |
-
# model.encoder.embed_tokens.weight.requires_grad = False
|
227 |
-
# if getattr(cfg.model, "freeze_decoder_embedding", False) or getattr(
|
228 |
-
# cfg.model, "encoder_prompt", False) or getattr(cfg.model, "decoder_prompt", False) or getattr(cfg.model, "adapter", False):
|
229 |
-
# model.decoder.embed_tokens.weight.requires_grad = False
|
230 |
-
|
231 |
-
|
232 |
-
# model.encoder.image_proj.requires_grad_(True)
|
233 |
-
# if getattr(cfg.model, "video_encoder_name", None):
|
234 |
-
# model.encoder.video_proj.requires_grad_(True)
|
235 |
-
# if getattr(cfg.model, "audio_encoder_name", None):
|
236 |
-
# model.encoder.audio_proj.requires_grad_(True)
|
237 |
-
|
238 |
-
|
239 |
-
print_trainable_params_percentage(model)
|
240 |
-
|
241 |
-
lr = trainer.get_lr()
|
242 |
-
|
243 |
-
train_meter = meters.StopwatchMeter()
|
244 |
-
train_meter.start()
|
245 |
-
while epoch_itr.next_epoch_idx <= max_epoch:
|
246 |
-
if lr <= cfg.optimization.stop_min_lr:
|
247 |
-
logger.info(
|
248 |
-
f"stopping training because current learning rate ({lr}) is smaller "
|
249 |
-
"than or equal to minimum learning rate "
|
250 |
-
f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
|
251 |
-
)
|
252 |
-
break
|
253 |
-
|
254 |
-
# train for one epoch
|
255 |
-
|
256 |
-
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
|
257 |
-
if should_stop:
|
258 |
-
break
|
259 |
-
|
260 |
-
# only use first validation loss to update the learning rate
|
261 |
-
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
262 |
-
|
263 |
-
epoch_itr = trainer.get_train_iterator(
|
264 |
-
epoch_itr.next_epoch_idx,
|
265 |
-
# sharded data: get train iterator for next epoch
|
266 |
-
load_dataset=True,
|
267 |
-
# don't cache epoch iterators for sharded datasets
|
268 |
-
disable_iterator_cache=True,
|
269 |
-
)
|
270 |
-
train_meter.stop()
|
271 |
-
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
272 |
-
|
273 |
-
# ioPath implementation to wait for all asynchronous file writes to complete.
|
274 |
-
if cfg.checkpoint.write_checkpoints_asynchronously:
|
275 |
-
logger.info(
|
276 |
-
"ioPath PathManager waiting for all asynchronous checkpoint "
|
277 |
-
"writes to finish."
|
278 |
-
)
|
279 |
-
PathManager.async_close()
|
280 |
-
logger.info("ioPath PathManager finished waiting.")
|
281 |
-
|
282 |
-
|
283 |
-
def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
|
284 |
-
# skip check if no validation was done in the current epoch
|
285 |
-
if valid_loss is None:
|
286 |
-
return False
|
287 |
-
if cfg.checkpoint.patience <= 0:
|
288 |
-
return False
|
289 |
-
|
290 |
-
def is_better(a, b):
|
291 |
-
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
|
292 |
-
|
293 |
-
prev_best = getattr(should_stop_early, "best", None)
|
294 |
-
if prev_best is None or is_better(valid_loss, prev_best):
|
295 |
-
should_stop_early.best = valid_loss
|
296 |
-
should_stop_early.num_runs = 0
|
297 |
-
return False
|
298 |
-
else:
|
299 |
-
should_stop_early.num_runs += 1
|
300 |
-
if should_stop_early.num_runs >= cfg.checkpoint.patience:
|
301 |
-
logger.info(
|
302 |
-
"early stop since valid performance hasn't improved for last {} runs".format(
|
303 |
-
cfg.checkpoint.patience
|
304 |
-
)
|
305 |
-
)
|
306 |
-
return True
|
307 |
-
else:
|
308 |
-
return False
|
309 |
-
|
310 |
-
|
311 |
-
@metrics.aggregate("train")
|
312 |
-
def train(
|
313 |
-
cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
|
314 |
-
) -> Tuple[List[Optional[float]], bool]:
|
315 |
-
"""Train the model for one epoch and return validation losses."""
|
316 |
-
# Initialize data iterator
|
317 |
-
itr = epoch_itr.next_epoch_itr(
|
318 |
-
fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
|
319 |
-
shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
|
320 |
-
)
|
321 |
-
update_freq = (
|
322 |
-
cfg.optimization.update_freq[epoch_itr.epoch - 1]
|
323 |
-
if epoch_itr.epoch <= len(cfg.optimization.update_freq)
|
324 |
-
else cfg.optimization.update_freq[-1]
|
325 |
-
)
|
326 |
-
itr = iterators.GroupedIterator(itr, update_freq)
|
327 |
-
if cfg.common.tpu:
|
328 |
-
itr = utils.tpu_data_loader(itr)
|
329 |
-
progress = progress_bar.progress_bar(
|
330 |
-
itr,
|
331 |
-
log_format=cfg.common.log_format,
|
332 |
-
log_file=cfg.common.log_file,
|
333 |
-
log_interval=cfg.common.log_interval,
|
334 |
-
epoch=epoch_itr.epoch,
|
335 |
-
tensorboard_logdir=(
|
336 |
-
cfg.common.tensorboard_logdir
|
337 |
-
if distributed_utils.is_master(cfg.distributed_training)
|
338 |
-
else None
|
339 |
-
),
|
340 |
-
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
341 |
-
wandb_project=(
|
342 |
-
cfg.common.wandb_project
|
343 |
-
if distributed_utils.is_master(cfg.distributed_training)
|
344 |
-
else None
|
345 |
-
),
|
346 |
-
wandb_run_name=os.environ.get(
|
347 |
-
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
348 |
-
),
|
349 |
-
azureml_logging=(
|
350 |
-
cfg.common.azureml_logging
|
351 |
-
if distributed_utils.is_master(cfg.distributed_training)
|
352 |
-
else False
|
353 |
-
),
|
354 |
-
)
|
355 |
-
progress.update_config(_flatten_config(cfg))
|
356 |
-
|
357 |
-
trainer.begin_epoch(epoch_itr.epoch)
|
358 |
-
|
359 |
-
valid_subsets = cfg.dataset.valid_subset.split(",")
|
360 |
-
should_stop = False
|
361 |
-
num_updates = trainer.get_num_updates()
|
362 |
-
logger.info("Start iterating over samples")
|
363 |
-
for i, samples in enumerate(progress):
|
364 |
-
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
|
365 |
-
"train_step-%d" % i
|
366 |
-
):
|
367 |
-
log_output = trainer.train_step(samples)
|
368 |
-
|
369 |
-
if log_output is not None: # not OOM, overflow, ...
|
370 |
-
# log mid-epoch stats
|
371 |
-
num_updates = trainer.get_num_updates()
|
372 |
-
if num_updates % cfg.common.log_interval == 0:
|
373 |
-
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
374 |
-
progress.log(stats, tag="train_inner", step=num_updates)
|
375 |
-
|
376 |
-
# reset mid-epoch stats after each log interval
|
377 |
-
# the end-of-epoch stats will still be preserved
|
378 |
-
metrics.reset_meters("train_inner")
|
379 |
-
|
380 |
-
end_of_epoch = not itr.has_next()
|
381 |
-
valid_losses, should_stop = validate_and_save(
|
382 |
-
cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
383 |
-
)
|
384 |
-
|
385 |
-
# if (hasattr(cfg.model, 'progressive') or hasattr(cfg.model, 'only_linear_proj') or hasattr(cfg.model, 'freeze_perception')) and end_of_epoch:
|
386 |
-
# custom_unfreeze(trainer, epoch_itr, cfg.model)
|
387 |
-
# print_trainable_params_percentage(trainer.model)
|
388 |
-
|
389 |
-
if should_stop:
|
390 |
-
break
|
391 |
-
|
392 |
-
|
393 |
-
# print(i, len(progress))
|
394 |
-
# if i > 5:
|
395 |
-
# break
|
396 |
-
# log end-of-epoch stats
|
397 |
-
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
|
398 |
-
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
399 |
-
progress.print(stats, tag="train", step=num_updates)
|
400 |
-
print_trainable_params_percentage(trainer.model)
|
401 |
-
# reset epoch-level meters
|
402 |
-
metrics.reset_meters("train")
|
403 |
-
return valid_losses, should_stop
|
404 |
-
|
405 |
-
# progressive training
|
406 |
-
def custom_unfreeze(trainer, epoch_itr, cfg):
|
407 |
-
model = trainer.model
|
408 |
-
epoch = epoch_itr.epoch
|
409 |
-
print("Epoch, ", epoch)
|
410 |
-
## unfreeze epochs
|
411 |
-
unfreeze_epoch_encoder = cfg.unfreeze_epoch_encoder
|
412 |
-
unfreeze_epoch_decoder = cfg.unfreeze_epoch_decoder
|
413 |
-
|
414 |
-
unfreeze_epoch_image = cfg.unfreeze_epoch_image
|
415 |
-
unfreeze_epoch_video = cfg.unfreeze_epoch_video
|
416 |
-
unfreeze_epoch_audio = cfg.unfreeze_epoch_audio
|
417 |
-
|
418 |
-
|
419 |
-
if getattr(cfg, "only_linear_proj", False):
|
420 |
-
unfreeze_epoch = cfg.unfreeze_epoch
|
421 |
-
if epoch >= unfreeze_epoch:
|
422 |
-
model.requires_grad_(True)
|
423 |
-
|
424 |
-
if getattr(cfg, "freeze_encoder_embedding", False) or getattr(
|
425 |
-
cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
|
426 |
-
model.encoder.embed_tokens.weight.requires_grad = False
|
427 |
-
if getattr(cfg, "freeze_decoder_embedding", False) or getattr(
|
428 |
-
cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
|
429 |
-
model.decoder.embed_tokens.weight.requires_grad = False
|
430 |
-
|
431 |
-
if trainer._ema is not None:
|
432 |
-
trainer._ema.requires_grad_(False)
|
433 |
-
print_trainable_params_percentage(model)
|
434 |
-
return
|
435 |
-
|
436 |
-
if getattr(cfg, "freeze_perception", False):
|
437 |
-
|
438 |
-
if hasattr(model.encoder, 'embed_images'):
|
439 |
-
if epoch >= unfreeze_epoch_image:
|
440 |
-
grad = True
|
441 |
-
else:
|
442 |
-
grad = False
|
443 |
-
model.encoder.embed_images.requires_grad_(grad)
|
444 |
-
print('model.encoder.embed_images.requires_grad', grad)
|
445 |
-
if hasattr(model.encoder, 'embed_videos'):
|
446 |
-
if epoch >= unfreeze_epoch_video:
|
447 |
-
grad = True
|
448 |
-
else:
|
449 |
-
grad = False
|
450 |
-
model.encoder.embed_videos.requires_grad_(grad)
|
451 |
-
print('model.encoder.embed_videos.requires_grad', grad)
|
452 |
-
|
453 |
-
if hasattr(model.encoder, 'embed_audios'):
|
454 |
-
if epoch >= unfreeze_epoch_audio:
|
455 |
-
grad = True
|
456 |
-
else:
|
457 |
-
grad = False
|
458 |
-
model.encoder.embed_audios.requires_grad_(grad)
|
459 |
-
print('model.encoder.embed_audios.requires_grad', grad)
|
460 |
-
|
461 |
-
if trainer._ema is not None:
|
462 |
-
trainer._ema.requires_grad_(False)
|
463 |
-
return
|
464 |
-
|
465 |
-
if epoch >= unfreeze_epoch_encoder:
|
466 |
-
grad=True
|
467 |
-
else:
|
468 |
-
grad=False
|
469 |
-
for l in model.encoder.layers:
|
470 |
-
l.requires_grad_(grad)
|
471 |
-
print('model.encoder.layers.requires_grad', grad)
|
472 |
-
|
473 |
-
if epoch >= unfreeze_epoch_decoder:
|
474 |
-
grad=True
|
475 |
-
else:
|
476 |
-
grad=False
|
477 |
-
for l in model.decoder.layers:
|
478 |
-
l.requires_grad_(grad)
|
479 |
-
print('model.decoder.layers.requires_grad', grad)
|
480 |
-
|
481 |
-
if getattr(cfg, "freeze_encoder_embedding", False) or getattr(
|
482 |
-
cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
|
483 |
-
model.encoder.embed_tokens.weight.requires_grad = False
|
484 |
-
if getattr(cfg, "freeze_decoder_embedding", False) or getattr(
|
485 |
-
cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
|
486 |
-
model.decoder.embed_tokens.weight.requires_grad = False
|
487 |
-
|
488 |
-
if getattr(cfg, "encoder_prompt", False):
|
489 |
-
model.encoder.encoder_prompt_encoder.requires_grad_(True)
|
490 |
-
if getattr(cfg, "decoder_prompt", False):
|
491 |
-
model.decoder.decoder_prompt_encoder.requires_grad_(True)
|
492 |
-
if getattr(cfg, "adapter", False):
|
493 |
-
for idx, layer in enumerate(model.encoder.layers):
|
494 |
-
layer.adapter.requires_grad_(True)
|
495 |
-
for idx, layer in enumerate(model.decoder.layers):
|
496 |
-
layer.adapter.requires_grad_(True)
|
497 |
-
|
498 |
-
if hasattr(model.encoder, 'embed_images'):
|
499 |
-
if epoch >= unfreeze_epoch_image:
|
500 |
-
grad = True
|
501 |
-
else:
|
502 |
-
grad = False
|
503 |
-
model.encoder.embed_images.requires_grad_(grad)
|
504 |
-
print('model.encoder.embed_images.requires_grad', grad)
|
505 |
-
if hasattr(model.encoder, 'embed_videos'):
|
506 |
-
if epoch >= unfreeze_epoch_video:
|
507 |
-
grad = True
|
508 |
-
else:
|
509 |
-
grad = False
|
510 |
-
model.encoder.embed_videos.requires_grad_(grad)
|
511 |
-
print('model.encoder.embed_videos.requires_grad', grad)
|
512 |
-
|
513 |
-
if hasattr(model.encoder, 'embed_audios'):
|
514 |
-
if epoch >= unfreeze_epoch_audio:
|
515 |
-
grad = True
|
516 |
-
else:
|
517 |
-
grad = False
|
518 |
-
model.encoder.embed_audios.requires_grad_(grad)
|
519 |
-
print('model.encoder.embed_audios.requires_grad', grad)
|
520 |
-
|
521 |
-
if trainer._ema is not None:
|
522 |
-
trainer._ema.requires_grad_(False)
|
523 |
-
|
524 |
-
def _flatten_config(cfg: DictConfig):
|
525 |
-
config = OmegaConf.to_container(cfg)
|
526 |
-
# remove any legacy Namespaces and replace with a single "args"
|
527 |
-
namespace = None
|
528 |
-
for k, v in list(config.items()):
|
529 |
-
if isinstance(v, argparse.Namespace):
|
530 |
-
namespace = v
|
531 |
-
del config[k]
|
532 |
-
if namespace is not None:
|
533 |
-
config["args"] = vars(namespace)
|
534 |
-
return config
|
535 |
-
|
536 |
-
|
537 |
-
def validate_and_save(
|
538 |
-
cfg: DictConfig,
|
539 |
-
trainer: Trainer,
|
540 |
-
task: tasks.FairseqTask,
|
541 |
-
epoch_itr,
|
542 |
-
valid_subsets: List[str],
|
543 |
-
end_of_epoch: bool,
|
544 |
-
) -> Tuple[List[Optional[float]], bool]:
|
545 |
-
num_updates = trainer.get_num_updates()
|
546 |
-
max_update = cfg.optimization.max_update or math.inf
|
547 |
-
|
548 |
-
# Stopping conditions (and an additional one based on validation loss later
|
549 |
-
# on)
|
550 |
-
should_stop = False
|
551 |
-
if num_updates >= max_update:
|
552 |
-
should_stop = True
|
553 |
-
logger.info(
|
554 |
-
f"Stopping training due to "
|
555 |
-
f"num_updates: {num_updates} >= max_update: {max_update}"
|
556 |
-
)
|
557 |
-
|
558 |
-
training_time_hours = trainer.cumulative_training_time() / (60 * 60)
|
559 |
-
if (
|
560 |
-
cfg.optimization.stop_time_hours > 0
|
561 |
-
and training_time_hours > cfg.optimization.stop_time_hours
|
562 |
-
):
|
563 |
-
should_stop = True
|
564 |
-
logger.info(
|
565 |
-
f"Stopping training due to "
|
566 |
-
f"cumulative_training_time: {training_time_hours} > "
|
567 |
-
f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
|
568 |
-
)
|
569 |
-
|
570 |
-
do_save = (
|
571 |
-
(end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
|
572 |
-
or should_stop
|
573 |
-
or (
|
574 |
-
cfg.checkpoint.save_interval_updates > 0
|
575 |
-
and num_updates > 0
|
576 |
-
and num_updates % cfg.checkpoint.save_interval_updates == 0
|
577 |
-
and num_updates >= cfg.dataset.validate_after_updates
|
578 |
-
)
|
579 |
-
)
|
580 |
-
do_validate = (
|
581 |
-
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
582 |
-
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
|
583 |
-
or should_stop
|
584 |
-
or (
|
585 |
-
cfg.dataset.validate_interval_updates > 0
|
586 |
-
and num_updates > 0
|
587 |
-
and num_updates % cfg.dataset.validate_interval_updates == 0
|
588 |
-
)
|
589 |
-
) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates
|
590 |
-
|
591 |
-
# Validate
|
592 |
-
valid_losses = [None]
|
593 |
-
if do_validate:
|
594 |
-
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
|
595 |
-
|
596 |
-
should_stop |= should_stop_early(cfg, valid_losses[0])
|
597 |
-
|
598 |
-
# Save checkpoint
|
599 |
-
if do_save or should_stop:
|
600 |
-
checkpoint_utils.save_checkpoint(
|
601 |
-
cfg.checkpoint, trainer, epoch_itr, valid_losses[0], save_on_cuda=cfg.checkpoint.save_on_cuda,
|
602 |
-
)
|
603 |
-
|
604 |
-
return valid_losses, should_stop
|
605 |
-
|
606 |
-
|
607 |
-
def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
|
608 |
-
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
609 |
-
return stats
|
610 |
-
|
611 |
-
|
612 |
-
def validate(
|
613 |
-
cfg: DictConfig,
|
614 |
-
trainer: Trainer,
|
615 |
-
task: tasks.FairseqTask,
|
616 |
-
epoch_itr,
|
617 |
-
subsets: List[str],
|
618 |
-
) -> List[Optional[float]]:
|
619 |
-
"""Evaluate the model on the validation set(s) and return the losses."""
|
620 |
-
|
621 |
-
if cfg.dataset.fixed_validation_seed is not None:
|
622 |
-
# set fixed seed for every validation
|
623 |
-
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
|
624 |
-
|
625 |
-
trainer.begin_valid_epoch(epoch_itr.epoch)
|
626 |
-
valid_losses = []
|
627 |
-
for subset in subsets:
|
628 |
-
logger.info('begin validation on "{}" subset'.format(subset))
|
629 |
-
|
630 |
-
# Initialize data iterator
|
631 |
-
itr = trainer.get_valid_iterator(subset).next_epoch_itr(
|
632 |
-
shuffle=False, set_dataset_epoch=False # use a fixed valid set
|
633 |
-
)
|
634 |
-
if cfg.common.tpu:
|
635 |
-
itr = utils.tpu_data_loader(itr)
|
636 |
-
progress = progress_bar.progress_bar(
|
637 |
-
itr,
|
638 |
-
log_format=cfg.common.log_format,
|
639 |
-
log_interval=cfg.common.log_interval,
|
640 |
-
epoch=epoch_itr.epoch,
|
641 |
-
prefix=f"valid on '{subset}' subset",
|
642 |
-
tensorboard_logdir=(
|
643 |
-
cfg.common.tensorboard_logdir
|
644 |
-
if distributed_utils.is_master(cfg.distributed_training)
|
645 |
-
else None
|
646 |
-
),
|
647 |
-
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
648 |
-
wandb_project=(
|
649 |
-
cfg.common.wandb_project
|
650 |
-
if distributed_utils.is_master(cfg.distributed_training)
|
651 |
-
else None
|
652 |
-
),
|
653 |
-
wandb_run_name=os.environ.get(
|
654 |
-
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
|
655 |
-
),
|
656 |
-
)
|
657 |
-
|
658 |
-
# create a new root metrics aggregator so validation metrics
|
659 |
-
# don't pollute other aggregators (e.g., train meters)
|
660 |
-
with metrics.aggregate(new_root=True) as agg:
|
661 |
-
for i, sample in enumerate(progress):
|
662 |
-
if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
|
663 |
-
break
|
664 |
-
try:
|
665 |
-
# print(sample)
|
666 |
-
trainer.valid_step(sample)
|
667 |
-
except IndexError:
|
668 |
-
# print(sample)
|
669 |
-
print('didnt pass')
|
670 |
-
trainer.valid_step(sample)
|
671 |
-
continue
|
672 |
-
|
673 |
-
# log validation stats
|
674 |
-
if hasattr(task, 'get_valid_stats'):
|
675 |
-
stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values())
|
676 |
-
else:
|
677 |
-
stats = agg.get_smoothed_values()
|
678 |
-
stats = get_valid_stats(cfg, trainer, stats)
|
679 |
-
|
680 |
-
if hasattr(task, "post_validate"):
|
681 |
-
task.post_validate(trainer.get_model(), stats, agg)
|
682 |
-
|
683 |
-
|
684 |
-
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
685 |
-
|
686 |
-
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
|
687 |
-
return valid_losses
|
688 |
-
|
689 |
-
|
690 |
-
def get_valid_stats(
|
691 |
-
cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
|
692 |
-
) -> Dict[str, Any]:
|
693 |
-
stats["num_updates"] = trainer.get_num_updates()
|
694 |
-
if hasattr(checkpoint_utils.save_checkpoint, "best"):
|
695 |
-
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
|
696 |
-
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
|
697 |
-
stats[key] = best_function(
|
698 |
-
checkpoint_utils.save_checkpoint.best,
|
699 |
-
stats[cfg.checkpoint.best_checkpoint_metric],
|
700 |
-
)
|
701 |
-
return stats
|
702 |
-
|
703 |
-
|
704 |
-
def cli_main(
|
705 |
-
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
|
706 |
-
) -> None:
|
707 |
-
parser = options.get_training_parser()
|
708 |
-
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
709 |
-
print(args)
|
710 |
-
cfg = convert_namespace_to_omegaconf(args)
|
711 |
-
|
712 |
-
if cfg.common.use_plasma_view:
|
713 |
-
server = PlasmaStore(path=cfg.common.plasma_path)
|
714 |
-
logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")
|
715 |
-
|
716 |
-
if args.profile:
|
717 |
-
with torch.cuda.profiler.profile():
|
718 |
-
with torch.autograd.profiler.emit_nvtx():
|
719 |
-
distributed_utils.call_main(cfg, main)
|
720 |
-
else:
|
721 |
-
distributed_utils.call_main(cfg, main)
|
722 |
-
|
723 |
-
# if cfg.common.use_plasma_view:
|
724 |
-
# server.server.kill()
|
725 |
-
|
726 |
-
|
727 |
-
if __name__ == "__main__":
|
728 |
-
cli_main()
|
729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.py
DELETED
@@ -1,1569 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
"""
|
7 |
-
Train a network across multiple GPUs.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import contextlib
|
11 |
-
import logging
|
12 |
-
import sys
|
13 |
-
import time
|
14 |
-
from argparse import Namespace
|
15 |
-
from itertools import chain
|
16 |
-
from typing import Any, Dict, List
|
17 |
-
|
18 |
-
import torch
|
19 |
-
from fairseq import models, optim, utils
|
20 |
-
from fairseq.dataclass.configs import FairseqConfig
|
21 |
-
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
22 |
-
from fairseq.distributed import utils as distributed_utils
|
23 |
-
from fairseq.file_io import PathManager
|
24 |
-
from fairseq.logging import meters, metrics
|
25 |
-
from fairseq.models.ema import build_ema
|
26 |
-
from fairseq.nan_detector import NanDetector
|
27 |
-
from fairseq.optim import lr_scheduler
|
28 |
-
from omegaconf import OmegaConf
|
29 |
-
|
30 |
-
from utils import checkpoint_utils
|
31 |
-
import torch.nn as nn
|
32 |
-
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
|
35 |
-
|
36 |
-
class Trainer(object):
|
37 |
-
"""Main class for data parallel training.
|
38 |
-
|
39 |
-
This class supports synchronous distributed data parallel training,
|
40 |
-
where multiple workers each have a full model replica and gradients
|
41 |
-
are accumulated across workers before each update. We use
|
42 |
-
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
|
43 |
-
communication of the gradients across workers.
|
44 |
-
"""
|
45 |
-
|
46 |
-
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
|
47 |
-
|
48 |
-
if isinstance(cfg, Namespace):
|
49 |
-
logger.warning(
|
50 |
-
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
|
51 |
-
)
|
52 |
-
cfg = convert_namespace_to_omegaconf(cfg)
|
53 |
-
|
54 |
-
self.cfg = cfg
|
55 |
-
self.task = task
|
56 |
-
|
57 |
-
# catalog shared parameters
|
58 |
-
shared_params = _catalog_shared_params(model)
|
59 |
-
self.tpu = cfg.common.tpu
|
60 |
-
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
|
61 |
-
if self.cuda:
|
62 |
-
self.device = torch.device("cuda")
|
63 |
-
elif self.tpu:
|
64 |
-
self.device = utils.get_tpu_device()
|
65 |
-
else:
|
66 |
-
self.device = torch.device("cpu")
|
67 |
-
|
68 |
-
if self.is_fsdp:
|
69 |
-
import fairscale
|
70 |
-
if self.cfg.common.bf16:
|
71 |
-
raise ValueError(
|
72 |
-
"FullyShardedDataParallel is not compatible with --bf16 or "
|
73 |
-
"--memory-efficient-bf16"
|
74 |
-
)
|
75 |
-
if self.cfg.distributed_training.zero_sharding != "none":
|
76 |
-
raise ValueError(
|
77 |
-
"FullyShardedDataParallel is not compatible with --zero-sharding "
|
78 |
-
"option (it's already built in)"
|
79 |
-
)
|
80 |
-
if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0":
|
81 |
-
raise RuntimeError(
|
82 |
-
"Please update to fairscale 0.4.0 or newer when combining "
|
83 |
-
"--update-freq with FullyShardedDataParallel"
|
84 |
-
)
|
85 |
-
else:
|
86 |
-
if (
|
87 |
-
hasattr(self.cfg.distributed_training, "cpu_offload")
|
88 |
-
and self.cfg.distributed_training.cpu_offload
|
89 |
-
):
|
90 |
-
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
|
91 |
-
|
92 |
-
# copy model and criterion to current device/dtype
|
93 |
-
self._criterion = criterion
|
94 |
-
self._model = model
|
95 |
-
if not self.is_fsdp:
|
96 |
-
if cfg.common.fp16:
|
97 |
-
assert not cfg.common.amp, "Cannot use fp16 and AMP together"
|
98 |
-
self._criterion = self._criterion.half()
|
99 |
-
self._model = self._model.half()
|
100 |
-
|
101 |
-
if hasattr(self._model.encoder, 'embed_audios'):
|
102 |
-
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
|
103 |
-
for layer in self._model.modules():
|
104 |
-
if isinstance(layer, LogmelFilterBank) or isinstance(layer, Spectrogram):
|
105 |
-
layer.float()
|
106 |
-
print(layer)
|
107 |
-
# for layer in self._model.modules():
|
108 |
-
# if isinstance(layer, nn.BatchNorm2d):
|
109 |
-
# layer.float()
|
110 |
-
# print(layer)
|
111 |
-
|
112 |
-
|
113 |
-
elif cfg.common.bf16:
|
114 |
-
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
115 |
-
self._model = self._model.to(dtype=torch.bfloat16)
|
116 |
-
elif cfg.common.amp:
|
117 |
-
self._amp_retries = 0
|
118 |
-
if (
|
119 |
-
not cfg.distributed_training.pipeline_model_parallel
|
120 |
-
# the DistributedFairseqModel wrapper will handle moving to device,
|
121 |
-
# so only handle cases which don't use the wrapper
|
122 |
-
and not self.use_distributed_wrapper
|
123 |
-
):
|
124 |
-
self._criterion = self._criterion.to(device=self.device)
|
125 |
-
self._model = self._model.to(device=self.device)
|
126 |
-
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
|
127 |
-
self.last_device = None
|
128 |
-
if self.cuda and self.pipeline_model_parallel:
|
129 |
-
self.last_device = torch.device(
|
130 |
-
cfg.distributed_training.pipeline_devices[-1]
|
131 |
-
)
|
132 |
-
|
133 |
-
# check that shared parameters are preserved after device transfer
|
134 |
-
for shared_param in shared_params:
|
135 |
-
ref = _get_module_by_path(self._model, shared_param[0])
|
136 |
-
for path in shared_param[1:]:
|
137 |
-
logger.info(
|
138 |
-
"detected shared parameter: {} <- {}".format(shared_param[0], path)
|
139 |
-
)
|
140 |
-
_set_module_by_path(self._model, path, ref)
|
141 |
-
|
142 |
-
self._dummy_batch = None # indicates we don't have a dummy batch at first
|
143 |
-
self._lr_scheduler = None
|
144 |
-
self._num_updates = 0
|
145 |
-
self._num_xla_compiles = 0 # for TPUs
|
146 |
-
self._optim_history = None
|
147 |
-
self._optimizer = None
|
148 |
-
self._warn_once = set()
|
149 |
-
self._wrapped_criterion = None
|
150 |
-
self._wrapped_model = None
|
151 |
-
self._ema = None
|
152 |
-
|
153 |
-
# TODO(myleott): support tpu
|
154 |
-
if self.cuda and self.data_parallel_world_size > 1:
|
155 |
-
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
|
156 |
-
else:
|
157 |
-
self._grad_norm_buf = None
|
158 |
-
|
159 |
-
self.quantizer = quantizer
|
160 |
-
if self.quantizer is not None:
|
161 |
-
self.quantizer.set_trainer(self)
|
162 |
-
|
163 |
-
# get detailed cuda environment
|
164 |
-
if self.cuda:
|
165 |
-
self.cuda_env = utils.CudaEnvironment()
|
166 |
-
if self.data_parallel_world_size > 1:
|
167 |
-
self.cuda_env_arr = distributed_utils.all_gather_list(
|
168 |
-
self.cuda_env, group=distributed_utils.get_global_group()
|
169 |
-
)
|
170 |
-
else:
|
171 |
-
self.cuda_env_arr = [self.cuda_env]
|
172 |
-
if self.data_parallel_rank == 0:
|
173 |
-
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
|
174 |
-
else:
|
175 |
-
self.cuda_env = None
|
176 |
-
self.cuda_env_arr = None
|
177 |
-
|
178 |
-
metrics.log_start_time("wall", priority=790, round=0)
|
179 |
-
|
180 |
-
self._start_time = time.time()
|
181 |
-
self._previous_training_time = 0
|
182 |
-
self._cumulative_training_time = None
|
183 |
-
|
184 |
-
def reinitialize(self):
|
185 |
-
"""Reinitialize the Trainer, typically after model params change."""
|
186 |
-
self._lr_scheduler = None
|
187 |
-
self._optimizer = None
|
188 |
-
self._wrapped_criterion = None
|
189 |
-
self._wrapped_model = None
|
190 |
-
|
191 |
-
@property
|
192 |
-
def data_parallel_world_size(self):
|
193 |
-
if self.cfg.distributed_training.distributed_world_size == 1:
|
194 |
-
return 1
|
195 |
-
return distributed_utils.get_data_parallel_world_size()
|
196 |
-
|
197 |
-
@property
|
198 |
-
def data_parallel_process_group(self):
|
199 |
-
return distributed_utils.get_data_parallel_group()
|
200 |
-
|
201 |
-
@property
|
202 |
-
def data_parallel_rank(self):
|
203 |
-
if self.cfg.distributed_training.distributed_world_size == 1:
|
204 |
-
return 0
|
205 |
-
return distributed_utils.get_data_parallel_rank()
|
206 |
-
|
207 |
-
@property
|
208 |
-
def is_data_parallel_master(self):
|
209 |
-
# NOTE: this returns true for all model parallel replicas with data
|
210 |
-
# parallel rank 0
|
211 |
-
return self.data_parallel_rank == 0
|
212 |
-
|
213 |
-
@property
|
214 |
-
def use_distributed_wrapper(self) -> bool:
|
215 |
-
return (
|
216 |
-
self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf
|
217 |
-
) or (
|
218 |
-
self.is_fsdp and self.cfg.distributed_training.cpu_offload
|
219 |
-
)
|
220 |
-
|
221 |
-
@property
|
222 |
-
def should_save_checkpoint_on_current_rank(self) -> bool:
|
223 |
-
"""Indicates whether to save checkpoints on the current DDP rank."""
|
224 |
-
if (
|
225 |
-
self.is_fsdp and self.cfg.distributed_training.use_sharded_state
|
226 |
-
) or getattr(self.cfg.model, "base_layers", 0) > 0:
|
227 |
-
return True
|
228 |
-
else:
|
229 |
-
return self.is_data_parallel_master
|
230 |
-
|
231 |
-
@property
|
232 |
-
def always_call_state_dict_during_save_checkpoint(self) -> bool:
|
233 |
-
if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state:
|
234 |
-
# FSDP calls communication collective when consolidating checkpoints
|
235 |
-
return True
|
236 |
-
else:
|
237 |
-
return False
|
238 |
-
|
239 |
-
@property
|
240 |
-
def checkpoint_suffix(self) -> str:
|
241 |
-
"""Suffix to add to the checkpoint file name."""
|
242 |
-
if self.is_fsdp and self.cfg.distributed_training.use_sharded_state:
|
243 |
-
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(
|
244 |
-
self.data_parallel_rank
|
245 |
-
)
|
246 |
-
else:
|
247 |
-
return self.cfg.checkpoint.checkpoint_suffix or ""
|
248 |
-
|
249 |
-
@property
|
250 |
-
def criterion(self):
|
251 |
-
if self._wrapped_criterion is None:
|
252 |
-
if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
|
253 |
-
self._wrapped_criterion = models.DistributedFairseqModel(
|
254 |
-
self.cfg.distributed_training,
|
255 |
-
self._criterion,
|
256 |
-
process_group=self.data_parallel_process_group,
|
257 |
-
device=self.device,
|
258 |
-
)
|
259 |
-
else:
|
260 |
-
self._wrapped_criterion = self._criterion
|
261 |
-
return self._wrapped_criterion
|
262 |
-
|
263 |
-
@property
|
264 |
-
def model(self):
|
265 |
-
if self._wrapped_model is None:
|
266 |
-
if self.use_distributed_wrapper:
|
267 |
-
self._wrapped_model = models.DistributedFairseqModel(
|
268 |
-
self.cfg.distributed_training,
|
269 |
-
self._model,
|
270 |
-
process_group=self.data_parallel_process_group,
|
271 |
-
device=self.device,
|
272 |
-
)
|
273 |
-
else:
|
274 |
-
self._wrapped_model = self._model
|
275 |
-
return self._wrapped_model
|
276 |
-
|
277 |
-
@property
|
278 |
-
def ema(self):
|
279 |
-
if self._ema is None:
|
280 |
-
self._build_ema()
|
281 |
-
return self._ema
|
282 |
-
|
283 |
-
def _build_ema(self):
|
284 |
-
if self.cfg.ema.store_ema:
|
285 |
-
self._ema = build_ema(self._model, self.cfg.ema, self.device)
|
286 |
-
logger.info(
|
287 |
-
"Exponential Moving Average Shadow Model is initialized."
|
288 |
-
)
|
289 |
-
|
290 |
-
@property
|
291 |
-
def optimizer(self):
|
292 |
-
if self._optimizer is None:
|
293 |
-
self._build_optimizer()
|
294 |
-
return self._optimizer
|
295 |
-
|
296 |
-
@property
|
297 |
-
def lr_scheduler(self):
|
298 |
-
if self._lr_scheduler is None:
|
299 |
-
self._build_optimizer() # this will initialize self._lr_scheduler
|
300 |
-
return self._lr_scheduler
|
301 |
-
|
302 |
-
def _build_optimizer(self):
|
303 |
-
# params = list(self.model.parameters())
|
304 |
-
# print("len of model param:", len(params))
|
305 |
-
# params += list(
|
306 |
-
# filter(
|
307 |
-
# lambda p: p.requires_grad,
|
308 |
-
# chain(self.criterion.parameters()),
|
309 |
-
# )
|
310 |
-
# )
|
311 |
-
|
312 |
-
|
313 |
-
params = list(
|
314 |
-
filter(
|
315 |
-
lambda p: p.requires_grad,
|
316 |
-
chain(self.model.parameters(), self.criterion.parameters()),
|
317 |
-
)
|
318 |
-
)
|
319 |
-
print("len of optim param:", len(params))
|
320 |
-
|
321 |
-
if self.is_fsdp and self.cfg.common.fp16:
|
322 |
-
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
|
323 |
-
# mostly for the grad scaling. But if we don't have the
|
324 |
-
# --memory-efficient-fp16 flag set, then we're effectively doing
|
325 |
-
# regular --fp16 and can allow the use of optimizers that would
|
326 |
-
# otherwise be unsupported by MemoryEfficientFP16Optimizer.
|
327 |
-
allow_unsupported = not self.cfg.common.memory_efficient_fp16
|
328 |
-
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
329 |
-
self.cfg, params, allow_unsupported=allow_unsupported
|
330 |
-
)
|
331 |
-
elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
|
332 |
-
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
|
333 |
-
logger.info(
|
334 |
-
"NOTE: your device does NOT support faster training with --fp16 or --amp, "
|
335 |
-
"please switch to FP32 which is likely to be faster"
|
336 |
-
)
|
337 |
-
if (
|
338 |
-
self.cfg.common.memory_efficient_fp16
|
339 |
-
or self.cfg.common.memory_efficient_bf16
|
340 |
-
):
|
341 |
-
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
342 |
-
self.cfg, params
|
343 |
-
)
|
344 |
-
elif self.cfg.common.amp:
|
345 |
-
self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
|
346 |
-
else:
|
347 |
-
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
|
348 |
-
else:
|
349 |
-
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
|
350 |
-
logger.info("NOTE: your device may support faster training with --fp16 or --amp")
|
351 |
-
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
|
352 |
-
|
353 |
-
if self.is_fsdp:
|
354 |
-
assert (
|
355 |
-
not self.cfg.optimization.use_bmuf
|
356 |
-
), "--ddp-backend=fully_sharded is not compatible with BMUF"
|
357 |
-
assert self._optimizer.supports_flat_params, (
|
358 |
-
"--ddp-backend=fully_sharded is only compatible with pointwise "
|
359 |
-
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
|
360 |
-
"However, the sharding will result in slightly different results when "
|
361 |
-
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
|
362 |
-
)
|
363 |
-
|
364 |
-
if self.cfg.optimization.use_bmuf:
|
365 |
-
self._optimizer = optim.FairseqBMUF(
|
366 |
-
self.cfg.bmuf,
|
367 |
-
self._optimizer,
|
368 |
-
)
|
369 |
-
|
370 |
-
if self.cfg.distributed_training.zero_sharding == "os":
|
371 |
-
if (
|
372 |
-
self.cfg.common.fp16
|
373 |
-
and not self.cfg.common.memory_efficient_fp16
|
374 |
-
and not self.cfg.common.memory_efficient_bf16
|
375 |
-
) and not self.cfg.common.fp16_no_flatten_grads:
|
376 |
-
raise ValueError(
|
377 |
-
"ZeRO is incomptabile with fp16 and flattened grads. "
|
378 |
-
"Please use --fp16-no-flatten-grads"
|
379 |
-
)
|
380 |
-
else:
|
381 |
-
optim.shard_(self._optimizer, self.data_parallel_process_group)
|
382 |
-
|
383 |
-
# We should initialize the learning rate scheduler immediately after
|
384 |
-
# building the optimizer, so that the initial learning rate is set.
|
385 |
-
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
|
386 |
-
self.cfg.lr_scheduler,
|
387 |
-
self.optimizer,
|
388 |
-
)
|
389 |
-
self._lr_scheduler.step_update(0)
|
390 |
-
|
391 |
-
@property
|
392 |
-
def is_fsdp(self):
|
393 |
-
return self.cfg.distributed_training.ddp_backend == "fully_sharded"
|
394 |
-
|
395 |
-
def consolidate_optimizer(self):
|
396 |
-
"""For OSS, we need to consolidate the state dict."""
|
397 |
-
if self.cfg.checkpoint.no_save_optimizer_state:
|
398 |
-
return
|
399 |
-
self._gathered_optim_state = None
|
400 |
-
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
|
401 |
-
self.optimizer.optimizer.consolidate_state_dict()
|
402 |
-
elif self.is_fsdp and not self.model.use_sharded_state:
|
403 |
-
st = self.model.gather_full_optim_state_dict(
|
404 |
-
self.optimizer
|
405 |
-
) # only returns on rank 0
|
406 |
-
self._gathered_optim_state = st
|
407 |
-
|
408 |
-
def state_dict(self):
|
409 |
-
state_dict = {
|
410 |
-
"args": None, # legacy
|
411 |
-
"cfg": (
|
412 |
-
OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
|
413 |
-
if OmegaConf.is_config(self.cfg)
|
414 |
-
else self.cfg
|
415 |
-
),
|
416 |
-
"model": self.model.state_dict(),
|
417 |
-
"criterion": (
|
418 |
-
self.criterion.state_dict()
|
419 |
-
if utils.has_parameters(self.criterion)
|
420 |
-
else None
|
421 |
-
),
|
422 |
-
"optimizer_history": (self._optim_history or [])
|
423 |
-
+ [
|
424 |
-
{
|
425 |
-
"criterion_name": self.get_criterion().__class__.__name__,
|
426 |
-
"optimizer_name": self.optimizer.__class__.__name__,
|
427 |
-
"lr_scheduler_state": self.lr_scheduler.state_dict(),
|
428 |
-
"num_updates": self.get_num_updates(),
|
429 |
-
}
|
430 |
-
],
|
431 |
-
"task_state": self.task.state_dict() if self.task is not None else {},
|
432 |
-
"extra_state": {
|
433 |
-
"metrics": metrics.state_dict(),
|
434 |
-
"previous_training_time": self.cumulative_training_time(),
|
435 |
-
},
|
436 |
-
}
|
437 |
-
if self.cfg.ema.store_ema:
|
438 |
-
# Save EMA model state as extra state
|
439 |
-
state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
|
440 |
-
if self.cfg.ema.ema_fp32:
|
441 |
-
# Save EMA params in fp32
|
442 |
-
state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
|
443 |
-
if not self.cfg.checkpoint.no_save_optimizer_state:
|
444 |
-
if self._gathered_optim_state is not None:
|
445 |
-
state_dict["last_optimizer_state"] = self._gathered_optim_state
|
446 |
-
self._gathered_optim_state = None
|
447 |
-
else:
|
448 |
-
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
|
449 |
-
if self.is_fsdp:
|
450 |
-
# save meta data for recombining checkpoint upon loading
|
451 |
-
state_dict["fsdp_metadata"] = self.model.local_metadata_dict()
|
452 |
-
return state_dict
|
453 |
-
|
454 |
-
def save_checkpoint(self, filename, extra_state, save_on_cuda=False):
|
455 |
-
"""Save all training state in a checkpoint file."""
|
456 |
-
logger.info(f"Saving checkpoint to {filename}")
|
457 |
-
# call state_dict on all ranks in case it needs internal communication
|
458 |
-
if not save_on_cuda:
|
459 |
-
state_dict = utils.move_to_cpu(self.state_dict())
|
460 |
-
else:
|
461 |
-
print("Save on cuda")
|
462 |
-
state_dict = self.state_dict()
|
463 |
-
state_dict["extra_state"].update(extra_state)
|
464 |
-
if self.should_save_checkpoint_on_current_rank:
|
465 |
-
checkpoint_utils.torch_persistent_save(
|
466 |
-
state_dict,
|
467 |
-
filename,
|
468 |
-
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
|
469 |
-
)
|
470 |
-
logger.info(f"Finished saving checkpoint to {filename}")
|
471 |
-
|
472 |
-
def load_checkpoint(
|
473 |
-
self,
|
474 |
-
filename,
|
475 |
-
reset_optimizer=False,
|
476 |
-
reset_lr_scheduler=False,
|
477 |
-
optimizer_overrides=None,
|
478 |
-
reset_meters=False,
|
479 |
-
strict=True,
|
480 |
-
load_on_cuda=False,
|
481 |
-
):
|
482 |
-
"""
|
483 |
-
Load all training state from a checkpoint file.
|
484 |
-
rank = 0 will load the checkpoint, and then broadcast it to all
|
485 |
-
other ranks.
|
486 |
-
"""
|
487 |
-
extra_state, self._optim_history, last_optim_state = None, [], None
|
488 |
-
|
489 |
-
logger.info(f"Preparing to load checkpoint {filename}")
|
490 |
-
is_distributed = self.data_parallel_world_size > 1
|
491 |
-
bexists = PathManager.isfile(filename)
|
492 |
-
if bexists:
|
493 |
-
load_on_all_ranks = (
|
494 |
-
self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
|
495 |
-
# TPUs don't support broadcast yet, so load checkpoints
|
496 |
-
# on every worker for now
|
497 |
-
or self.tpu
|
498 |
-
# FSDP requires loading checkpoint shards on all ranks
|
499 |
-
or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state)
|
500 |
-
or getattr(self.cfg.model, "base_layers", 0) > 0
|
501 |
-
)
|
502 |
-
|
503 |
-
if load_on_all_ranks or self.data_parallel_rank == 0:
|
504 |
-
state = checkpoint_utils.load_checkpoint_to_cpu(
|
505 |
-
filename, load_on_all_ranks=load_on_all_ranks, strict=strict,load_on_cuda=load_on_cuda,
|
506 |
-
)
|
507 |
-
last_optim_state = state.get("last_optimizer_state", None)
|
508 |
-
|
509 |
-
# If doing zero_sharding, do not broadcast global optimizer
|
510 |
-
# state. Later we will broadcast sharded states to each rank
|
511 |
-
# to avoid memory from exploding.
|
512 |
-
if (
|
513 |
-
not load_on_all_ranks
|
514 |
-
and self.cfg.distributed_training.zero_sharding == "os"
|
515 |
-
and "last_optimizer_state" in state
|
516 |
-
and is_distributed
|
517 |
-
):
|
518 |
-
state["last_optimizer_state"] = "SHARDED"
|
519 |
-
else:
|
520 |
-
last_optim_state = None
|
521 |
-
state = None
|
522 |
-
|
523 |
-
if is_distributed and not load_on_all_ranks: # .contiguous()
|
524 |
-
state = distributed_utils.broadcast_object(
|
525 |
-
state,
|
526 |
-
src_rank=0,
|
527 |
-
group=self.data_parallel_process_group,
|
528 |
-
dist_device=self.device,
|
529 |
-
)
|
530 |
-
if self.data_parallel_rank > 0:
|
531 |
-
last_optim_state = state.get("last_optimizer_state", None)
|
532 |
-
|
533 |
-
# load model parameters
|
534 |
-
try:
|
535 |
-
if self.cfg.checkpoint.use_ema_weights_to_init_param and "extra_state" in state and "ema" in state["extra_state"]:
|
536 |
-
logger.info("use_ema_weights_to_init_param = True, will use EMA weights in the ckpt to init the model param...")
|
537 |
-
ema_state_dict = state["extra_state"]["ema_fp32_params"] if "ema_fp32_params" in state["extra_state"] else state["extra_state"]["ema"]
|
538 |
-
msg = self.model.load_state_dict(
|
539 |
-
ema_state_dict, strict=strict, model_cfg=self.cfg.model
|
540 |
-
)
|
541 |
-
else:
|
542 |
-
msg = self.model.load_state_dict(
|
543 |
-
state["model"], strict=strict, model_cfg=self.cfg.model
|
544 |
-
)
|
545 |
-
logger.info(msg)
|
546 |
-
|
547 |
-
# save memory for later steps
|
548 |
-
if not (self.cfg.ema.store_ema and (self.cfg.checkpoint.use_latest_weights_to_init_ema or not ("extra_state" in state and "ema" in state["extra_state"]))):
|
549 |
-
del state["model"]
|
550 |
-
if utils.has_parameters(self.get_criterion()) and 'criterion' in state:
|
551 |
-
self.get_criterion().load_state_dict(
|
552 |
-
state["criterion"], strict=strict
|
553 |
-
)
|
554 |
-
del state["criterion"]
|
555 |
-
|
556 |
-
except Exception:
|
557 |
-
raise Exception(
|
558 |
-
"Cannot load model parameters from checkpoint {}; "
|
559 |
-
"please ensure that the architectures match.".format(filename)
|
560 |
-
)
|
561 |
-
extra_state = state.get("extra_state", None)
|
562 |
-
self._optim_history = state.get("optimizer_history", None)
|
563 |
-
|
564 |
-
if last_optim_state is not None and not reset_optimizer:
|
565 |
-
# rebuild optimizer after loading model, since params may have changed
|
566 |
-
self._build_optimizer()
|
567 |
-
|
568 |
-
# only reload optimizer and lr_scheduler if they match
|
569 |
-
last_optim = self._optim_history[-1]
|
570 |
-
assert (
|
571 |
-
last_optim["criterion_name"] == self.get_criterion().__class__.__name__
|
572 |
-
), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
|
573 |
-
assert (
|
574 |
-
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
|
575 |
-
), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
|
576 |
-
|
577 |
-
if not reset_lr_scheduler:
|
578 |
-
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
|
579 |
-
|
580 |
-
if self.is_fsdp and not self.model.use_sharded_state:
|
581 |
-
# if use_sharded_state, the last_optim_state is already sharded, skip this
|
582 |
-
last_optim_state = self.model.get_shard_from_optim_state_dict(
|
583 |
-
last_optim_state
|
584 |
-
)
|
585 |
-
elif not load_on_all_ranks and is_distributed:
|
586 |
-
last_optim_state = self.optimizer.broadcast_global_state_dict(
|
587 |
-
last_optim_state
|
588 |
-
)
|
589 |
-
|
590 |
-
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
|
591 |
-
|
592 |
-
self.set_num_updates(last_optim["num_updates"])
|
593 |
-
|
594 |
-
if extra_state is not None:
|
595 |
-
itr_state = extra_state["train_iterator"]
|
596 |
-
epoch = itr_state["epoch"]
|
597 |
-
|
598 |
-
if "previous_training_time" in extra_state:
|
599 |
-
self._previous_training_time = extra_state["previous_training_time"]
|
600 |
-
self._start_time = time.time()
|
601 |
-
|
602 |
-
self.lr_step(epoch)
|
603 |
-
|
604 |
-
if (
|
605 |
-
itr_state.get("version", 1) >= 2
|
606 |
-
and itr_state["iterations_in_epoch"] == 0
|
607 |
-
):
|
608 |
-
# reset meters at start of epoch
|
609 |
-
reset_meters = True
|
610 |
-
|
611 |
-
if "metrics" in extra_state and not reset_meters:
|
612 |
-
metrics.load_state_dict(extra_state["metrics"])
|
613 |
-
|
614 |
-
# reset TimeMeters, since their start times don't make sense anymore
|
615 |
-
for meter in metrics.get_meters("default"):
|
616 |
-
if isinstance(meter, meters.TimeMeter):
|
617 |
-
meter.reset()
|
618 |
-
|
619 |
-
if self.cfg.ema.store_ema:
|
620 |
-
if self.cfg.checkpoint.use_latest_weights_to_init_ema or "ema" not in extra_state:
|
621 |
-
if "ema" not in extra_state:
|
622 |
-
logger.warn(
|
623 |
-
"EMA not found in checkpoint. But store_ema is True. "
|
624 |
-
"EMA is re-initialized from checkpoint."
|
625 |
-
)
|
626 |
-
elif self.cfg.checkpoint.use_latest_weights_to_init_ema:
|
627 |
-
logger.info(
|
628 |
-
"use_latest_weights_to_init_ema = True. EMA is re-initialized from checkpoint."
|
629 |
-
)
|
630 |
-
self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
|
631 |
-
del state["model"]
|
632 |
-
else:
|
633 |
-
logger.info(
|
634 |
-
"Loading EMA from checkpoint"
|
635 |
-
)
|
636 |
-
self.ema.restore(extra_state["ema"], build_fp32_params=False)
|
637 |
-
|
638 |
-
if self.cfg.ema.ema_fp32:
|
639 |
-
if "ema_fp32_params" in extra_state:
|
640 |
-
logger.info(
|
641 |
-
"Loading EMA fp32 params from checkpoint"
|
642 |
-
)
|
643 |
-
self.ema.build_fp32_params(extra_state["ema_fp32_params"])
|
644 |
-
else:
|
645 |
-
logger.info(
|
646 |
-
"Building EMA fp32 params from EMA model in checkpoint"
|
647 |
-
)
|
648 |
-
self.ema.build_fp32_params()
|
649 |
-
|
650 |
-
logger.info(
|
651 |
-
"Loaded checkpoint {} (epoch {} @ {} updates)".format(
|
652 |
-
filename, epoch, self.get_num_updates()
|
653 |
-
)
|
654 |
-
)
|
655 |
-
|
656 |
-
else:
|
657 |
-
logger.info("No existing checkpoint found {}".format(filename))
|
658 |
-
|
659 |
-
# print("delete state ...")
|
660 |
-
# del state # dereference seems crucial
|
661 |
-
# torch.cuda.empty_cache()
|
662 |
-
|
663 |
-
return extra_state
|
664 |
-
|
665 |
-
def get_train_iterator(
|
666 |
-
self,
|
667 |
-
epoch,
|
668 |
-
combine=True,
|
669 |
-
load_dataset=True,
|
670 |
-
data_selector=None,
|
671 |
-
shard_batch_itr=True,
|
672 |
-
disable_iterator_cache=False,
|
673 |
-
):
|
674 |
-
"""Return an EpochBatchIterator over the training set for a given epoch."""
|
675 |
-
if load_dataset:
|
676 |
-
logger.info("loading train data for epoch {}".format(epoch))
|
677 |
-
self.task.load_dataset(
|
678 |
-
self.cfg.dataset.train_subset,
|
679 |
-
epoch=epoch,
|
680 |
-
combine=combine,
|
681 |
-
data_selector=data_selector,
|
682 |
-
tpu=self.tpu,
|
683 |
-
)
|
684 |
-
batch_iterator = self.task.get_batch_iterator(
|
685 |
-
dataset=self.task.dataset(self.cfg.dataset.train_subset),
|
686 |
-
max_tokens=self.cfg.dataset.max_tokens,
|
687 |
-
max_sentences=self.cfg.dataset.batch_size,
|
688 |
-
max_positions=utils.resolve_max_positions(
|
689 |
-
self.task.max_positions(),
|
690 |
-
self.model.max_positions(),
|
691 |
-
self.cfg.dataset.max_tokens,
|
692 |
-
),
|
693 |
-
ignore_invalid_inputs=True,
|
694 |
-
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
695 |
-
seed=self.cfg.common.seed,
|
696 |
-
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
|
697 |
-
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
|
698 |
-
num_workers=self.cfg.dataset.num_workers,
|
699 |
-
epoch=epoch,
|
700 |
-
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
701 |
-
disable_iterator_cache=disable_iterator_cache,
|
702 |
-
)
|
703 |
-
self.reset_dummy_batch(batch_iterator.first_batch)
|
704 |
-
batch_iterator.dataset.dataset._seek()
|
705 |
-
return batch_iterator
|
706 |
-
|
707 |
-
def get_valid_iterator(
|
708 |
-
self,
|
709 |
-
subset,
|
710 |
-
disable_iterator_cache=False,
|
711 |
-
):
|
712 |
-
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
|
713 |
-
self.task.dataset(subset).dataset._seek()
|
714 |
-
batch_iterator = self.task.get_batch_iterator(
|
715 |
-
dataset=self.task.dataset(subset),
|
716 |
-
max_tokens=self.cfg.dataset.max_tokens_valid,
|
717 |
-
max_sentences=self.cfg.dataset.batch_size_valid,
|
718 |
-
max_positions=utils.resolve_max_positions(
|
719 |
-
self.task.max_positions(),
|
720 |
-
self.model.max_positions(),
|
721 |
-
),
|
722 |
-
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
|
723 |
-
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
724 |
-
seed=self.cfg.common.seed,
|
725 |
-
num_shards=self.data_parallel_world_size,
|
726 |
-
shard_id=self.data_parallel_rank,
|
727 |
-
num_workers=self.cfg.dataset.num_workers,
|
728 |
-
# always pass a fixed "epoch" to keep validation data consistent
|
729 |
-
# across training epochs
|
730 |
-
epoch=1,
|
731 |
-
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
732 |
-
disable_iterator_cache=disable_iterator_cache,
|
733 |
-
)
|
734 |
-
self.reset_dummy_batch(batch_iterator.first_batch)
|
735 |
-
batch_iterator.dataset.dataset._seek()
|
736 |
-
return batch_iterator
|
737 |
-
|
738 |
-
def begin_epoch(self, epoch):
|
739 |
-
"""Called at the beginning of each epoch."""
|
740 |
-
logger.info("begin training epoch {}".format(epoch))
|
741 |
-
|
742 |
-
self.lr_step_begin_epoch(epoch)
|
743 |
-
|
744 |
-
if self.quantizer is not None:
|
745 |
-
self.quantizer.begin_epoch(epoch)
|
746 |
-
|
747 |
-
# task specific setup per epoch
|
748 |
-
self.task.begin_epoch(epoch, self.get_model())
|
749 |
-
|
750 |
-
if self.tpu:
|
751 |
-
import torch_xla.core.xla_model as xm
|
752 |
-
|
753 |
-
xm.rendezvous("begin_epoch") # wait for all workers
|
754 |
-
xm.mark_step()
|
755 |
-
|
756 |
-
def begin_valid_epoch(self, epoch):
|
757 |
-
"""Called at the beginning of each validation epoch."""
|
758 |
-
|
759 |
-
# task specific setup per validation epoch
|
760 |
-
self.task.begin_valid_epoch(epoch, self.get_model())
|
761 |
-
|
762 |
-
def reset_dummy_batch(self, batch):
|
763 |
-
self._dummy_batch = batch
|
764 |
-
|
765 |
-
@metrics.aggregate("train")
|
766 |
-
def train_step(self, samples, raise_oom=False):
|
767 |
-
"""Do forward, backward and parameter update."""
|
768 |
-
self._set_seed()
|
769 |
-
self.model.train()
|
770 |
-
self.criterion.train()
|
771 |
-
self.zero_grad()
|
772 |
-
|
773 |
-
metrics.log_start_time("train_wall", priority=800, round=0)
|
774 |
-
|
775 |
-
# If EMA is enabled through store_ema=True
|
776 |
-
# and task.uses_ema is True, pass the EMA model as a keyword
|
777 |
-
# argument to the task.
|
778 |
-
extra_kwargs = {}
|
779 |
-
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
|
780 |
-
extra_kwargs["ema_model"] = self.ema.get_model()
|
781 |
-
|
782 |
-
# forward and backward pass
|
783 |
-
logging_outputs, sample_size, ooms = [], 0, 0
|
784 |
-
for i, sample in enumerate(samples): # delayed update loop
|
785 |
-
sample, is_dummy_batch = self._prepare_sample(sample)
|
786 |
-
|
787 |
-
def maybe_no_sync():
|
788 |
-
"""
|
789 |
-
Whenever *samples* contains more than one mini-batch, we
|
790 |
-
want to accumulate gradients locally and only call
|
791 |
-
all-reduce in the last backwards pass.
|
792 |
-
"""
|
793 |
-
if (
|
794 |
-
self.data_parallel_world_size > 1
|
795 |
-
and hasattr(self.model, "no_sync")
|
796 |
-
and i < len(samples) - 1
|
797 |
-
# The no_sync context manager results in increased memory
|
798 |
-
# usage with FSDP, since full-size gradients will be
|
799 |
-
# accumulated on each GPU. It's typically a better tradeoff
|
800 |
-
# to do the extra communication with FSDP.
|
801 |
-
and not self.is_fsdp
|
802 |
-
):
|
803 |
-
return self.model.no_sync()
|
804 |
-
else:
|
805 |
-
return contextlib.ExitStack() # dummy contextmanager
|
806 |
-
|
807 |
-
try:
|
808 |
-
with maybe_no_sync():
|
809 |
-
# forward and backward
|
810 |
-
loss, sample_size_i, logging_output = self.task.train_step(
|
811 |
-
sample=sample,
|
812 |
-
model=self.model,
|
813 |
-
criterion=self.criterion,
|
814 |
-
optimizer=self.optimizer,
|
815 |
-
update_num=self.get_num_updates(),
|
816 |
-
ignore_grad=is_dummy_batch,
|
817 |
-
**extra_kwargs,
|
818 |
-
)
|
819 |
-
del loss
|
820 |
-
|
821 |
-
logging_outputs.append(logging_output)
|
822 |
-
sample_size += sample_size_i
|
823 |
-
|
824 |
-
# emptying the CUDA cache after the first step can
|
825 |
-
# reduce the chance of OOM
|
826 |
-
if self.cuda and self.get_num_updates() == 0:
|
827 |
-
torch.cuda.empty_cache()
|
828 |
-
except RuntimeError as e:
|
829 |
-
if "out of memory" in str(e):
|
830 |
-
self._log_oom(e)
|
831 |
-
if raise_oom:
|
832 |
-
raise e
|
833 |
-
logger.warning(
|
834 |
-
"attempting to recover from OOM in forward/backward pass"
|
835 |
-
)
|
836 |
-
ooms += 1
|
837 |
-
self.zero_grad()
|
838 |
-
if self.cuda:
|
839 |
-
torch.cuda.empty_cache()
|
840 |
-
if self.cfg.distributed_training.distributed_world_size == 1:
|
841 |
-
return None
|
842 |
-
else:
|
843 |
-
raise e
|
844 |
-
|
845 |
-
if self.tpu and i < len(samples) - 1:
|
846 |
-
# tpu-comment: every XLA operation before marking step is
|
847 |
-
# appended to the IR graph, and processing too many batches
|
848 |
-
# before marking step can lead to OOM errors.
|
849 |
-
# To handle gradient accumulation use case, we explicitly
|
850 |
-
# mark step here for every forward pass without a backward pass
|
851 |
-
self._xla_markstep_and_send_to_cpu()
|
852 |
-
|
853 |
-
if is_dummy_batch:
|
854 |
-
if torch.is_tensor(sample_size):
|
855 |
-
sample_size.zero_()
|
856 |
-
else:
|
857 |
-
sample_size *= 0.0
|
858 |
-
|
859 |
-
if torch.is_tensor(sample_size):
|
860 |
-
sample_size = sample_size.float()
|
861 |
-
else:
|
862 |
-
sample_size = float(sample_size)
|
863 |
-
|
864 |
-
# gather logging outputs from all replicas
|
865 |
-
if self._sync_stats():
|
866 |
-
train_time = self._local_cumulative_training_time()
|
867 |
-
logging_outputs, (
|
868 |
-
sample_size,
|
869 |
-
ooms,
|
870 |
-
total_train_time,
|
871 |
-
) = self._aggregate_logging_outputs(
|
872 |
-
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch
|
873 |
-
)
|
874 |
-
self._cumulative_training_time = (
|
875 |
-
total_train_time / self.data_parallel_world_size
|
876 |
-
)
|
877 |
-
|
878 |
-
overflow = False
|
879 |
-
try:
|
880 |
-
with torch.autograd.profiler.record_function("reduce-grads"):
|
881 |
-
# reduce gradients across workers
|
882 |
-
self.optimizer.all_reduce_grads(self.model)
|
883 |
-
if utils.has_parameters(self.criterion):
|
884 |
-
self.optimizer.all_reduce_grads(self.criterion)
|
885 |
-
|
886 |
-
with torch.autograd.profiler.record_function("multiply-grads"):
|
887 |
-
# multiply gradients by (data_parallel_size / sample_size) since
|
888 |
-
# DDP normalizes by the number of data parallel workers for
|
889 |
-
# improved fp16 precision.
|
890 |
-
# Thus we get (sum_of_gradients / sample_size) at the end.
|
891 |
-
# In case of fp16, this step also undoes loss scaling.
|
892 |
-
# (Debugging note: Some optimizers perform this scaling on the
|
893 |
-
# fly, so inspecting model.parameters() or optimizer.params may
|
894 |
-
# still show the original, unscaled gradients.)
|
895 |
-
numer = (
|
896 |
-
self.data_parallel_world_size
|
897 |
-
if not self.cfg.optimization.use_bmuf or self._sync_stats()
|
898 |
-
else 1
|
899 |
-
)
|
900 |
-
self.optimizer.multiply_grads(numer / (sample_size or 1.0))
|
901 |
-
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
|
902 |
-
# way that avoids CPU/device transfers in case sample_size is a GPU or
|
903 |
-
# TPU object. The assumption is that the gradient itself is also 0.
|
904 |
-
|
905 |
-
with torch.autograd.profiler.record_function("clip-grads"):
|
906 |
-
# clip grads
|
907 |
-
grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
|
908 |
-
|
909 |
-
# check that grad norms are consistent across workers
|
910 |
-
# on tpu check tensor is slow
|
911 |
-
if not self.tpu:
|
912 |
-
if (
|
913 |
-
not self.cfg.optimization.use_bmuf
|
914 |
-
and self.cfg.distributed_training.ddp_backend != "slow_mo"
|
915 |
-
):
|
916 |
-
self._check_grad_norms(grad_norm)
|
917 |
-
if not torch.isfinite(grad_norm).all():
|
918 |
-
# in case of AMP, if gradients are Nan/Inf then
|
919 |
-
# optimizer step is still required
|
920 |
-
if self.cfg.common.amp:
|
921 |
-
overflow = True
|
922 |
-
else:
|
923 |
-
# check local gradnorm single GPU case, trigger NanDetector
|
924 |
-
raise FloatingPointError("gradients are Nan/Inf")
|
925 |
-
|
926 |
-
with torch.autograd.profiler.record_function("optimizer"):
|
927 |
-
# take an optimization step
|
928 |
-
self.task.optimizer_step(
|
929 |
-
self.optimizer, model=self.model, update_num=self.get_num_updates()
|
930 |
-
)
|
931 |
-
if self.cfg.common.amp and overflow:
|
932 |
-
if self._amp_retries == self.cfg.common.amp_batch_retries:
|
933 |
-
logger.info("AMP: skipping this batch.")
|
934 |
-
self._amp_retries = 0
|
935 |
-
else:
|
936 |
-
self._amp_retries += 1
|
937 |
-
return self.train_step(samples, raise_oom) # recursion to feed in same batch
|
938 |
-
|
939 |
-
except FloatingPointError:
|
940 |
-
# re-run the forward and backward pass with hooks attached to print
|
941 |
-
# out where it fails
|
942 |
-
self.zero_grad()
|
943 |
-
with NanDetector(self.get_model()):
|
944 |
-
for _, sample in enumerate(samples):
|
945 |
-
sample, _ = self._prepare_sample(sample)
|
946 |
-
self.task.train_step(
|
947 |
-
sample,
|
948 |
-
self.model,
|
949 |
-
self.criterion,
|
950 |
-
self.optimizer,
|
951 |
-
self.get_num_updates(),
|
952 |
-
ignore_grad=False,
|
953 |
-
**extra_kwargs,
|
954 |
-
)
|
955 |
-
raise
|
956 |
-
except OverflowError as e:
|
957 |
-
overflow = True
|
958 |
-
logger.info(
|
959 |
-
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
|
960 |
-
)
|
961 |
-
grad_norm = torch.tensor(0.0).cuda()
|
962 |
-
self.zero_grad()
|
963 |
-
except RuntimeError as e:
|
964 |
-
if "out of memory" in str(e):
|
965 |
-
self._log_oom(e)
|
966 |
-
logger.error("OOM during optimization, irrecoverable")
|
967 |
-
raise e
|
968 |
-
|
969 |
-
# Some distributed wrappers (e.g., SlowMo) need access to the optimizer
|
970 |
-
# after the step
|
971 |
-
if hasattr(self.model, "perform_additional_optimizer_actions"):
|
972 |
-
if hasattr(self.optimizer, "fp32_params"):
|
973 |
-
self.model.perform_additional_optimizer_actions(
|
974 |
-
self.optimizer.optimizer, self.optimizer.fp32_params
|
975 |
-
)
|
976 |
-
else:
|
977 |
-
self.model.perform_additional_optimizer_actions(
|
978 |
-
self.optimizer.optimizer
|
979 |
-
)
|
980 |
-
|
981 |
-
logging_output = None
|
982 |
-
if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
|
983 |
-
self.set_num_updates(self.get_num_updates() + 1)
|
984 |
-
|
985 |
-
if self.cfg.ema.store_ema:
|
986 |
-
# Step EMA forward with new model.
|
987 |
-
self.ema.step(
|
988 |
-
self.get_model(),
|
989 |
-
self.get_num_updates(),
|
990 |
-
)
|
991 |
-
metrics.log_scalar(
|
992 |
-
"ema_decay",
|
993 |
-
self.ema.get_decay(),
|
994 |
-
priority=10000,
|
995 |
-
round=5,
|
996 |
-
weight=0,
|
997 |
-
)
|
998 |
-
|
999 |
-
if self.tpu:
|
1000 |
-
import torch_xla.core.xla_model as xm
|
1001 |
-
|
1002 |
-
# mark step on TPUs
|
1003 |
-
self._xla_markstep_and_send_to_cpu()
|
1004 |
-
|
1005 |
-
# only log stats every log_interval steps
|
1006 |
-
# this causes wps to be misreported when log_interval > 1
|
1007 |
-
logging_output = {}
|
1008 |
-
if self.get_num_updates() % self.cfg.common.log_interval == 0:
|
1009 |
-
# log memory usage
|
1010 |
-
mem_info = xm.get_memory_info(self.device)
|
1011 |
-
gb_free = mem_info["kb_free"] / 1024 / 1024
|
1012 |
-
gb_total = mem_info["kb_total"] / 1024 / 1024
|
1013 |
-
metrics.log_scalar(
|
1014 |
-
"gb_free", gb_free, priority=1500, round=1, weight=0
|
1015 |
-
)
|
1016 |
-
metrics.log_scalar(
|
1017 |
-
"gb_total", gb_total, priority=1600, round=1, weight=0
|
1018 |
-
)
|
1019 |
-
logging_outputs = self._xla_markstep_and_send_to_cpu(
|
1020 |
-
logging_outputs
|
1021 |
-
)
|
1022 |
-
logging_output = self._reduce_and_log_stats(
|
1023 |
-
logging_outputs, sample_size, grad_norm
|
1024 |
-
)
|
1025 |
-
|
1026 |
-
# log whenever there's an XLA compilation, since these
|
1027 |
-
# slow down training and may indicate opportunities for
|
1028 |
-
# optimization
|
1029 |
-
self._check_xla_compilation()
|
1030 |
-
else:
|
1031 |
-
if self.cuda and self.cuda_env is not None:
|
1032 |
-
# log minimum free memory over the iteration
|
1033 |
-
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
1034 |
-
torch.cuda.reset_peak_memory_stats()
|
1035 |
-
gb_free = self.cuda_env.total_memory_in_GB - gb_used
|
1036 |
-
metrics.log_scalar(
|
1037 |
-
"gb_free", gb_free, priority=1500, round=1, weight=0
|
1038 |
-
)
|
1039 |
-
|
1040 |
-
# log stats
|
1041 |
-
logging_output = self._reduce_and_log_stats(
|
1042 |
-
logging_outputs, sample_size, grad_norm
|
1043 |
-
)
|
1044 |
-
|
1045 |
-
# clear CUDA cache to reduce memory fragmentation
|
1046 |
-
if (
|
1047 |
-
self.cuda
|
1048 |
-
and self.cfg.common.empty_cache_freq > 0
|
1049 |
-
and (
|
1050 |
-
(self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
|
1051 |
-
% self.cfg.common.empty_cache_freq
|
1052 |
-
)
|
1053 |
-
== 0
|
1054 |
-
):
|
1055 |
-
torch.cuda.empty_cache()
|
1056 |
-
|
1057 |
-
if self.cfg.common.fp16 or self.cfg.common.amp:
|
1058 |
-
metrics.log_scalar(
|
1059 |
-
"loss_scale",
|
1060 |
-
(
|
1061 |
-
self.optimizer.scaler.loss_scale
|
1062 |
-
if self.cfg.common.fp16
|
1063 |
-
else self.optimizer.scaler.get_scale()
|
1064 |
-
),
|
1065 |
-
priority=700,
|
1066 |
-
round=4,
|
1067 |
-
weight=0,
|
1068 |
-
)
|
1069 |
-
|
1070 |
-
metrics.log_stop_time("train_wall")
|
1071 |
-
return logging_output
|
1072 |
-
|
1073 |
-
@metrics.aggregate("valid")
|
1074 |
-
def valid_step(self, sample, raise_oom=False):
|
1075 |
-
"""Do forward pass in evaluation mode."""
|
1076 |
-
if self.tpu:
|
1077 |
-
import torch_xla.core.xla_model as xm
|
1078 |
-
|
1079 |
-
xm.rendezvous("valid_step") # wait for all workers
|
1080 |
-
|
1081 |
-
# If EMA is enabled through store_ema=True
|
1082 |
-
# and task.uses_ema is True, pass the EMA model as a keyword
|
1083 |
-
# argument to the task.
|
1084 |
-
extra_kwargs = {}
|
1085 |
-
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
|
1086 |
-
extra_kwargs["ema_model"] = self.ema.get_model()
|
1087 |
-
|
1088 |
-
with torch.no_grad():
|
1089 |
-
self.model.eval()
|
1090 |
-
self.criterion.eval()
|
1091 |
-
|
1092 |
-
sample, is_dummy_batch = self._prepare_sample(sample)
|
1093 |
-
|
1094 |
-
try:
|
1095 |
-
_loss, sample_size, logging_output = self.task.valid_step(
|
1096 |
-
sample, self.model, self.criterion, **extra_kwargs
|
1097 |
-
)
|
1098 |
-
except RuntimeError as e:
|
1099 |
-
if "out of memory" in str(e):
|
1100 |
-
self._log_oom(e)
|
1101 |
-
if not raise_oom:
|
1102 |
-
logger.warning(
|
1103 |
-
"ran out of memory in validation step, retrying batch"
|
1104 |
-
)
|
1105 |
-
for p in self.model.parameters():
|
1106 |
-
if p.grad is not None:
|
1107 |
-
p.grad = None # free some memory
|
1108 |
-
if self.cuda:
|
1109 |
-
torch.cuda.empty_cache()
|
1110 |
-
return self.valid_step(sample, raise_oom=True)
|
1111 |
-
raise e
|
1112 |
-
|
1113 |
-
|
1114 |
-
logging_outputs = [logging_output]
|
1115 |
-
if is_dummy_batch:
|
1116 |
-
if torch.is_tensor(sample_size):
|
1117 |
-
sample_size.zero_()
|
1118 |
-
else:
|
1119 |
-
sample_size *= 0.0
|
1120 |
-
|
1121 |
-
# gather logging outputs from all replicas
|
1122 |
-
if self.data_parallel_world_size > 1:
|
1123 |
-
logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
|
1124 |
-
logging_outputs,
|
1125 |
-
sample_size,
|
1126 |
-
ignore=is_dummy_batch,
|
1127 |
-
)
|
1128 |
-
|
1129 |
-
# log validation stats
|
1130 |
-
if self.tpu:
|
1131 |
-
logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
|
1132 |
-
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
|
1133 |
-
|
1134 |
-
return logging_output
|
1135 |
-
|
1136 |
-
def zero_grad(self):
|
1137 |
-
self.optimizer.zero_grad()
|
1138 |
-
|
1139 |
-
def lr_step_begin_epoch(self, epoch):
|
1140 |
-
"""Adjust the learning rate at the beginning of the epoch."""
|
1141 |
-
self.lr_scheduler.step_begin_epoch(epoch)
|
1142 |
-
# prefer updating the LR based on the number of steps
|
1143 |
-
return self.lr_step_update()
|
1144 |
-
|
1145 |
-
def lr_reinit(self, total_updates, num_updates):
|
1146 |
-
self.lr_scheduler.reinit(total_updates, num_updates)
|
1147 |
-
|
1148 |
-
def lr_step(self, epoch, val_loss=None):
|
1149 |
-
"""Adjust the learning rate at the end of the epoch."""
|
1150 |
-
self.lr_scheduler.step(epoch, val_loss)
|
1151 |
-
# prefer updating the LR based on the number of steps
|
1152 |
-
return self.lr_step_update()
|
1153 |
-
|
1154 |
-
def lr_step_update(self):
|
1155 |
-
"""Update the learning rate after each update."""
|
1156 |
-
new_lr = self.lr_scheduler.step_update(self.get_num_updates())
|
1157 |
-
if isinstance(new_lr, dict):
|
1158 |
-
for k, v in new_lr.items():
|
1159 |
-
metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
|
1160 |
-
new_lr = new_lr.get("default", next(iter(new_lr.values())))
|
1161 |
-
else:
|
1162 |
-
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
|
1163 |
-
return new_lr
|
1164 |
-
|
1165 |
-
def get_lr(self):
|
1166 |
-
"""Get the current learning rate."""
|
1167 |
-
return self.optimizer.get_lr()
|
1168 |
-
|
1169 |
-
def get_model(self):
|
1170 |
-
"""Get the (non-wrapped) model instance."""
|
1171 |
-
return self._model
|
1172 |
-
|
1173 |
-
def get_criterion(self):
|
1174 |
-
"""Get the (non-wrapped) criterion instance."""
|
1175 |
-
return self._criterion
|
1176 |
-
|
1177 |
-
def get_meter(self, name):
|
1178 |
-
"""[deprecated] Get a specific meter by name."""
|
1179 |
-
from fairseq import meters
|
1180 |
-
|
1181 |
-
if "get_meter" not in self._warn_once:
|
1182 |
-
self._warn_once.add("get_meter")
|
1183 |
-
utils.deprecation_warning(
|
1184 |
-
"Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
|
1185 |
-
)
|
1186 |
-
|
1187 |
-
train_meters = metrics.get_meters("train")
|
1188 |
-
if train_meters is None:
|
1189 |
-
train_meters = {}
|
1190 |
-
|
1191 |
-
if name == "train_loss" and "loss" in train_meters:
|
1192 |
-
return train_meters["loss"]
|
1193 |
-
elif name == "train_nll_loss":
|
1194 |
-
# support for legacy train.py, which assumed this meter is
|
1195 |
-
# always initialized
|
1196 |
-
m = train_meters.get("nll_loss", None)
|
1197 |
-
return m or meters.AverageMeter()
|
1198 |
-
elif name == "wall":
|
1199 |
-
# support for legacy train.py, which assumed this meter is
|
1200 |
-
# always initialized
|
1201 |
-
m = metrics.get_meter("default", "wall")
|
1202 |
-
return m or meters.TimeMeter()
|
1203 |
-
elif name == "wps":
|
1204 |
-
m = metrics.get_meter("train", "wps")
|
1205 |
-
return m or meters.TimeMeter()
|
1206 |
-
elif name in {"valid_loss", "valid_nll_loss"}:
|
1207 |
-
# support for legacy train.py, which assumed these meters
|
1208 |
-
# are always initialized
|
1209 |
-
k = name[len("valid_") :]
|
1210 |
-
m = metrics.get_meter("valid", k)
|
1211 |
-
return m or meters.AverageMeter()
|
1212 |
-
elif name == "oom":
|
1213 |
-
return meters.AverageMeter()
|
1214 |
-
elif name in train_meters:
|
1215 |
-
return train_meters[name]
|
1216 |
-
return None
|
1217 |
-
|
1218 |
-
def get_num_updates(self):
|
1219 |
-
"""Get the number of parameters updates."""
|
1220 |
-
return self._num_updates
|
1221 |
-
|
1222 |
-
def set_num_updates(self, num_updates):
|
1223 |
-
"""Set the number of parameters updates."""
|
1224 |
-
self._num_updates = num_updates
|
1225 |
-
self.lr_step_update()
|
1226 |
-
if self.quantizer:
|
1227 |
-
self.quantizer.step_update(self._num_updates)
|
1228 |
-
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
|
1229 |
-
|
1230 |
-
def clip_grad_norm(self, clip_norm):
|
1231 |
-
def agg_norm_fn(total_norm):
|
1232 |
-
total_norm = total_norm.cuda().float() ** 2
|
1233 |
-
total_norm = distributed_utils.all_reduce(
|
1234 |
-
total_norm, group=self.data_parallel_process_group
|
1235 |
-
)
|
1236 |
-
return total_norm ** 0.5
|
1237 |
-
|
1238 |
-
should_agg_norm = (
|
1239 |
-
self.is_fsdp
|
1240 |
-
and (
|
1241 |
-
self.data_parallel_process_group is not None
|
1242 |
-
or torch.distributed.is_initialized()
|
1243 |
-
)
|
1244 |
-
)
|
1245 |
-
return self.optimizer.clip_grad_norm(
|
1246 |
-
clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None
|
1247 |
-
)
|
1248 |
-
|
1249 |
-
def cumulative_training_time(self):
|
1250 |
-
if self._cumulative_training_time is None:
|
1251 |
-
# single GPU
|
1252 |
-
return self._local_cumulative_training_time()
|
1253 |
-
else:
|
1254 |
-
return self._cumulative_training_time
|
1255 |
-
|
1256 |
-
def _local_cumulative_training_time(self):
|
1257 |
-
"""Aggregate training time in seconds."""
|
1258 |
-
return time.time() - self._start_time + self._previous_training_time
|
1259 |
-
|
1260 |
-
def _fp_convert_sample(self, sample):
|
1261 |
-
def apply_half(t):
|
1262 |
-
if t.dtype is torch.float32:
|
1263 |
-
return t.to(dtype=torch.half)
|
1264 |
-
return t
|
1265 |
-
|
1266 |
-
def apply_bfloat16(t):
|
1267 |
-
if t.dtype is torch.float32:
|
1268 |
-
return t.to(dtype=torch.bfloat16)
|
1269 |
-
return t
|
1270 |
-
|
1271 |
-
if self.cfg.common.fp16:
|
1272 |
-
sample = utils.apply_to_sample(apply_half, sample)
|
1273 |
-
|
1274 |
-
if self.cfg.common.bf16:
|
1275 |
-
sample = utils.apply_to_sample(apply_bfloat16, sample)
|
1276 |
-
|
1277 |
-
return sample
|
1278 |
-
|
1279 |
-
def _prepare_sample(self, sample, is_dummy=False):
|
1280 |
-
if sample == "DUMMY":
|
1281 |
-
raise Exception(
|
1282 |
-
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
|
1283 |
-
"that the total number of batches is smaller than the number of "
|
1284 |
-
"participating GPUs. Try reducing the batch size or using fewer GPUs."
|
1285 |
-
)
|
1286 |
-
|
1287 |
-
if sample is None or len(sample) == 0:
|
1288 |
-
assert (
|
1289 |
-
self._dummy_batch is not None and len(self._dummy_batch) > 0
|
1290 |
-
), "Invalid dummy batch: {}".format(self._dummy_batch)
|
1291 |
-
sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
|
1292 |
-
return sample, True
|
1293 |
-
|
1294 |
-
# Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth
|
1295 |
-
# it makes sense to do the format conversion on the CPU and then transfer
|
1296 |
-
# a smaller buffer to the device. This also saves GPU memory capacity.
|
1297 |
-
|
1298 |
-
if self.cfg.common.on_cpu_convert_precision:
|
1299 |
-
sample = self._fp_convert_sample(sample)
|
1300 |
-
|
1301 |
-
if self.cuda:
|
1302 |
-
if self.pipeline_model_parallel:
|
1303 |
-
if 'target' in sample:
|
1304 |
-
sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
|
1305 |
-
else:
|
1306 |
-
sample = utils.move_to_cuda(sample)
|
1307 |
-
elif self.tpu and is_dummy:
|
1308 |
-
# the dummy batch may not be on the appropriate device
|
1309 |
-
sample = utils.move_to_cuda(sample, device=self.device)
|
1310 |
-
|
1311 |
-
if not self.cfg.common.on_cpu_convert_precision:
|
1312 |
-
sample = self._fp_convert_sample(sample)
|
1313 |
-
|
1314 |
-
if self._dummy_batch == "DUMMY":
|
1315 |
-
self._dummy_batch = sample
|
1316 |
-
|
1317 |
-
return sample, False
|
1318 |
-
|
1319 |
-
def _set_seed(self):
|
1320 |
-
# Set seed based on args.seed and the update number so that we get
|
1321 |
-
# reproducible results when resuming from checkpoints
|
1322 |
-
seed = self.cfg.common.seed + self.get_num_updates()
|
1323 |
-
utils.set_torch_seed(seed)
|
1324 |
-
|
1325 |
-
def _sync_stats(self):
|
1326 |
-
# Return True if it's using multiple GPUs and DDP or multiple GPUs with
|
1327 |
-
# BMUF and it's a bmuf sync with warmup iterations completed before.
|
1328 |
-
if self.data_parallel_world_size == 1:
|
1329 |
-
return False
|
1330 |
-
elif self.cfg.optimization.use_bmuf:
|
1331 |
-
return (
|
1332 |
-
self.get_num_updates() + 1
|
1333 |
-
) % self.cfg.bmuf.global_sync_iter == 0 and (
|
1334 |
-
self.get_num_updates() + 1
|
1335 |
-
) > self.cfg.bmuf.warmup_iterations
|
1336 |
-
else:
|
1337 |
-
return True
|
1338 |
-
|
1339 |
-
def _log_oom(self, exc):
|
1340 |
-
msg = "OOM: Ran out of memory with exception: {}".format(exc)
|
1341 |
-
logger.warning(msg)
|
1342 |
-
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
|
1343 |
-
for device_idx in range(torch.cuda.device_count()):
|
1344 |
-
logger.warning(torch.cuda.memory_summary(device=device_idx))
|
1345 |
-
sys.stderr.flush()
|
1346 |
-
|
1347 |
-
def _aggregate_logging_outputs(
|
1348 |
-
self,
|
1349 |
-
logging_outputs: List[Dict[str, Any]],
|
1350 |
-
*extra_stats_to_sum,
|
1351 |
-
ignore=False,
|
1352 |
-
):
|
1353 |
-
if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
|
1354 |
-
return self._fast_stat_sync_sum(
|
1355 |
-
logging_outputs, *extra_stats_to_sum, ignore=ignore
|
1356 |
-
)
|
1357 |
-
else:
|
1358 |
-
return self._all_gather_list_sync(
|
1359 |
-
logging_outputs, *extra_stats_to_sum, ignore=ignore
|
1360 |
-
)
|
1361 |
-
|
1362 |
-
def _all_gather_list_sync(
|
1363 |
-
self,
|
1364 |
-
logging_outputs: List[Dict[str, Any]],
|
1365 |
-
*extra_stats_to_sum,
|
1366 |
-
ignore=False,
|
1367 |
-
):
|
1368 |
-
"""
|
1369 |
-
Sync logging outputs across workers. all_gather_list_sync is
|
1370 |
-
suitable when logging outputs are complex types.
|
1371 |
-
"""
|
1372 |
-
if self.tpu:
|
1373 |
-
raise NotImplementedError
|
1374 |
-
if ignore:
|
1375 |
-
logging_outputs = []
|
1376 |
-
results = list(
|
1377 |
-
zip(
|
1378 |
-
*distributed_utils.all_gather_list(
|
1379 |
-
[logging_outputs] + list(extra_stats_to_sum),
|
1380 |
-
max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
|
1381 |
-
group=self.data_parallel_process_group,
|
1382 |
-
)
|
1383 |
-
)
|
1384 |
-
)
|
1385 |
-
logging_outputs, extra_stats_to_sum = results[0], results[1:]
|
1386 |
-
logging_outputs = list(chain.from_iterable(logging_outputs))
|
1387 |
-
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
|
1388 |
-
return logging_outputs, extra_stats_to_sum
|
1389 |
-
|
1390 |
-
def _fast_stat_sync_sum(
|
1391 |
-
self,
|
1392 |
-
logging_outputs: List[Dict[str, Any]],
|
1393 |
-
*extra_stats_to_sum,
|
1394 |
-
ignore=False,
|
1395 |
-
):
|
1396 |
-
"""
|
1397 |
-
Sync logging outputs across workers. fast_stat_sync_sum is
|
1398 |
-
faster than all_gather_list_sync, but is only suitable when
|
1399 |
-
logging outputs are scalars and can be summed. Note that
|
1400 |
-
*logging_outputs* cannot contain any nested dicts/lists.
|
1401 |
-
"""
|
1402 |
-
data = {}
|
1403 |
-
for i, stat in enumerate(extra_stats_to_sum):
|
1404 |
-
data["extra_stats_" + str(i)] = stat
|
1405 |
-
if len(logging_outputs) > 0:
|
1406 |
-
log_keys = list(logging_outputs[0].keys())
|
1407 |
-
for k in log_keys:
|
1408 |
-
if not ignore:
|
1409 |
-
v = sum(log[k] for log in logging_outputs if k in log)
|
1410 |
-
else:
|
1411 |
-
v = logging_outputs[0][k]
|
1412 |
-
v = torch.zeros_like(v) if torch.is_tensor(v) else 0
|
1413 |
-
data["logging_outputs_" + k] = v
|
1414 |
-
else:
|
1415 |
-
log_keys = None
|
1416 |
-
|
1417 |
-
data = distributed_utils.all_reduce_dict(
|
1418 |
-
data, device=self.device, group=self.data_parallel_process_group
|
1419 |
-
)
|
1420 |
-
|
1421 |
-
extra_stats_to_sum = [
|
1422 |
-
data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
|
1423 |
-
]
|
1424 |
-
if log_keys is not None:
|
1425 |
-
logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
|
1426 |
-
else:
|
1427 |
-
logging_outputs = []
|
1428 |
-
return logging_outputs, extra_stats_to_sum
|
1429 |
-
|
1430 |
-
def _check_grad_norms(self, grad_norm):
|
1431 |
-
"""Check that grad norms are consistent across workers."""
|
1432 |
-
if self._grad_norm_buf is not None:
|
1433 |
-
self._grad_norm_buf.zero_()
|
1434 |
-
self._grad_norm_buf[self.data_parallel_rank] = grad_norm
|
1435 |
-
distributed_utils.all_reduce(
|
1436 |
-
self._grad_norm_buf, group=self.data_parallel_process_group
|
1437 |
-
)
|
1438 |
-
|
1439 |
-
def is_consistent(tensor):
|
1440 |
-
max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
|
1441 |
-
return (
|
1442 |
-
(torch.isfinite(tensor).all()
|
1443 |
-
and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
|
1444 |
-
or
|
1445 |
-
(self.cfg.common.amp and not torch.isfinite(tensor).all())
|
1446 |
-
# in case of amp non-finite grads are fine
|
1447 |
-
)
|
1448 |
-
|
1449 |
-
if not is_consistent(self._grad_norm_buf):
|
1450 |
-
pretty_detail = "\n".join(
|
1451 |
-
"rank {:3d} = {:.8f}".format(r, n)
|
1452 |
-
for r, n in enumerate(self._grad_norm_buf.tolist())
|
1453 |
-
)
|
1454 |
-
error_detail = "grad_norm across the workers:\n{}\n".format(
|
1455 |
-
pretty_detail
|
1456 |
-
)
|
1457 |
-
# use FloatingPointError to trigger NanDetector
|
1458 |
-
raise FloatingPointError(
|
1459 |
-
"Fatal error: gradients are inconsistent between workers. "
|
1460 |
-
"Try --ddp-backend=legacy_ddp. "
|
1461 |
-
"Or are you mixing up different generation of GPUs in training?"
|
1462 |
-
+ "\n"
|
1463 |
-
+ "-" * 80
|
1464 |
-
+ "\n{}\n".format(error_detail)
|
1465 |
-
+ "-" * 80
|
1466 |
-
)
|
1467 |
-
|
1468 |
-
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
|
1469 |
-
if grad_norm is not None and (
|
1470 |
-
not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
|
1471 |
-
):
|
1472 |
-
metrics.log_speed("ups", 1.0, priority=100, round=2)
|
1473 |
-
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
|
1474 |
-
if self.cfg.optimization.clip_norm > 0:
|
1475 |
-
metrics.log_scalar(
|
1476 |
-
"clip",
|
1477 |
-
torch.where(
|
1478 |
-
grad_norm > self.cfg.optimization.clip_norm,
|
1479 |
-
grad_norm.new_tensor(100),
|
1480 |
-
grad_norm.new_tensor(0),
|
1481 |
-
),
|
1482 |
-
priority=500,
|
1483 |
-
round=1,
|
1484 |
-
)
|
1485 |
-
|
1486 |
-
with metrics.aggregate() as agg:
|
1487 |
-
if logging_outputs is not None:
|
1488 |
-
self.task.reduce_metrics(logging_outputs, self.get_criterion())
|
1489 |
-
del logging_outputs
|
1490 |
-
|
1491 |
-
# extra warning for criterions that don't properly log a loss value
|
1492 |
-
if "loss" not in agg:
|
1493 |
-
if "loss" not in self._warn_once:
|
1494 |
-
self._warn_once.add("loss")
|
1495 |
-
logger.warning(
|
1496 |
-
"Criterion.reduce_metrics did not log a 'loss' value, "
|
1497 |
-
"which may break some functionality"
|
1498 |
-
)
|
1499 |
-
metrics.log_scalar("loss", -1)
|
1500 |
-
|
1501 |
-
# support legacy interface
|
1502 |
-
if self.tpu:
|
1503 |
-
logging_output = {}
|
1504 |
-
else:
|
1505 |
-
logging_output = agg.get_smoothed_values()
|
1506 |
-
logging_output["sample_size"] = sample_size
|
1507 |
-
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
|
1508 |
-
if key_to_delete in logging_output:
|
1509 |
-
del logging_output[key_to_delete]
|
1510 |
-
return logging_output
|
1511 |
-
|
1512 |
-
def _check_xla_compilation(self):
|
1513 |
-
import torch_xla.debug.metrics as met
|
1514 |
-
|
1515 |
-
compile_stats = met.metric_data("CompileTime")
|
1516 |
-
if compile_stats is None:
|
1517 |
-
return
|
1518 |
-
num_xla_compiles = compile_stats[0]
|
1519 |
-
if num_xla_compiles > self._num_xla_compiles:
|
1520 |
-
logger.warning(
|
1521 |
-
"XLA compilation detected on device #{}; too many of these can lead "
|
1522 |
-
"to slow training, but we expect a few in the beginning".format(
|
1523 |
-
self.cfg.distributed_training.distributed_rank
|
1524 |
-
)
|
1525 |
-
)
|
1526 |
-
self._num_xla_compiles = num_xla_compiles
|
1527 |
-
|
1528 |
-
def _xla_markstep_and_send_to_cpu(self, data=None):
|
1529 |
-
import torch_xla.core.xla_model as xm
|
1530 |
-
|
1531 |
-
xm.mark_step()
|
1532 |
-
if data is not None:
|
1533 |
-
from fairseq.utils import xla_device_to_cpu
|
1534 |
-
|
1535 |
-
return xla_device_to_cpu(data)
|
1536 |
-
|
1537 |
-
|
1538 |
-
def _catalog_shared_params(module, memo=None, prefix=""):
|
1539 |
-
if memo is None:
|
1540 |
-
first_call = True
|
1541 |
-
memo = {}
|
1542 |
-
else:
|
1543 |
-
first_call = False
|
1544 |
-
for name, param in module._parameters.items():
|
1545 |
-
param_prefix = prefix + ("." if prefix else "") + name
|
1546 |
-
if param not in memo:
|
1547 |
-
memo[param] = []
|
1548 |
-
memo[param].append(param_prefix)
|
1549 |
-
for name, m in module._modules.items():
|
1550 |
-
if m is None:
|
1551 |
-
continue
|
1552 |
-
submodule_prefix = prefix + ("." if prefix else "") + name
|
1553 |
-
_catalog_shared_params(m, memo, submodule_prefix)
|
1554 |
-
if first_call:
|
1555 |
-
return [x for x in memo.values() if len(x) > 1]
|
1556 |
-
|
1557 |
-
|
1558 |
-
def _get_module_by_path(module, path):
|
1559 |
-
path = path.split(".")
|
1560 |
-
for name in path:
|
1561 |
-
module = getattr(module, name)
|
1562 |
-
return module
|
1563 |
-
|
1564 |
-
|
1565 |
-
def _set_module_by_path(module, path, value):
|
1566 |
-
path = path.split(".")
|
1567 |
-
for name in path[:-1]:
|
1568 |
-
module = getattr(module, name)
|
1569 |
-
setattr(module, path[-1], value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformers.md
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
# Use in huggingface transformers (Beta)
|
2 |
-
|
3 |
-
[**Colab Notebook**](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
|
4 |
-

|
5 |
-
|
6 |
-
|
7 |
-
We now support inference of OFA on the huggingface transformers. In the near future, we will provide the codes for training.
|
8 |
-
|
9 |
-
Model checkpoints are stored in our [huggingface models](https://huggingface.co/OFA-Sys). Specifically, 5 versions of the pretrained OFA models, namely OFA-tiny, OFA-medium, OFA-base, OFA-large, and OFA-huge have been already uploaded. For more information about the models, please refer to the Model Card on our [README](https://github.com/OFA-Sys/OFA).
|
10 |
-
Note that each directory includes 4 files, namely `config.json` which consists of model configuration, `vocab.json` and `merge.txt` for our OFA tokenizer, and lastly `pytorch_model.bin` which consists of model weights. There is no need to worry about the mismatch between Fairseq and transformers, since we have addressed the issue yet.
|
11 |
-
|
12 |
-
To use it in transformers, you can first refer to this notebook ([link](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)). For more information, you can find codes in this branch https://github.com/OFA-Sys/OFA/tree/feature/add_transformers.
|
13 |
-
|
14 |
-
In the following, we introduce the details in our provided notebook and illustrate how to use OFA in Transformers.
|
15 |
-
|
16 |
-
First, install the transformers and download the models (take OFA-tiny as an example) as shown below.
|
17 |
-
|
18 |
-
```
|
19 |
-
git clone --single-branch --branch feature/add_transformers https://github.com/OFA-Sys/OFA.git
|
20 |
-
pip install OFA/transformers/
|
21 |
-
git clone https://huggingface.co/OFA-Sys/OFA-tiny
|
22 |
-
```
|
23 |
-
|
24 |
-
Next, refer the path to OFA-tiny to `ckpt_dir`, and prepare an image for the testing example below. Also, ensure that you have pillow and torchvision in your environment. Check if there is the directory `generate` in your model directory `transformers/src/transformers/models/ofa` to ensure that you can use the sequence generator that we provide.
|
25 |
-
|
26 |
-
```
|
27 |
-
>>> from PIL import Image
|
28 |
-
>>> from torchvision import transforms
|
29 |
-
>>> from transformers import OFATokenizer, OFAModel
|
30 |
-
>>> from transformers.models.ofa.generate import sequence_generator
|
31 |
-
|
32 |
-
>>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
33 |
-
>>> resolution = 256
|
34 |
-
>>> patch_resize_transform = transforms.Compose([
|
35 |
-
lambda image: image.convert("RGB"),
|
36 |
-
transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
|
37 |
-
transforms.ToTensor(),
|
38 |
-
transforms.Normalize(mean=mean, std=std)
|
39 |
-
])
|
40 |
-
|
41 |
-
|
42 |
-
>>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
|
43 |
-
|
44 |
-
>>> txt = " what does the image describe?"
|
45 |
-
>>> inputs = tokenizer([txt], return_tensors="pt").input_ids
|
46 |
-
>>> img = Image.open(path_to_image)
|
47 |
-
>>> patch_img = patch_resize_transform(img).unsqueeze(0)
|
48 |
-
|
49 |
-
|
50 |
-
>>> # using the generator of fairseq version
|
51 |
-
>>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
|
52 |
-
>>> generator = sequence_generator.SequenceGenerator(
|
53 |
-
tokenizer=tokenizer,
|
54 |
-
beam_size=5,
|
55 |
-
max_len_b=16,
|
56 |
-
min_len=0,
|
57 |
-
no_repeat_ngram_size=3,
|
58 |
-
)
|
59 |
-
>>> data = {}
|
60 |
-
>>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
|
61 |
-
>>> gen_output = generator.generate([model], data)
|
62 |
-
>>> gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]
|
63 |
-
|
64 |
-
>>> # using the generator of huggingface version
|
65 |
-
>>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)
|
66 |
-
>>> gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
|
67 |
-
|
68 |
-
>>> print(tokenizer.batch_decode(gen, skip_special_tokens=True))
|
69 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|