mshukor commited on
Commit
80b588a
·
1 Parent(s): 1aa8f27
Audio_Captioning.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
Captioning.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
Image_gen.ipynb DELETED
@@ -1,301 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "399f2fcf-9241-4910-a30d-6ca19880d0ad",
6
- "metadata": {},
7
- "source": [
8
- "## Import"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 15,
14
- "id": "97e68340-0096-475e-8ed8-22f5d627e3ad",
15
- "metadata": {},
16
- "outputs": [],
17
- "source": [
18
- "import torch\n",
19
- "import numpy as np\n",
20
- "from fairseq import utils, tasks\n",
21
- "from fairseq import checkpoint_utils\n",
22
- "from utils.eval_utils import eval_step\n",
23
- "from tasks.mm_tasks import ImageGenTask\n",
24
- "from models.unival import UnIVALModel\n",
25
- "from PIL import Image\n",
26
- "from torchvision import transforms\n",
27
- "import time\n",
28
- "\n",
29
- "\n",
30
- "# turn on cuda if GPU is available\n",
31
- "use_cuda = torch.cuda.is_available()\n",
32
- "# use fp16 only when GPU is available\n",
33
- "use_fp16 = True if use_cuda else False"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": 16,
39
- "id": "719cef65-c00c-4c9c-90b2-e660b386c3d5",
40
- "metadata": {},
41
- "outputs": [
42
- {
43
- "data": {
44
- "text/plain": [
45
- "<function fairseq.tasks.register_task.<locals>.register_task_cls(cls)>"
46
- ]
47
- },
48
- "execution_count": 16,
49
- "metadata": {},
50
- "output_type": "execute_result"
51
- }
52
- ],
53
- "source": [
54
- "# Register caption task\n",
55
- "tasks.register_task('image_gen', ImageGenTask)\n"
56
- ]
57
- },
58
- {
59
- "cell_type": "markdown",
60
- "id": "cc9c1d7b-898b-4ac4-adf3-832891d9e4be",
61
- "metadata": {},
62
- "source": [
63
- "### Load model "
64
- ]
65
- },
66
- {
67
- "cell_type": "code",
68
- "execution_count": 12,
69
- "id": "568bb6ea-eef9-4024-98e6-35e74b5ffeec",
70
- "metadata": {},
71
- "outputs": [
72
- {
73
- "name": "stdout",
74
- "output_type": "stream",
75
- "text": [
76
- "self.sample_patch_num 784\n",
77
- "self.sample_audio_patch_num None\n",
78
- "self.sample_video_patch_num None\n",
79
- "self.with_cls False\n",
80
- "Frozen image bn <class 'models.ofa.frozen_bn.FrozenBatchNorm2d'>\n",
81
- "Loading: all_resnext101\n",
82
- "use bn: <class 'torch.nn.modules.batchnorm.BatchNorm3d'>\n",
83
- "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n",
84
- "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n",
85
- "load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n",
86
- "<All keys matched successfully>\n",
87
- "RAM memory % used: 10.5\n",
88
- "RAM Used (GB): 19.574349824\n",
89
- "encoder\n",
90
- "RAM memory % used: 10.5\n",
91
- "decoder\n",
92
- "RAM memory % used: 10.5\n",
93
- "ofa\n",
94
- "Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n"
95
- ]
96
- }
97
- ],
98
- "source": [
99
- "# Load pretrained ckpt & config\n",
100
- "clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'\n",
101
- "vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'\n",
102
- "vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'\n",
103
- "\n",
104
- "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
105
- "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'\n",
106
- "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'\n",
107
- "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'\n",
108
- "\n",
109
- "# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
110
- "checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'\n",
111
- "\n",
112
- "\n",
113
- "\n",
114
- "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n",
115
- "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n",
116
- "\n",
117
- "gen_images_path='results/image_gen/'\n",
118
- "\n",
119
- "overrides = {\"bpe_dir\": \"utils/BPE\",\n",
120
- " \"eval_cider\": False,\n",
121
- " \"beam\": 24,\n",
122
- " \"max_len_b\": 1024,\n",
123
- " \"max_len_a\": 0,\n",
124
- " \"min_len\": 1024,\n",
125
- " \"sampling_topk\": 256,\n",
126
- " \"constraint_range\": \"50265,58457\",\n",
127
- " \"clip_model_path\": clip_model_path,\n",
128
- " \"vqgan_model_path\": vqgan_model_path,\n",
129
- " \"vqgan_config_path\": vqgan_config_path,\n",
130
- " \"seed\": 42,\n",
131
- " \"video_model_path\": video_model_path, \n",
132
- " \"resnet_model_path\": resnet_model_path,\n",
133
- " \"gen_images_path\":gen_images_path,\n",
134
- " \"patch_image_size\": 256,\n",
135
- " \"temperature\": 1.5,\n",
136
- " }\n",
137
- "\n",
138
- "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
139
- " utils.split_paths(checkpoint_path),\n",
140
- " arg_overrides=overrides\n",
141
- ")\n",
142
- "\n",
143
- "task.cfg.sampling_times = 2\n",
144
- "# Move models to GPU\n",
145
- "for model in models:\n",
146
- " model.eval()\n",
147
- " if use_fp16:\n",
148
- " model.half()\n",
149
- " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
150
- " model.cuda()\n",
151
- " model.prepare_for_inference_(cfg)\n",
152
- "\n",
153
- "# Initialize generator\n",
154
- "generator = task.build_generator(models, cfg.generation)\n",
155
- "\n",
156
- "# Text preprocess\n",
157
- "bos_item = torch.LongTensor([task.src_dict.bos()])\n",
158
- "eos_item = torch.LongTensor([task.src_dict.eos()])\n",
159
- "pad_idx = task.src_dict.pad()"
160
- ]
161
- },
162
- {
163
- "cell_type": "markdown",
164
- "id": "5e4a45ec-bce1-495b-8033-3b574367b360",
165
- "metadata": {},
166
- "source": [
167
- "### Preprocess"
168
- ]
169
- },
170
- {
171
- "cell_type": "code",
172
- "execution_count": 13,
173
- "id": "9f2e7e32-c9a0-43b3-bf86-2419d9f7dfe0",
174
- "metadata": {},
175
- "outputs": [],
176
- "source": [
177
- "def encode_text(text, length=None, append_bos=False, append_eos=False):\n",
178
- " s = task.tgt_dict.encode_line(\n",
179
- " line=task.bpe.encode(text),\n",
180
- " add_if_not_exist=False,\n",
181
- " append_eos=False\n",
182
- " ).long()\n",
183
- " if length is not None:\n",
184
- " s = s[:length]\n",
185
- " if append_bos:\n",
186
- " s = torch.cat([bos_item, s])\n",
187
- " if append_eos:\n",
188
- " s = torch.cat([s, eos_item])\n",
189
- " return s\n",
190
- "\n",
191
- "\n",
192
- "# Construct input for image generation task\n",
193
- "def construct_sample(query: str):\n",
194
- " code_mask = torch.tensor([True])\n",
195
- " src_text = encode_text(\" what is the complete image? caption: {}\".format(query), append_bos=True,\n",
196
- " append_eos=True).unsqueeze(0)\n",
197
- " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n",
198
- " sample = {\n",
199
- " \"id\": np.array(['42']),\n",
200
- " \"net_input\": {\n",
201
- " \"src_tokens\": src_text,\n",
202
- " \"src_lengths\": src_length,\n",
203
- " \"code_masks\": code_mask\n",
204
- " }\n",
205
- " }\n",
206
- " return sample\n",
207
- "\n",
208
- "\n",
209
- "# Function to turn FP32 to FP16\n",
210
- "def apply_half(t):\n",
211
- " if t.dtype is torch.float32:\n",
212
- " return t.to(dtype=torch.half)\n",
213
- " return t\n",
214
- "\n",
215
- "\n",
216
- "# Function for image generation\n",
217
- "def image_generation(caption):\n",
218
- " sample = construct_sample(caption)\n",
219
- " sample = utils.move_to_cuda(sample) if use_cuda else sample\n",
220
- " sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n",
221
- " print('|Start|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
222
- " with torch.no_grad():\n",
223
- " result, scores = eval_step(task, generator, models, sample)\n",
224
- "\n",
225
- " # return top-4 results (ranked by clip)\n",
226
- " images = [result[i]['image'] for i in range(4)]\n",
227
- " pic_size = 256\n",
228
- " retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))\n",
229
- " print('|FINISHED|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
230
- " for i in range(4):\n",
231
- " loc = ((i % 2) * pic_size, int(i / 2) * pic_size)\n",
232
- " retImage.paste(images[i], loc)\n",
233
- " return retImage"
234
- ]
235
- },
236
- {
237
- "cell_type": "markdown",
238
- "id": "44dec799-c5c2-4d22-8b08-7a7ca2cdf3c9",
239
- "metadata": {},
240
- "source": [
241
- "### Inference"
242
- ]
243
- },
244
- {
245
- "cell_type": "code",
246
- "execution_count": 14,
247
- "id": "02d5cd7a-8d63-4fa4-9da1-d4b79ec01445",
248
- "metadata": {},
249
- "outputs": [
250
- {
251
- "name": "stdout",
252
- "output_type": "stream",
253
- "text": [
254
- "|Start| 2023-06-29 12:57:39 A brown horse in the street\n",
255
- "|FINISHED| 2023-06-29 12:59:03 A brown horse in the street\n"
256
- ]
257
- }
258
- ],
259
- "source": [
260
- "query = \"A brown horse in the street\"\n",
261
- "# query = \"Cattle grazing on grass near a lake surrounded by mountain.\"\n",
262
- "# query = 'A street scene with a double-decker bus on the road.'\n",
263
- "# query = 'A path.'\n",
264
- "\n",
265
- "\n",
266
- "retImage = image_generation(query)\n"
267
- ]
268
- },
269
- {
270
- "cell_type": "code",
271
- "execution_count": null,
272
- "id": "1a8a1654-1f17-41c7-b410-c7491a96dcee",
273
- "metadata": {},
274
- "outputs": [],
275
- "source": [
276
- "retImage.save(f'{query}.png')"
277
- ]
278
- }
279
- ],
280
- "metadata": {
281
- "kernelspec": {
282
- "display_name": "ofa",
283
- "language": "python",
284
- "name": "ofa"
285
- },
286
- "language_info": {
287
- "codemirror_mode": {
288
- "name": "ipython",
289
- "version": 3
290
- },
291
- "file_extension": ".py",
292
- "mimetype": "text/x-python",
293
- "name": "python",
294
- "nbconvert_exporter": "python",
295
- "pygments_lexer": "ipython3",
296
- "version": "3.7.4"
297
- }
298
- },
299
- "nbformat": 4,
300
- "nbformat_minor": 5
301
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README_EncouragingLoss.md DELETED
@@ -1,34 +0,0 @@
1
- # Finetuning with Encouraging Loss (EL)
2
- Below we provide methods for finetuning with label smoothed encouraging loss proposed in [_Well-classified Examples are Underestimated in Classification with Deep Neural Networks_](https://arxiv.org/pdf/2110.06537.pdf) on different downstream tasks.
3
- The implementation is in [label_smoothed_encouraging_loss.py](criterions/label_smoothed_encouraging_loss.py).
4
- You can set the `--criterion` to `adjust_label_smoothed_encouraging_loss` to use it. This criterion has a hyper-parameter `--log-end`.
5
- `--log-end < 1` results in a approximated and conservative version of the full encouraging loss.
6
- A high log_end will more strongly weaken the gradient vanishing, enhance the modeling of the data, and increase the growth rate of the margin, but it will also bring a larger gradient norm, which will bring challenges to the existing optimization system.
7
- We recommend higher log_end for cases with higher performance, and 0.75 or 0.5 as your first try.
8
- ## Image Captioning
9
- We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
10
-
11
- <details>
12
- <summary><b>Finetuning</b></summary>
13
- <p>
14
- We propose two scripts for stage1. </b>
15
- </p>
16
- <pre>
17
- cd run_scripts/caption
18
- nohup sh train_caption_stage1_el.sh > train_stage1_el.out & # stage 1, train with encouraging loss, expected cider 1.403
19
- nohup sh train_caption_stage1_el_db.sh > train_stage1_el.out & # stage 1, train with encouraging loss, and drop best examples, expected cider 1.404
20
- </pre>
21
- </details>
22
-
23
- ## Referring Expression Comprehension
24
- We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
25
- <details>
26
- <summary><b>Finetuning</b></summary>
27
- <pre>
28
- cd run_scripts/refcoco
29
- nohup sh train_refcoco_el.sh > train_refcoco_el.out & # finetune for refcoco
30
- nohup sh train_refcocoplus_el.sh > train_refcocoplus_el.out & # finetune for refcoco+
31
- nohup sh train_refcocog_el.sh > train_refcocog_el.out & # finetune for refcocog
32
- </pre>
33
- </details>
34
- Evaluation is also the same as the default setting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
VG.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
VQA.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
Video_Captioning.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
checkpoints.md DELETED
@@ -1,36 +0,0 @@
1
- # Checkpoints
2
-
3
- We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.
4
-
5
- ## Pretraining
6
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
7
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a> (~470M parameters)
8
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt"> Pre-trained checkpoint (OFA-Base) </a> (~180M parameters)
9
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt"> Pre-trained checkpoint (OFA-Medium) </a> (~93M parameters)
10
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt"> Pre-trained checkpoint (OFA-Tiny) </a> (~33M parameters)
11
-
12
- ## Finetuning (OFA-Huge)
13
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_huge_best.pt"> Finetuned checkpoint for Caption on COCO </a>
14
-
15
- ## Finetuning (OFA-Large)
16
-
17
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
18
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_stage1_best.pt"> Finetuned checkpoint for Caption on COCO During Stage1 Finetuning </a>
19
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
20
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
21
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
22
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_large_best.pt"> Finetuned checkpoint for VQAv2 </a>
23
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_large_best.pt"> Finetuned checkpoint for SNLI-VE </a>
24
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_large_best.zip"> Finetuned checkpoint for Text-to-Image Generation on COCO && CLIP checkpoint && VQGAN checkpoint </a>
25
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/imagenet_1k_large_best.pt"> Finetuned checkpoint for ImageNet-1K </a>
26
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/gigaword_large_best.pt"> Finetuned checkpoint for Gigaword </a>
27
-
28
-
29
- ## Finetuning (OFA-Base)
30
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_base_best.pt"> Finetuned base checkpoint for Caption on COCO </a>
31
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_base_best.pt"> Finetuned base checkpoint for RefCOCO </a>
32
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_base_best.pt"> Finetuned base checkpoint for RefCOCO+ </a>
33
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_base_best.pt"> Finetuned base checkpoint for RefCOCOg </a>
34
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_base_best.pt"> Finetuned base checkpoint for VQAv2 </a>
35
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_base_best.pt"> Finetuned base checkpoint for SNLI-VE </a>
36
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_base_best.pt"> Finetuned base checkpoint for Text-to-Image Generation on COCO </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints_cn.md DELETED
@@ -1,82 +0,0 @@
1
- # Checkpoints (OFA-CN)
2
-
3
- We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
4
- <br>
5
-
6
- ## Checkpoints
7
- Below we provide the links for downloading the Chinese OFA checkpoints.
8
-
9
- ### Pretraining
10
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
11
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)
12
-
13
- ### Finetuning (OFA-Large)
14
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
15
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
16
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
17
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
18
-
19
- ### Finetuning (OFA-Base)
20
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
21
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
22
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
23
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
24
- <br>
25
-
26
- ## Model Card
27
- Below we provide the basic information of the base-size and large-size OFA-CN.
28
-
29
- <table border="1" width="100%">
30
- <tr align="center">
31
- <th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
32
- </tr>
33
- <tr align="center">
34
- <td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
35
- </tr>
36
- <tr align="center">
37
- <td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
38
- </tr>
39
- </tr>
40
- </table>
41
- <br>
42
-
43
- ## Results
44
- Below we provide the results of OFA-CN and the baselines for comparison.
45
-
46
- ### [MUGE Caption]("https://tianchi.aliyun.com/muge")
47
- <table border="1" width="100%">
48
- <tr align="center">
49
- <td>Model</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
50
- </tr>
51
- <tr align="center">
52
- <td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
53
- </tr>
54
- <tr align="center">
55
- <td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
56
- </tr>
57
- <tr align="center">
58
- <td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
59
- </tr>
60
- <tr align="center">
61
- <td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
62
- </tr>
63
- </table>
64
-
65
- ### RefCOCO-CN Series
66
- <table border="1" width="100%">
67
- <tr align="center">
68
- <td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
69
- </tr>
70
- <tr align="center">
71
- <td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
72
- </tr>
73
- <tr align="center">
74
- <td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
75
- </tr>
76
- <tr align="center">
77
- <td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
78
- </tr>
79
- </table>
80
- <br>
81
-
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab.md DELETED
@@ -1,9 +0,0 @@
1
- # Colab Notebooks
2
-
3
- We provide Colab notebooks of different downstream tasks for you guys to enjoy OFA. See below.
4
-
5
- * [Image Captioning in Huggingface Transformers](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
6
- * [Generic Interface](https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m?usp=sharing#scrollTo=s9Vni6YUZOpC) (using different instructions to perform various tasks with just one model.)
7
- * [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
8
- * [Referring Expression Comprehension](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
9
- * [Open-Domain Visual Question Answering](https://colab.research.google.com/drive/14v6OQe_MxV_HMnsiKfnEeMR1UMqhzZNb?usp=sharing)
 
 
 
 
 
 
 
 
 
 
datasets.md DELETED
@@ -1,44 +0,0 @@
1
- # Datasets
2
-
3
- We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
4
-
5
- ## Pretraining
6
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip"> A small subset of the pretraining data </a>
7
-
8
- The pretraining datasets used in OFA are all publicly available. Here we provide the public links to these data, it is recommended that you download the data from the links first, and then process the downloaded dataset into a similar format as the examples we provided.
9
- - _CC12M_: https://github.com/google-research-datasets/conceptual-12m
10
- - _CC3M_: https://github.com/google-research-datasets/conceptual-captions
11
- - _SBU_: https://www.cs.virginia.edu/~vicente/sbucaptions
12
- - _COCO_: https://cocodataset.org/#home
13
- - _VG_: https://visualgenome.org/
14
- - _VQAv2_: https://visualqa.org/
15
- - _GQA_: https://cs.stanford.edu/people/dorarad/gqa/about.html
16
- - _RefCOCO_/_RefCOCO+_/RefCOCOg: https://github.com/lichengunc/refer
17
- - _OpenImages_: https://storage.googleapis.com/openimages/web/index.html
18
- - _Object365_: https://www.objects365.org/overview.html
19
- - _YFCC100M (subset)_: https://github.com/openai/CLIP/blob/main/data/yfcc100m.md
20
- - _ImageNet-21K_: https://image-net.org/index.php
21
- - _Pile_: https://pile.eleuther.ai
22
-
23
- ## Vision & Language Tasks
24
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
25
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
26
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
27
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
28
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/vqa_data/vqa_data.zip"> Dataset for VQAv2 </a> (we have also provided chunked parts of the dataset files for more convenient downloading, please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>)
29
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/snli_ve_data/snli_ve_data.zip"> Dataset for SNLI-VE </a>
30
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen.zip"> Dataset for Text-to-Image Genearion </a>
31
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen_origin_id.zip"> Dataset for Text-to-Image Genearion (with original id) </a>
32
-
33
- ## Vision Tasks
34
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/imagenet_1k_data/imagenet_1k_data.zip"> Dataset for ImageNet-1K </a>
35
-
36
- ## Language Tasks
37
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/cola_data.zip"> Dataset for COLA </a>
38
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mnli_data.zip"> Dataset for MNLI </a>
39
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mrpc_data.zip"> Dataset for MRPC </a>
40
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qnli_data.zip"> Dataset for QNLI </a>
41
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qqp_data.zip"> Dataset for QQP </a>
42
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/rte_data.zip"> Dataset for RTE </a>
43
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/sst2_data.zip"> Dataset for SST2 </a>
44
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/gigaword_data/gigaword_data.zip"> Dataset for Gigaword </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluate.py DELETED
@@ -1,239 +0,0 @@
1
- #!/usr/bin/env python3 -u
2
- # Copyright 2022 The OFA-Sys Team.
3
- # All rights reserved.
4
- # This source code is licensed under the Apache 2.0 license
5
- # found in the LICENSE file in the root directory.
6
-
7
- import logging
8
- import os
9
- import sys
10
-
11
- import numpy as np
12
- import torch
13
- from fairseq import distributed_utils, options, tasks, utils
14
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
15
- from fairseq.logging import progress_bar
16
- from fairseq.utils import reset_logging
17
- from omegaconf import DictConfig
18
-
19
- from utils import checkpoint_utils
20
- from utils.eval_utils import eval_step, merge_results
21
- from utils.zero_shot_utils import zero_shot_step
22
-
23
- logging.basicConfig(
24
- format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25
- datefmt="%Y-%m-%d %H:%M:%S",
26
- level=os.environ.get("LOGLEVEL", "INFO").upper(),
27
- stream=sys.stdout,
28
- )
29
- logger = logging.getLogger("ofa.evaluate")
30
-
31
- from utils.utils import print_trainable_params_percentage, setup_for_distributed
32
-
33
- def apply_half(t):
34
- if t.dtype is torch.float32:
35
- return t.to(dtype=torch.half)
36
- return t
37
-
38
-
39
- def main(cfg: DictConfig, **kwargs):
40
- utils.import_user_module(cfg.common)
41
-
42
- setup_for_distributed(distributed_utils.is_master(cfg.distributed_training))
43
-
44
- reset_logging()
45
- # logger.info(cfg)
46
-
47
- assert (
48
- cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
49
- ), "Must specify batch size either with --max-tokens or --batch-size"
50
-
51
- # Fix seed for stochastic decoding
52
- if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
53
- np.random.seed(cfg.common.seed)
54
- utils.set_torch_seed(cfg.common.seed)
55
-
56
- use_fp16 = cfg.common.fp16
57
- use_cuda = torch.cuda.is_available() and not cfg.common.cpu
58
-
59
- if use_cuda:
60
- torch.cuda.set_device(cfg.distributed_training.device_id)
61
-
62
- # Load ensemble
63
- overrides = eval(cfg.common_eval.model_overrides)
64
- # Deal with beam-search / all-candidate VQA eval
65
- if cfg.task._name == "vqa_gen":
66
- overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
67
-
68
- logger.info("loading model(s) from {}".format(cfg.common_eval.path))
69
-
70
- # print("cfg", cfg)
71
- # print(kwargs)
72
- # cfg.model.num_frames = kwargs["num_frames"]
73
- # cfg.model.patch_frame_size = kwargs["patch_frame_size"]
74
- # print("cfg.model", cfg.model)
75
- # strict = getattr(kwargs, 'strict', True)
76
- strict = kwargs['strict']
77
- logger.info('load checkpoint, strict:{}'.format(strict))
78
-
79
- if kwargs["zero_shot"]:
80
- for arg_name, arg_val in overrides.items():
81
- cfg.task[arg_name] = arg_val
82
- # print("Zero-shot eval", cfg.task, cfg)
83
-
84
- if hasattr(cfg.task, "add_caption"):
85
- cfg.task.add_caption = False
86
- print("cfg.task", cfg.task)
87
- task = tasks.setup_task(cfg.task)
88
- # cfg.criterion.sample_patch_num = 776
89
-
90
-
91
- models, saved_cfg = checkpoint_utils.load_model_ensemble(
92
- utils.split_paths(cfg.common_eval.path),
93
- arg_overrides=overrides,
94
- task=task,
95
- suffix=cfg.checkpoint.checkpoint_suffix,
96
- strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
97
- num_shards=cfg.checkpoint.checkpoint_shard_count,
98
- )
99
- for m in models:
100
- m.encoder.sample_patch_num = 776
101
- saved_cfg.task = cfg.task
102
- # print("saved_cfg", saved_cfg)
103
- else:
104
- models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
105
- utils.split_paths(cfg.common_eval.path),
106
- arg_overrides=overrides,
107
- suffix=cfg.checkpoint.checkpoint_suffix,
108
- strict=((cfg.checkpoint.checkpoint_shard_count == 1) and strict),
109
- num_shards=cfg.checkpoint.checkpoint_shard_count,
110
- )
111
-
112
-
113
-
114
- # task.cfg['evaluate_cfg'] = cfg.task
115
- # print(task.cfg)
116
- kwargs['evaluate_cfg'] = cfg.task
117
- # print(kwargs)
118
- # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
119
- task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
120
-
121
- # Move models to GPU
122
- for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
123
- if kwargs['ema_eval']:
124
- logger.info("loading EMA weights from {}".format(ckpt_path))
125
- model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
126
- model.eval()
127
- print("use fp16", use_fp16)
128
- if use_fp16:
129
-
130
- model.half()
131
- if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
132
- model.cuda()
133
- model.prepare_for_inference_(cfg)
134
-
135
- # Load dataset (possibly sharded)
136
- itr = task.get_batch_iterator(
137
- dataset=task.dataset(cfg.dataset.gen_subset),
138
- max_tokens=cfg.dataset.max_tokens,
139
- max_sentences=cfg.dataset.batch_size,
140
- max_positions=utils.resolve_max_positions(
141
- task.max_positions(), *[m.max_positions() for m in models]
142
- ),
143
- ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
144
- required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
145
- seed=cfg.common.seed,
146
- num_shards=cfg.distributed_training.distributed_world_size,
147
- shard_id=cfg.distributed_training.distributed_rank,
148
- num_workers=cfg.dataset.num_workers,
149
- data_buffer_size=cfg.dataset.data_buffer_size,
150
- ).next_epoch_itr(shuffle=False)
151
- progress = progress_bar.progress_bar(
152
- itr,
153
- log_format=cfg.common.log_format,
154
- log_interval=cfg.common.log_interval,
155
- default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
156
- )
157
-
158
- # Initialize generator
159
- generator = task.build_generator(models, cfg.generation)
160
-
161
- results = []
162
- score_sum = torch.FloatTensor([0]).cuda()
163
- score_cnt = torch.FloatTensor([0]).cuda()
164
-
165
- score_sum_list = []
166
- score_cnt_list = []
167
- for sample in progress:
168
- if "net_input" not in sample:
169
- continue
170
- sample = utils.move_to_cuda(sample) if use_cuda else sample
171
- sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
172
- with torch.no_grad():
173
- if kwargs["zero_shot"] and kwargs['noconstraints']:
174
- result, scores = zero_shot_step(task, generator, models, sample)
175
- else:
176
- result, scores = eval_step(task, generator, models, sample, **kwargs)
177
- ### else refcoco res, score, other_scores
178
-
179
- # print(scores)
180
- scalar = False
181
- if isinstance(scores, list):
182
- if not isinstance(scores[0], list):
183
- try:
184
- tmp = sum(scores[0])
185
- scalar=False
186
- except:
187
- scalar=True
188
- # print(scalar)
189
- # print(sum(scores[0]))
190
- if isinstance(scores, list) and not scalar:
191
- names = result[0]
192
- result = result[1]
193
- if len(score_sum_list) == 0:
194
- score_sum_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
195
- score_cnt_list = [torch.FloatTensor([0]).cuda() for i in range(len(scores))]
196
-
197
- for i in range(len(scores)):
198
-
199
-
200
- score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
201
- score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
202
- else:
203
- for i in range(len(scores)):
204
- score_sum_list[i] += sum(scores[i]) if scores[i] is not None else 0
205
- score_cnt_list[i] += len(scores[i]) if scores[i] is not None else 0
206
- else:
207
- score_sum += sum(scores) if scores is not None else 0
208
- score_cnt += len(scores) if scores is not None else 0
209
- results += result
210
- progress.log({"sentences": sample["nsentences"]})
211
-
212
-
213
- ### merge per metric
214
- if len(score_sum_list) > 0:
215
- print(names, len(score_sum_list))
216
- for i in range(len(score_sum_list)):
217
- print(names[i])
218
- merge_results(task, cfg, logger, score_cnt_list[i], score_sum_list[i], results)
219
- else:
220
- merge_results(task, cfg, logger, score_cnt, score_sum, results)
221
-
222
-
223
- def cli_main():
224
- parser = options.get_generation_parser()
225
- parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
226
- parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
227
- parser.add_argument("--zero-shot", action='store_true')
228
- parser.add_argument("--strict", action='store_false')
229
- parser.add_argument("--noconstraints", action='store_true')
230
- args = options.parse_args_and_arch(parser)
231
- cfg = convert_namespace_to_omegaconf(args)
232
- distributed_utils.call_main(
233
- cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval,
234
- zero_shot=args.zero_shot, strict=args.strict, noconstraints=args.noconstraints
235
- )
236
-
237
-
238
- if __name__ == "__main__":
239
- cli_main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelscope.md DELETED
@@ -1,23 +0,0 @@
1
- # ModelScope
2
-
3
- ModelScope is a new platform that provides "Model-As-A-Service", where users can use state-of-the-art models with the lowest costs of efforts as possible. We have released:
4
- * The pretrained and finetuned **OFA** models
5
- * **Chinese CLIP** (the CLIP pretrained Chinese data, which was previously released in our organization)
6
-
7
- on the platform, including the English and Chinese ones. Feel free to check these models and use them with ModelScope, and also feel free to send us feedbacks to help us improve the product.
8
-
9
- ## Chinese
10
- * Chinese CLIP \[[Base](https://www.modelscope.cn/#/models/damo/multi-modal_clip-vit-base-patch16_zh/summary) | [Large](https://www.modelscope.cn/#/models/damo/multi-modal_clip-vit-large-patch14_zh/summary)\]
11
- * Finetuned OFA on Visual Grounding (RefCOCO) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_large_zh/summary)\]
12
-
13
- ## English
14
- * Finetuned OFA on Image Captioning \[[Large](https://www.modelscope.cn/#/models/damo/ofa_image-caption_coco_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_image-caption_coco_distilled_en/summary)\]
15
- * Finetuned OFA on Text-to-Image Generation \[[Large](https://www.modelscope.cn/#/models/damo/ofa_text-to-image-synthesis_coco_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_distilled_en/summary)\]
16
- * Finetuned OFA on Visual Question Answering \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-question-answering_pretrain_large_en/summary)\]
17
- * Finetuned OFA on Visual Grounding (RefCOCO) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_large_en/summary)\]
18
- * Finetuned OFA on Visual Entailment \[[Large](https://www.modelscope.cn/#/models/damo/ofa_visual-entailment_snli-ve_large_en/summary) | [Distill](https://modelscope.cn/#/models/damo/ofa_visual-entailment_snli-ve_distilled_v2_en/summary)\]
19
- * Finetuned OFA on Summarization (Gigaword) \[[Large](https://www.modelscope.cn/#/models/damo/ofa_summarization_gigaword_large_en/summary)\]
20
- * Finetuned OFA on Natural Language Entailment (MNLI, can be used to finetune on the GLUE benchmark) \[[Large](https://modelscope.cn/#/models/damo/ofa_text-classification_mnli_large_en/summary)\]
21
- * Finetuned OFA on Image Classification (ImageNet-1k) \[[Large](https://modelscope.cn/#/models/damo/ofa_image-classification_imagenet_large_en/summary)\]
22
-
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ofa_test.ipynb DELETED
@@ -1,2499 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Import"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": 1,
13
- "metadata": {},
14
- "outputs": [],
15
- "source": [
16
- "%load_ext autoreload\n",
17
- "%autoreload 2"
18
- ]
19
- },
20
- {
21
- "cell_type": "code",
22
- "execution_count": 2,
23
- "metadata": {},
24
- "outputs": [],
25
- "source": [
26
- "import os\n",
27
- "import json \n",
28
- "import torch\n",
29
- "# import clip\n",
30
- "from PIL import Image\n",
31
- "# import sng_parser\n",
32
- "from tqdm import tqdm \n",
33
- "import codecs\n",
34
- "import numpy as np\n",
35
- "import csv\n",
36
- "import sys\n",
37
- "\n",
38
- "from io import BytesIO\n",
39
- "import base64"
40
- ]
41
- },
42
- {
43
- "cell_type": "code",
44
- "execution_count": null,
45
- "metadata": {},
46
- "outputs": [],
47
- "source": []
48
- },
49
- {
50
- "cell_type": "markdown",
51
- "metadata": {},
52
- "source": [
53
- "# Data"
54
- ]
55
- },
56
- {
57
- "cell_type": "markdown",
58
- "metadata": {},
59
- "source": [
60
- "## Explore"
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": 16,
66
- "metadata": {},
67
- "outputs": [
68
- {
69
- "name": "stderr",
70
- "output_type": "stream",
71
- "text": [
72
- "100it [00:00, 14325.30it/s]\n"
73
- ]
74
- }
75
- ],
76
- "source": [
77
- "csv.field_size_limit(sys.maxsize)\n",
78
- "\n",
79
- "# path_data = '/data/mshukor/data/ofa/pretrain_example/vision_language_examples.tsv'\n",
80
- "# selected_cols='0,1,2,3,4,5,6,7'\n",
81
- "\n",
82
- "# path_data = '/data/mshukor/data/ofa/pretrain_example/detection_examples.tsv'\n",
83
- "# selected_cols='0,1,2'\n",
84
- "\n",
85
- "# path_data = '/data/mshukor/data/ofa/pretrain_example/image_examples.tsv'\n",
86
- "# selected_cols='0,1,2'\n",
87
- "\n",
88
- "path_data = '/data/mshukor/data/ofa/pretrain_example/text_examples.tsv'\n",
89
- "selected_cols='0,1'\n",
90
- "\n",
91
- "data_example = []\n",
92
- "\n",
93
- "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
94
- "\n",
95
- "with open(path_data) as file:\n",
96
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
97
- " for line in tqdm(tsv_file):\n",
98
- "\n",
99
- " d = [line[i] for i in selected_col_ids]\n",
100
- "# print(d)\n",
101
- " data_example.append(d)\n",
102
- " \n"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": 17,
108
- "metadata": {},
109
- "outputs": [
110
- {
111
- "data": {
112
- "text/plain": [
113
- "['100',\n",
114
- " '...please depart this field clean unless you might be answering the question. do not ask questions you already know the answer to. thanks.retrieved from \" \" ad blocker interference detected! wikia is a single-to-usefulness web site that makes cash from promoting. we\\'ve a experience for viewers using ad blockers wikia shouldn\\'t be if youve made further modifications. take away the custom ad blocker (s) and the page leave timber as expected. categories : un-answered questionsadd class cancelsave per the reddit twine, flac files will be synced to an ios gadget via icloud impel, then accessed through thenew information utility , which will allow for local playback of the excessive-high quality audio files straight by the side of the device. if , it could stamp the first time that apple has offered help for the favored flac format an ios gadget.']"
115
- ]
116
- },
117
- "execution_count": 17,
118
- "metadata": {},
119
- "output_type": "execute_result"
120
- }
121
- ],
122
- "source": [
123
- "line"
124
- ]
125
- },
126
- {
127
- "cell_type": "code",
128
- "execution_count": 12,
129
- "metadata": {},
130
- "outputs": [
131
- {
132
- "data": {
133
- "text/plain": [
134
- "['7',\n",
135
- " 'perhaps the clearest indication of who won and lost came quickly on the heels of the event itself: the democratic post-debate message was that joe biden scored a clear win; the republican message was that joe biden was too mean to paul ryan. the former is a boast of success; the latter is an excuse for failure. in the larger context, it’s hard to overstate how much democrats needed a shot in the arm like this. the surface-level goals of any vice presidential debate is for the candidates to demonstrate a capacity to step up in the event of a crisis, while defending their ticket’s agenda and knocking their rivals’ agenda. but for biden, the overarching benefit was about the basic morale of his party with less than four weeks to go until election day: he wanted to give democratic voters something to feel good about, and he did. who the hell am i! i’m a liberal that is extreme in some ways and not in others. i support president obama and make no apologies for it. i think he has done a phenomenal job, especially when you consider that he inherited a huge mess and has faced unprecedented opposition from a lazy & desperate republican party. i’m a film producer/director/editor, adjunct professor, technician, media critic and photographer when i’m not reading left wing blogs and typing on this one. – on twitter @extremeliberal or email at liberalforreal (at) gmail.com own an important part of american history! cicely tyson narrates this award winning documentary that tells the story of african american migration from the old south to the prosperous north. winner of 5 awards including \"best film\" at the astoria international film festival, the \"paul robeson award\" at the newark black film festival and \"best film relating to the black experience\" at the xxv international black cinema berlin/germany!']"
136
- ]
137
- },
138
- "execution_count": 12,
139
- "metadata": {},
140
- "output_type": "execute_result"
141
- }
142
- ],
143
- "source": [
144
- "data_[6]"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": null,
150
- "metadata": {},
151
- "outputs": [],
152
- "source": [
153
- "tasks = set()\n",
154
- "datasets = set()\n",
155
- "\n",
156
- "for d in data:\n",
157
- " tasks.add(d[-1])\n",
158
- " datasets.add(d[-2])"
159
- ]
160
- },
161
- {
162
- "cell_type": "code",
163
- "execution_count": null,
164
- "metadata": {},
165
- "outputs": [],
166
- "source": [
167
- "int(data[10][0])"
168
- ]
169
- },
170
- {
171
- "cell_type": "code",
172
- "execution_count": null,
173
- "metadata": {},
174
- "outputs": [],
175
- "source": [
176
- "# len(data[0][2:][0].split(' '))\n",
177
- "# len(data[0][1])\n",
178
- "text = data[10][1]\n",
179
- "print(len(text.split(' ')))\n",
180
- "print(len(text))\n",
181
- "from nltk.tokenize import word_tokenize\n",
182
- "len(word_tokenize(text))"
183
- ]
184
- },
185
- {
186
- "cell_type": "code",
187
- "execution_count": null,
188
- "metadata": {},
189
- "outputs": [],
190
- "source": [
191
- "text"
192
- ]
193
- },
194
- {
195
- "cell_type": "code",
196
- "execution_count": null,
197
- "metadata": {},
198
- "outputs": [],
199
- "source": [
200
- "from nltk.tokenize.treebank import TreebankWordDetokenizer\n",
201
- "TreebankWordDetokenizer().detokenize(word_tokenize(text))"
202
- ]
203
- },
204
- {
205
- "cell_type": "code",
206
- "execution_count": null,
207
- "metadata": {},
208
- "outputs": [],
209
- "source": [
210
- "key = 'refcoco_train'\n",
211
- "index = -2\n",
212
- "for d in data:\n",
213
- " if d[index] == key:\n",
214
- " print(d[2:])\n",
215
- "# break"
216
- ]
217
- },
218
- {
219
- "cell_type": "code",
220
- "execution_count": null,
221
- "metadata": {},
222
- "outputs": [],
223
- "source": [
224
- "d[4].split(',')\n",
225
- "str([287.0, 127.0, 340.0, 162.0])\n",
226
- "'{:.2f},{:.2f},{:.2f},{:.2f}'.format(287.0, 127.0, 340.0, 162.0)"
227
- ]
228
- },
229
- {
230
- "cell_type": "code",
231
- "execution_count": null,
232
- "metadata": {},
233
- "outputs": [],
234
- "source": [
235
- "print(len(data))\n",
236
- "data[0]"
237
- ]
238
- },
239
- {
240
- "cell_type": "code",
241
- "execution_count": null,
242
- "metadata": {},
243
- "outputs": [],
244
- "source": [
245
- "all_captions_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/all_captions.txt'\n",
246
- "all_objects_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/object.txt'\n",
247
- "\n",
248
- "all_object_list = [\n",
249
- " row.strip() for row in open(all_objects_path) if row.strip() != ''\n",
250
- "]\n",
251
- "all_caption_list = [\n",
252
- " row.strip() for row in open(all_captions_path) if row.strip() != ''\n",
253
- "]\n"
254
- ]
255
- },
256
- {
257
- "cell_type": "code",
258
- "execution_count": null,
259
- "metadata": {},
260
- "outputs": [],
261
- "source": [
262
- "len(all_object_list)"
263
- ]
264
- },
265
- {
266
- "cell_type": "code",
267
- "execution_count": null,
268
- "metadata": {},
269
- "outputs": [],
270
- "source": [
271
- "len(all_caption_list)"
272
- ]
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": null,
277
- "metadata": {},
278
- "outputs": [],
279
- "source": [
280
- "all_object_list[:10]"
281
- ]
282
- },
283
- {
284
- "cell_type": "code",
285
- "execution_count": null,
286
- "metadata": {},
287
- "outputs": [],
288
- "source": [
289
- "all_caption_list[:10]"
290
- ]
291
- },
292
- {
293
- "cell_type": "code",
294
- "execution_count": null,
295
- "metadata": {},
296
- "outputs": [],
297
- "source": [
298
- "json_path = '/data/mshukor/data/ofa/pretrain_example/negative_sample/type2ans.json'\n",
299
- "type2ans = json.load(open(json_path,'r'))"
300
- ]
301
- },
302
- {
303
- "cell_type": "code",
304
- "execution_count": null,
305
- "metadata": {},
306
- "outputs": [],
307
- "source": [
308
- "type2ans.keys()\n",
309
- "# type2ans['what color is the']"
310
- ]
311
- },
312
- {
313
- "cell_type": "markdown",
314
- "metadata": {},
315
- "source": [
316
- "### Our data"
317
- ]
318
- },
319
- {
320
- "cell_type": "code",
321
- "execution_count": null,
322
- "metadata": {},
323
- "outputs": [],
324
- "source": []
325
- },
326
- {
327
- "cell_type": "code",
328
- "execution_count": 3,
329
- "metadata": {},
330
- "outputs": [
331
- {
332
- "name": "stderr",
333
- "output_type": "stream",
334
- "text": [
335
- "181767it [00:02, 70482.72it/s]\n"
336
- ]
337
- }
338
- ],
339
- "source": [
340
- "# path_data = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
341
- "# selected_cols='0,1'\n",
342
- "\n",
343
- "path_data = '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
344
- "selected_cols='0,1,2'\n",
345
- "\n",
346
- "# path_data = '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
347
- "# selected_cols='0,1,2,3,4,5,6,7'\n",
348
- "\n",
349
- "# path_data = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
350
- "# selected_cols='0,1,2'\n",
351
- "\n",
352
- "data = []\n",
353
- "\n",
354
- "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
355
- "\n",
356
- "with open(path_data) as file:\n",
357
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
358
- " for line in tqdm(tsv_file):\n",
359
- "\n",
360
- " d = [line[i] for i in selected_col_ids]\n",
361
- "# print(d)\n",
362
- " data.append(d)"
363
- ]
364
- },
365
- {
366
- "cell_type": "code",
367
- "execution_count": 21,
368
- "metadata": {},
369
- "outputs": [],
370
- "source": [
371
- "# new_data = []\n",
372
- "# for d in data:\n",
373
- "# label_list = d[2].strip().split('&&')\n",
374
- "# new_label_list = []\n",
375
- "# for label in label_list:\n",
376
- "# lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
377
- " \n",
378
- "# if any([\"&\" in l for l in lab]):\n",
379
- "# lab = [remove_special(l) for l in lab]\n",
380
- " \n",
381
- "# print(lab)\n",
382
- "# lab_ = lab + label.strip().split(',', 5)[4:]\n",
383
- "# lab_ = ','.join(lab_)\n",
384
- "# new_label_list.append(lab_)\n",
385
- "# new_label_list = ['&&'.join(new_label_list)]\n",
386
- "# new_data.append(d[:2]+new_label_list)"
387
- ]
388
- },
389
- {
390
- "cell_type": "code",
391
- "execution_count": 24,
392
- "metadata": {},
393
- "outputs": [
394
- {
395
- "name": "stdout",
396
- "output_type": "stream",
397
- "text": [
398
- "['&40.000', '155.000', '44.000', '164.000']\n"
399
- ]
400
- }
401
- ],
402
- "source": [
403
- "for d in data:\n",
404
- " label_list = d[2].strip().split('&&')\n",
405
- " new_label_list = []\n",
406
- " for label in label_list:\n",
407
- " lab = label.strip().split(',', 5)[:4] # x0, y0, x1, y1, cat_id, cat\n",
408
- " \n",
409
- " if any([\"&\" in l for l in lab]):\n",
410
- " print(lab)\n",
411
- " # lab = [remove_special(l) for l in lab]"
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": 27,
417
- "metadata": {},
418
- "outputs": [
419
- {
420
- "data": {
421
- "text/plain": [
422
- "['0',\n",
423
- " 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
424
- " '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
425
- ]
426
- },
427
- "execution_count": 27,
428
- "metadata": {},
429
- "output_type": "execute_result"
430
- }
431
- ],
432
- "source": [
433
- "new_data[0]"
434
- ]
435
- },
436
- {
437
- "cell_type": "code",
438
- "execution_count": 8,
439
- "metadata": {},
440
- "outputs": [
441
- {
442
- "data": {
443
- "text/plain": [
444
- "['1.020,279.960,534.110,480.000,67,dining table',\n",
445
- " '90.670,271.490,262.510,480.000,62,chair',\n",
446
- " '233.290,270.450,403.610,473.810,62,chair',\n",
447
- " '367.820,264.270,506.970,480.000,62,chair',\n",
448
- " '476.760,261.030,596.490,462.740,62,chair',\n",
449
- " '263.030,174.370,417.670,299.400,64,potted plant',\n",
450
- " '539.330,290.160,640.000,469.210,62,chair',\n",
451
- " '10.790,260.030,125.120,384.070,62,chair',\n",
452
- " '560.800,413.950,639.090,479.200,67,dining table',\n",
453
- " '20.540,376.760,103.780,431.890,62,chair',\n",
454
- " '1.080,373.210,32.360,480.000,62,chair',\n",
455
- " '298.200,235.170,381.210,269.250,86,vase',\n",
456
- " '152.170,256.670,230.580,285.780,62,chair',\n",
457
- " '364.400,256.570,417.060,283.210,62,chair',\n",
458
- " '296.780,277.790,329.260,289.780,84,book',\n",
459
- " '292.800,289.310,314.210,300.650,84,book',\n",
460
- " '285.800,257.460,299.770,273.600,62,chair']"
461
- ]
462
- },
463
- "execution_count": 8,
464
- "metadata": {},
465
- "output_type": "execute_result"
466
- }
467
- ],
468
- "source": [
469
- "label_list"
470
- ]
471
- },
472
- {
473
- "cell_type": "code",
474
- "execution_count": 12,
475
- "metadata": {},
476
- "outputs": [],
477
- "source": [
478
- "def remove_special(input_string):\n",
479
- " final_string = \"\"\n",
480
- " for character in input_string:\n",
481
- " if character == \" \":\n",
482
- " final_string = final_string + character\n",
483
- " else:\n",
484
- " if(character.isalnum()):\n",
485
- " final_string = final_string + character\n",
486
- " return final_string"
487
- ]
488
- },
489
- {
490
- "cell_type": "code",
491
- "execution_count": 60,
492
- "metadata": {},
493
- "outputs": [
494
- {
495
- "name": "stderr",
496
- "output_type": "stream",
497
- "text": [
498
- "100%|█| 5593207/5593207 [00:35<00:0\n"
499
- ]
500
- }
501
- ],
502
- "source": [
503
- "for d in tqdm(data):\n",
504
- " label = d[2]\n",
505
- " d[2] = remove_special(caption)\n"
506
- ]
507
- },
508
- {
509
- "cell_type": "code",
510
- "execution_count": 4,
511
- "metadata": {},
512
- "outputs": [
513
- {
514
- "data": {
515
- "text/plain": [
516
- "['0',\n",
517
- " 'coco/train2014/COCO_train2014_000000057870.jpg',\n",
518
- " '1.020,279.960,534.110,480.000,67,dining table&&90.670,271.490,262.510,480.000,62,chair&&233.290,270.450,403.610,473.810,62,chair&&367.820,264.270,506.970,480.000,62,chair&&476.760,261.030,596.490,462.740,62,chair&&263.030,174.370,417.670,299.400,64,potted plant&&539.330,290.160,640.000,469.210,62,chair&&10.790,260.030,125.120,384.070,62,chair&&560.800,413.950,639.090,479.200,67,dining table&&20.540,376.760,103.780,431.890,62,chair&&1.080,373.210,32.360,480.000,62,chair&&298.200,235.170,381.210,269.250,86,vase&&152.170,256.670,230.580,285.780,62,chair&&364.400,256.570,417.060,283.210,62,chair&&296.780,277.790,329.260,289.780,84,book&&292.800,289.310,314.210,300.650,84,book&&285.800,257.460,299.770,273.600,62,chair']"
519
- ]
520
- },
521
- "execution_count": 4,
522
- "metadata": {},
523
- "output_type": "execute_result"
524
- }
525
- ],
526
- "source": [
527
- "data[0]"
528
- ]
529
- },
530
- {
531
- "cell_type": "code",
532
- "execution_count": 78,
533
- "metadata": {},
534
- "outputs": [
535
- {
536
- "name": "stderr",
537
- "output_type": "stream",
538
- "text": [
539
- "100%|█| 181767/181767 [00:00<00:00,\n"
540
- ]
541
- }
542
- ],
543
- "source": [
544
- "for d in tqdm(data):\n",
545
- " d[2] = d[2].replace('\\\"', '')\n"
546
- ]
547
- },
548
- {
549
- "cell_type": "code",
550
- "execution_count": 49,
551
- "metadata": {},
552
- "outputs": [],
553
- "source": [
554
- "data_ = []\n",
555
- "with open(path_data) as file:\n",
556
- " for i in tqdm(range(6458670)):\n",
557
- " column_l = file.readline().rstrip(\"\\n\").split(\"\\t\")\n",
558
- " data_.append(column_l)\n",
559
- " if len(column_l) < 2:\n",
560
- " break"
561
- ]
562
- },
563
- {
564
- "cell_type": "code",
565
- "execution_count": 64,
566
- "metadata": {},
567
- "outputs": [
568
- {
569
- "name": "stderr",
570
- "output_type": "stream",
571
- "text": [
572
- "5593207it [00:03, 1463300.52it/s]\n"
573
- ]
574
- }
575
- ],
576
- "source": [
577
- "\n",
578
- "data_example = []\n",
579
- "fp = open('/data/mshukor/data/ofa/pretrain_ours/vision_language_mini_.tsv', \"r\")\n",
580
- "data_example = []\n",
581
- "for line in tqdm(fp):\n",
582
- " data_example.append(line)"
583
- ]
584
- },
585
- {
586
- "cell_type": "code",
587
- "execution_count": 74,
588
- "metadata": {},
589
- "outputs": [
590
- {
591
- "name": "stdout",
592
- "output_type": "stream",
593
- "text": [
594
- "2796604\tcc3m/train/8/2d0d96e4ecb8e2e959a3bf10d59b9d05ac114aea.jpg\tthe residential development under construction in district\t\t\t\tcc3m\tcaption\n",
595
- "\n"
596
- ]
597
- }
598
- ],
599
- "source": [
600
- "print(data_example[2796604])"
601
- ]
602
- },
603
- {
604
- "cell_type": "code",
605
- "execution_count": 73,
606
- "metadata": {},
607
- "outputs": [
608
- {
609
- "name": "stdout",
610
- "output_type": "stream",
611
- "text": [
612
- "/val2014/COCO_val2014_000000329789.jpg\tA young man is eating a slice of pizza in his room\t\t\t\tcoco_karp\tcaption\n",
613
- "\n"
614
- ]
615
- }
616
- ],
617
- "source": [
618
- "data_example[2796604]\n",
619
- "fp.seek(2796604)\n",
620
- "for l in fp:\n",
621
- " print(l)\n",
622
- " break"
623
- ]
624
- },
625
- {
626
- "cell_type": "code",
627
- "execution_count": 46,
628
- "metadata": {},
629
- "outputs": [
630
- {
631
- "name": "stdout",
632
- "output_type": "stream",
633
- "text": [
634
- "2317 2317\n",
635
- "2510 2514 2510\n"
636
- ]
637
- }
638
- ],
639
- "source": [
640
- "len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1])# len(line)\n",
641
- "# len(data_[10].rstrip(\"\\n\").split(\"\\t\")[1].encode('utf-8'))\n",
642
- "print(len(data_example[10]), len(data_example[10].encode('utf-8')))\n",
643
- "print(len(data_[10]), len(data_[10].encode('utf-8')), len(data_[10].encode('utf-8').decode(\"utf-8\")))\n",
644
- "\n"
645
- ]
646
- },
647
- {
648
- "cell_type": "code",
649
- "execution_count": null,
650
- "metadata": {},
651
- "outputs": [],
652
- "source": [
653
- "print(data_[10].encode('utf-8'))\n",
654
- "print(data_[10])\n",
655
- "\n",
656
- "print(data_example[10].encode('utf-8'))\n",
657
- "print(data_example[10])"
658
- ]
659
- },
660
- {
661
- "cell_type": "code",
662
- "execution_count": 4,
663
- "metadata": {},
664
- "outputs": [
665
- {
666
- "name": "stderr",
667
- "output_type": "stream",
668
- "text": [
669
- "6458670it [01:45, 61129.36it/s] \n"
670
- ]
671
- }
672
- ],
673
- "source": [
674
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
675
- "\n",
676
- "fp = open(output_path, \"r\")\n",
677
- "data_ = []\n",
678
- "for line in tqdm(fp):\n",
679
- " data_.append(line)\n",
680
- " "
681
- ]
682
- },
683
- {
684
- "cell_type": "code",
685
- "execution_count": 12,
686
- "metadata": {},
687
- "outputs": [
688
- {
689
- "name": "stderr",
690
- "output_type": "stream",
691
- "text": [
692
- "6458670it [04:08, 25941.37it/s]\n"
693
- ]
694
- }
695
- ],
696
- "source": [
697
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
698
- "\n",
699
- "start_id = 0 \n",
700
- "num_max_characters = 2500\n",
701
- "\n",
702
- "with open(output_path, 'w', newline='\\n') as f_output:\n",
703
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
704
- "\n",
705
- " for i, t in tqdm(enumerate(data)):\n",
706
- " text = t[1]\n",
707
- " item = [start_id, text]\n",
708
- " csv_output.writerow(item)\n",
709
- " start_id+=1"
710
- ]
711
- },
712
- {
713
- "cell_type": "code",
714
- "execution_count": 28,
715
- "metadata": {},
716
- "outputs": [
717
- {
718
- "name": "stderr",
719
- "output_type": "stream",
720
- "text": [
721
- "100%|█████████████████████████████████████████████████████████| 181767/181767 [00:03<00:00, 51934.09it/s]\n"
722
- ]
723
- }
724
- ],
725
- "source": [
726
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/detection_mini.tsv'\n",
727
- "\n",
728
- "with open(output_path, 'w', newline='\\n') as f_output:\n",
729
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
730
- "\n",
731
- " for t in tqdm(data):\n",
732
- " csv_output.writerow(t)"
733
- ]
734
- },
735
- {
736
- "cell_type": "code",
737
- "execution_count": 63,
738
- "metadata": {},
739
- "outputs": [
740
- {
741
- "data": {
742
- "text/plain": [
743
- "['5593206',\n",
744
- " 'train2014/COCO_train2014_000000524286.jpg',\n",
745
- " '',\n",
746
- " 'Is that a laptop?',\n",
747
- " '1.0|!+yes',\n",
748
- " '',\n",
749
- " 'vqa_train',\n",
750
- " 'qa']"
751
- ]
752
- },
753
- "execution_count": 63,
754
- "metadata": {},
755
- "output_type": "execute_result"
756
- }
757
- ],
758
- "source": [
759
- "t"
760
- ]
761
- },
762
- {
763
- "cell_type": "markdown",
764
- "metadata": {},
765
- "source": [
766
- "## Create data tsv"
767
- ]
768
- },
769
- {
770
- "cell_type": "code",
771
- "execution_count": null,
772
- "metadata": {},
773
- "outputs": [],
774
- "source": [
775
- "def convert_img_to_str(file_name):\n",
776
- " img = Image.open(file_name) # path to file\n",
777
- " img_buffer = BytesIO()\n",
778
- " img.save(img_buffer, format=img.format)\n",
779
- " byte_data = img_buffer.getvalue()\n",
780
- " base64_str = base64.b64encode(byte_data) # bytes\n",
781
- " base64_str = base64_str.decode(\"utf-8\") # str\n",
782
- " return base64_str"
783
- ]
784
- },
785
- {
786
- "cell_type": "markdown",
787
- "metadata": {},
788
- "source": [
789
- "### Create VL tsv"
790
- ]
791
- },
792
- {
793
- "cell_type": "markdown",
794
- "metadata": {},
795
- "source": [
796
- "#### Caption"
797
- ]
798
- },
799
- {
800
- "cell_type": "code",
801
- "execution_count": null,
802
- "metadata": {},
803
- "outputs": [],
804
- "source": [
805
- "original_data_path = '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json'\n",
806
- "original_data = json.load(open(original_data_path,'r'))"
807
- ]
808
- },
809
- {
810
- "cell_type": "code",
811
- "execution_count": null,
812
- "metadata": {},
813
- "outputs": [],
814
- "source": [
815
- "from preprocess.utils import get_tsv_data_from_jsons\n",
816
- " \n",
817
- "datasets = [\n",
818
- " '/data/mshukor/data/our_albef_data/json_pretrain/coco_karp.json',\n",
819
- " '/data/mshukor/data/our_albef_data/json_pretrain/vg_albef.json',\n",
820
- " '/data/mshukor/data/our_albef_data/json_pretrain/sbu.json',\n",
821
- " '/data/mshukor/data/our_albef_data/json_pretrain/cc3m.json', \n",
822
- " \n",
823
- " ['/data/mshukor/data/refcoco/refcoco+/refs(unc).p', '/data/mshukor/data/refcoco/refcoco+/instances.json'],\n",
824
- " \n",
825
- " '/data/mshukor/data/our_albef_data/data/vqa_train.json',\n",
826
- "]\n",
827
- "\n",
828
- "start_id = 0\n",
829
- "task_types = ['caption',\n",
830
- " 'caption',\n",
831
- " 'caption',\n",
832
- " 'caption',\n",
833
- " 'visual_grounding',\n",
834
- " 'qa',]\n",
835
- "\n",
836
- "tsvs = get_tsv_data_from_jsons(datasets, start_id, task_types, convert_images=False)\n"
837
- ]
838
- },
839
- {
840
- "cell_type": "code",
841
- "execution_count": null,
842
- "metadata": {},
843
- "outputs": [],
844
- "source": [
845
- "len(tsvs)\n",
846
- "# tsvs[-10000]\n",
847
- "tsvs[-1000000]"
848
- ]
849
- },
850
- {
851
- "cell_type": "code",
852
- "execution_count": null,
853
- "metadata": {},
854
- "outputs": [],
855
- "source": [
856
- "import csv\n",
857
- "from io import StringIO\n",
858
- "\n",
859
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/vision_language_mini.tsv'\n",
860
- "\n",
861
- "with open(output_path, 'w', newline='') as f_output:\n",
862
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
863
- "\n",
864
- " for t in tqdm(tsvs):\n",
865
- " csv_output.writerow(t)"
866
- ]
867
- },
868
- {
869
- "cell_type": "code",
870
- "execution_count": null,
871
- "metadata": {},
872
- "outputs": [],
873
- "source": [
874
- "csv.field_size_limit(sys.maxsize)\n",
875
- "\n",
876
- "\n",
877
- "out_data = []\n",
878
- "selected_cols='0,1,2,3,4,5,6,7'\n",
879
- "\n",
880
- "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
881
- "\n",
882
- "with open(output_path) as file:\n",
883
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
884
- " for line in tqdm(tsv_file):\n",
885
- " d = [line[i] for i in selected_col_ids]\n",
886
- " out_data.append(d)\n",
887
- " "
888
- ]
889
- },
890
- {
891
- "cell_type": "code",
892
- "execution_count": null,
893
- "metadata": {},
894
- "outputs": [],
895
- "source": [
896
- "out_data[-1]"
897
- ]
898
- },
899
- {
900
- "cell_type": "markdown",
901
- "metadata": {},
902
- "source": [
903
- "#### VQA"
904
- ]
905
- },
906
- {
907
- "cell_type": "code",
908
- "execution_count": null,
909
- "metadata": {},
910
- "outputs": [],
911
- "source": [
912
- "original_data_path = '/data/mshukor/data/our_albef_data/data/vqa_train.json'\n",
913
- "original_data = json.load(open(original_data_path,'r'))\n"
914
- ]
915
- },
916
- {
917
- "cell_type": "code",
918
- "execution_count": null,
919
- "metadata": {},
920
- "outputs": [],
921
- "source": [
922
- "original_data[100]"
923
- ]
924
- },
925
- {
926
- "cell_type": "code",
927
- "execution_count": null,
928
- "metadata": {},
929
- "outputs": [],
930
- "source": [
931
- "# 1.0|!+horizontal&&0.3|!+south&&0.3|!+straight&&0.3|!+vertical\n",
932
- "\n",
933
- "from preprocess.utils import get_tsv_vqa_data_from_json\n",
934
- "\n",
935
- "\n",
936
- "start_id = 0\n",
937
- "dataset_name = 'vqav2'\n",
938
- "task_type = 'qa'\n",
939
- "\n",
940
- "image_root = '/data/mshukor/data/coco'\n",
941
- "tmp = get_tsv_vqa_data_from_json(original_data, start_id, dataset_name, task_type, image_root=image_root, convert_images=False)"
942
- ]
943
- },
944
- {
945
- "cell_type": "code",
946
- "execution_count": null,
947
- "metadata": {},
948
- "outputs": [],
949
- "source": [
950
- "tmp[10]"
951
- ]
952
- },
953
- {
954
- "cell_type": "code",
955
- "execution_count": null,
956
- "metadata": {},
957
- "outputs": [],
958
- "source": []
959
- },
960
- {
961
- "cell_type": "markdown",
962
- "metadata": {},
963
- "source": [
964
- "#### Visual Grounding "
965
- ]
966
- },
967
- {
968
- "cell_type": "code",
969
- "execution_count": null,
970
- "metadata": {},
971
- "outputs": [],
972
- "source": [
973
- "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+_train.json'\n",
974
- "original_data = json.load(open(original_data_path,'r'))\n",
975
- "\n",
976
- "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/dets.json'\n",
977
- "det_file = json.load(open(original_data_path,'r'))\n",
978
- "\n",
979
- "original_data_path = '/data/mshukor/data/our_albef_data/data/refcoco+/cocos.json'\n",
980
- "coco_file = json.load(open(original_data_path,'r'))"
981
- ]
982
- },
983
- {
984
- "cell_type": "code",
985
- "execution_count": null,
986
- "metadata": {},
987
- "outputs": [],
988
- "source": [
989
- "list(det_file.keys())[:10]"
990
- ]
991
- },
992
- {
993
- "cell_type": "code",
994
- "execution_count": null,
995
- "metadata": {},
996
- "outputs": [],
997
- "source": [
998
- "original_data_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
999
- "original_data = json.load(open(original_data_path,'r'))"
1000
- ]
1001
- },
1002
- {
1003
- "cell_type": "code",
1004
- "execution_count": null,
1005
- "metadata": {},
1006
- "outputs": [],
1007
- "source": [
1008
- "import pickle\n",
1009
- "\n",
1010
- "ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
1011
- "refs = pickle.load(open(ref_path, 'rb'))"
1012
- ]
1013
- },
1014
- {
1015
- "cell_type": "code",
1016
- "execution_count": null,
1017
- "metadata": {},
1018
- "outputs": [],
1019
- "source": [
1020
- "for i, ref in tqdm(enumerate(refs)):\n",
1021
- " \n",
1022
- " "
1023
- ]
1024
- },
1025
- {
1026
- "cell_type": "code",
1027
- "execution_count": null,
1028
- "metadata": {},
1029
- "outputs": [],
1030
- "source": [
1031
- "len(refs)"
1032
- ]
1033
- },
1034
- {
1035
- "cell_type": "code",
1036
- "execution_count": null,
1037
- "metadata": {},
1038
- "outputs": [],
1039
- "source": [
1040
- "refs[500]"
1041
- ]
1042
- },
1043
- {
1044
- "cell_type": "code",
1045
- "execution_count": null,
1046
- "metadata": {},
1047
- "outputs": [],
1048
- "source": [
1049
- "id_to_annot = {}\n",
1050
- "for annot in original_data['annotations']:\n",
1051
- " id_to_annot[annot['id']] = annot\n",
1052
- " \n",
1053
- " "
1054
- ]
1055
- },
1056
- {
1057
- "cell_type": "code",
1058
- "execution_count": null,
1059
- "metadata": {},
1060
- "outputs": [],
1061
- "source": [
1062
- "id_to_images = {}\n",
1063
- "for annot in tqdm(original_data['images']):\n",
1064
- " id_to_images[annot['id']] = annot"
1065
- ]
1066
- },
1067
- {
1068
- "cell_type": "code",
1069
- "execution_count": null,
1070
- "metadata": {},
1071
- "outputs": [],
1072
- "source": [
1073
- "id_to_images[576457]"
1074
- ]
1075
- },
1076
- {
1077
- "cell_type": "code",
1078
- "execution_count": null,
1079
- "metadata": {},
1080
- "outputs": [],
1081
- "source": [
1082
- "list(id_to_annot.keys())[:10]\n",
1083
- "id_to_annot[1640859]['bbox']\n",
1084
- "for r in tqdm(id_to_annot.values()):\n",
1085
- " if r['bbox'][0] > 0:\n",
1086
- " print(r['bbox'])"
1087
- ]
1088
- },
1089
- {
1090
- "cell_type": "code",
1091
- "execution_count": null,
1092
- "metadata": {},
1093
- "outputs": [],
1094
- "source": []
1095
- },
1096
- {
1097
- "cell_type": "code",
1098
- "execution_count": null,
1099
- "metadata": {},
1100
- "outputs": [],
1101
- "source": [
1102
- "list(original_data.keys())[:10]"
1103
- ]
1104
- },
1105
- {
1106
- "cell_type": "code",
1107
- "execution_count": null,
1108
- "metadata": {},
1109
- "outputs": [],
1110
- "source": [
1111
- "ref_path = '/data/mshukor/data/refcoco/refcoco+/refs(unc).p'\n",
1112
- "instances_path = '/data/mshukor/data/refcoco/refcoco+/instances.json'\n",
1113
- "start_id = 0\n",
1114
- "dataset_name='refcoco_train'\n",
1115
- "task_type='visual_grounding'\n",
1116
- "convert_images=False\n",
1117
- "split='train'\n",
1118
- "\n",
1119
- "tmp = get_tsv_from_refcoco(ref_path, instances_path, start_id, dataset_name=dataset_name, task_type=task_type, convert_images=convert_images, split=split)"
1120
- ]
1121
- },
1122
- {
1123
- "cell_type": "code",
1124
- "execution_count": null,
1125
- "metadata": {},
1126
- "outputs": [],
1127
- "source": [
1128
- "tmp[-1]"
1129
- ]
1130
- },
1131
- {
1132
- "cell_type": "code",
1133
- "execution_count": null,
1134
- "metadata": {},
1135
- "outputs": [],
1136
- "source": [
1137
- "Image.open('/data/mshukor/data/coco/train2014/COCO_train2014_000000000072.jpg').convert('RGB')"
1138
- ]
1139
- },
1140
- {
1141
- "cell_type": "code",
1142
- "execution_count": null,
1143
- "metadata": {},
1144
- "outputs": [],
1145
- "source": [
1146
- "original_data['images'][:10]"
1147
- ]
1148
- },
1149
- {
1150
- "cell_type": "code",
1151
- "execution_count": null,
1152
- "metadata": {},
1153
- "outputs": [],
1154
- "source": [
1155
- "# ['third book starting from left', '', '29.1,11.72,66.81,343.41', '', 'refcoco_train', 'visual_grounding']\n",
1156
- "\n",
1157
- "original_data['categories']"
1158
- ]
1159
- },
1160
- {
1161
- "cell_type": "markdown",
1162
- "metadata": {},
1163
- "source": [
1164
- "### Imagenet"
1165
- ]
1166
- },
1167
- {
1168
- "cell_type": "code",
1169
- "execution_count": null,
1170
- "metadata": {},
1171
- "outputs": [],
1172
- "source": [
1173
- "# image-id and image base64 string .txt file \n",
1174
- "# id, image, code in tsv final \n",
1175
- "\n",
1176
- "from preprocesss.utils import create_imagenet_txt_files\n",
1177
- "\n",
1178
- "\n",
1179
- "path_data = '/data/mshukor/data/imagenet/val'\n",
1180
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_val.txt'\n",
1181
- "\n",
1182
- "\n",
1183
- "create_imagenet_txt_files(path_data, output_path)"
1184
- ]
1185
- },
1186
- {
1187
- "cell_type": "code",
1188
- "execution_count": null,
1189
- "metadata": {},
1190
- "outputs": [],
1191
- "source": [
1192
- "start_id\n",
1193
- "len(data)\n",
1194
- "data[0]"
1195
- ]
1196
- },
1197
- {
1198
- "cell_type": "code",
1199
- "execution_count": null,
1200
- "metadata": {},
1201
- "outputs": [],
1202
- "source": [
1203
- "\n",
1204
- "code_path = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train_codes.tsv'\n",
1205
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
1206
- "\n",
1207
- "def save_image_only_tsv_from_code_files(code_path, output_path, start_id=0):\n",
1208
- " selected_col_ids = [0,1]\n",
1209
- " out_data = []\n",
1210
- " with open(code_path) as file:\n",
1211
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1212
- " for line in tqdm(tsv_file):\n",
1213
- " d = [line[i] for i in selected_col_ids]\n",
1214
- " d = [start_id]+d\n",
1215
- " out_data.append(d)\n",
1216
- "\n",
1217
- "\n",
1218
- " with open(output_path, 'w', newline='') as f_output:\n",
1219
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
1220
- "\n",
1221
- " for t in tqdm(out_data):\n",
1222
- " csv_output.writerow(t)\n",
1223
- "\n",
1224
- "save_image_only_tsv_from_code_files(code_path, output_path, start_id=0)"
1225
- ]
1226
- },
1227
- {
1228
- "cell_type": "code",
1229
- "execution_count": null,
1230
- "metadata": {},
1231
- "outputs": [],
1232
- "source": [
1233
- "selected_col_ids = [0,1,2]\n",
1234
- "out_data = []\n",
1235
- "with open(output_path) as file:\n",
1236
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1237
- " for line in tqdm(tsv_file):\n",
1238
- " d = [line[i] for i in selected_col_ids]\n",
1239
- " out_data.append(d)\n",
1240
- " break"
1241
- ]
1242
- },
1243
- {
1244
- "cell_type": "code",
1245
- "execution_count": null,
1246
- "metadata": {},
1247
- "outputs": [],
1248
- "source": [
1249
- "len(out_data[0][2].split(' '))"
1250
- ]
1251
- },
1252
- {
1253
- "cell_type": "markdown",
1254
- "metadata": {},
1255
- "source": [
1256
- "#### Fix image paths"
1257
- ]
1258
- },
1259
- {
1260
- "cell_type": "code",
1261
- "execution_count": 33,
1262
- "metadata": {},
1263
- "outputs": [
1264
- {
1265
- "name": "stderr",
1266
- "output_type": "stream",
1267
- "text": [
1268
- "1281167it [00:16, 79250.80it/s]\n"
1269
- ]
1270
- }
1271
- ],
1272
- "source": [
1273
- "\n",
1274
- "path_data = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
1275
- "selected_cols='0,1,2'\n",
1276
- "\n",
1277
- "data = []\n",
1278
- "\n",
1279
- "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
1280
- "\n",
1281
- "with open(path_data) as file:\n",
1282
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1283
- " for line in tqdm(tsv_file):\n",
1284
- "\n",
1285
- " d = [line[i] for i in selected_col_ids]\n",
1286
- "# print(d)\n",
1287
- " data.append(d)"
1288
- ]
1289
- },
1290
- {
1291
- "cell_type": "code",
1292
- "execution_count": 44,
1293
- "metadata": {},
1294
- "outputs": [
1295
- {
1296
- "name": "stderr",
1297
- "output_type": "stream",
1298
- "text": [
1299
- "1281167it [00:16, 76760.12it/s]\n",
1300
- "1281167it [00:01, 671149.72it/s]\n",
1301
- "100%|█████| 1281167/1281167 [00:01<00:00, 947543.73it/s]\n"
1302
- ]
1303
- }
1304
- ],
1305
- "source": [
1306
- "# from imge-id img-path to \n",
1307
- "def replace_image_id_by_path(input_tsv, output_tsv, mapping_file):\n",
1308
- " selected_cols='0,1,2'\n",
1309
- " data = []\n",
1310
- " selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
1311
- " with open(input_tsv) as file:\n",
1312
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1313
- " for line in tqdm(tsv_file):\n",
1314
- " d = [line[i] for i in selected_col_ids]\n",
1315
- " data.append(d)\n",
1316
- " \n",
1317
- " im_id_to_path = {}\n",
1318
- " with open(mapping_file) as file:\n",
1319
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1320
- " for line in tqdm(tsv_file):\n",
1321
- " d = [line[i] for i in [0, 1]]\n",
1322
- " im_id_to_path[d[0]] = d[1]\n",
1323
- " \n",
1324
- " for d in tqdm(data):\n",
1325
- " im_id = d[1].split('/')[-1].split('.')[0]\n",
1326
- " im_path = im_id_to_path[im_id]\n",
1327
- " d[1] = im_path\n",
1328
- " \n",
1329
- " with open(output_tsv, 'w', newline='') as f_output:\n",
1330
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
1331
- "\n",
1332
- " for t in tqdm(data):\n",
1333
- " csv_output.writerow(t)\n",
1334
- " \n",
1335
- " return data\n",
1336
- "\n",
1337
- "input_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
1338
- "output_tsv = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
1339
- "mapping_file = '/data/mshukor/data/ofa/pretrain_ours/imagenet_train.txt'\n",
1340
- "\n",
1341
- "tmp = replace_image_id_by_path(input_tsv, output_tsv, mapping_file)"
1342
- ]
1343
- },
1344
- {
1345
- "cell_type": "code",
1346
- "execution_count": 45,
1347
- "metadata": {},
1348
- "outputs": [
1349
- {
1350
- "data": {
1351
- "text/plain": [
1352
- "['0',\n",
1353
- " 'RawImages/train/n03146219/n03146219_8050.JPEG',\n",
1354
- " '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
1355
- ]
1356
- },
1357
- "execution_count": 45,
1358
- "metadata": {},
1359
- "output_type": "execute_result"
1360
- }
1361
- ],
1362
- "source": [
1363
- "tmp[0]"
1364
- ]
1365
- },
1366
- {
1367
- "cell_type": "code",
1368
- "execution_count": 36,
1369
- "metadata": {},
1370
- "outputs": [
1371
- {
1372
- "name": "stderr",
1373
- "output_type": "stream",
1374
- "text": [
1375
- "100%|█████| 1281167/1281167 [00:03<00:00, 336250.44it/s]\n"
1376
- ]
1377
- }
1378
- ],
1379
- "source": [
1380
- "# imgage_dir = 'imagenet/RawImages/train/'\n",
1381
- "# for d in tqdm(data):\n",
1382
- "# im_id = d[1]\n",
1383
- "# im_dir = im_id.split('_')[0]\n",
1384
- "# im_path = os.path.join(im_dir, im_id+'.JPEG')\n",
1385
- "# d[1] = os.path.join(imgage_dir, im_path)"
1386
- ]
1387
- },
1388
- {
1389
- "cell_type": "code",
1390
- "execution_count": 39,
1391
- "metadata": {},
1392
- "outputs": [
1393
- {
1394
- "data": {
1395
- "text/plain": [
1396
- "['0',\n",
1397
- " 'imagenet/RawImages/train/n03146219/n03146219_8050.JPEG',\n",
1398
- " '7442 662 7977 1652 6320 650 4376 992 1596 7734 1925 5335 3935 5604 5697 4504 5114 4050 144 215 144 6691 5321 7769 4755 3346 4691 3469 4175 1351 6907 9 6948 7749 7166 215 1026 931 970 4168 2675 6874 6248 2306 6138 8052 2970 6302 5550 2491 6931 969 6574 8014 6588 6639 389 1882 688 4691 4266 675 6248 3938 2387 4365 5999 261 2966 3499 651 5290 970 3526 5583 516 167 2103 1513 198 6657 7442 1118 7207 7307 1792 2078 388 4285 3417 5450 6959 6999 1306 1649 4556 2533 1103 6869 7681 8051 1916 7160 7743 2704 8063 2726 4860 2383 1635 8061 3497 7327 5915 7836 5697 1719 2136 96 970 7184 5167 2250 404 7007 7565 2742 33 7076 5250 7790 1838 1298 2847 3250 1204 1934 5550 4360 5688 1791 3465 634 4663 2991 5352 4066 4157 946 1596 3504 5855 5629 5411 7695 3627 3942 5631 2736 2883 5059 1423 2009 2643 1873 4960 1661 545 1396 3450 3145 211 6869 2226 6780 2724 4606 3702 3667 891 6236 6419 3531 7032 5277 3381 3031 7878 725 1652 1813 5037 949 3087 405 7884 3784 5432 633 4256 235 3182 3686 5450 2419 1593 7948 5741 6237 7233 20 7470 7071 182 1584 6780 7913 2691 7207 5094 5199 4502 5030 2360 448 5129 2713 1094 1678 1934 2458 2970 2133 867 3332 6138 294 3260 5495 4189 5732 3940 5629 4139 7335 7607 3248 4981 2109 3660 4364 7763 3964 7163 6702 691']"
1399
- ]
1400
- },
1401
- "execution_count": 39,
1402
- "metadata": {},
1403
- "output_type": "execute_result"
1404
- }
1405
- ],
1406
- "source": [
1407
- "data[0]"
1408
- ]
1409
- },
1410
- {
1411
- "cell_type": "code",
1412
- "execution_count": 40,
1413
- "metadata": {},
1414
- "outputs": [
1415
- {
1416
- "name": "stderr",
1417
- "output_type": "stream",
1418
- "text": [
1419
- "100%|██████| 1281167/1281167 [00:27<00:00, 46704.02it/s]\n"
1420
- ]
1421
- }
1422
- ],
1423
- "source": [
1424
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/image_mini.tsv'\n",
1425
- "with open(output_path, 'w', newline='') as f_output:\n",
1426
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
1427
- "\n",
1428
- " for t in tqdm(data):\n",
1429
- " csv_output.writerow(t)\n"
1430
- ]
1431
- },
1432
- {
1433
- "cell_type": "markdown",
1434
- "metadata": {},
1435
- "source": [
1436
- "### Object detection"
1437
- ]
1438
- },
1439
- {
1440
- "cell_type": "markdown",
1441
- "metadata": {},
1442
- "source": [
1443
- "#### COCO"
1444
- ]
1445
- },
1446
- {
1447
- "cell_type": "code",
1448
- "execution_count": null,
1449
- "metadata": {},
1450
- "outputs": [],
1451
- "source": [
1452
- "# '505.856,189.994,799.744,450.016,/m/07j7r,tree&&753.664,384.00600000000003,827.392,446.572,/m/0c9ph5,flower'\n",
1453
- "\n",
1454
- "path_json = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
1455
- "\n",
1456
- "data = json.load(open(path_json,'r'))"
1457
- ]
1458
- },
1459
- {
1460
- "cell_type": "code",
1461
- "execution_count": null,
1462
- "metadata": {},
1463
- "outputs": [],
1464
- "source": [
1465
- "def get_tsv_from_coco_detection(instances_path, start_id, convert_images=True, split='train'):\n",
1466
- "\n",
1467
- " instances = json.load(open(instances_path,'r'))\n",
1468
- " imgid_to_annot = {}\n",
1469
- " for annot in tqdm(instances['annotations']):\n",
1470
- " if annot['image_id'] not in imgid_to_annot:\n",
1471
- " imgid_to_annot[annot['image_id']] = [annot]\n",
1472
- " else:\n",
1473
- " imgid_to_annot[annot['image_id']].append(annot)\n",
1474
- "\n",
1475
- " id_to_category = {}\n",
1476
- " for annot in tqdm(instances['categories']):\n",
1477
- " id_to_category[annot['id']] = annot['name']\n",
1478
- "\n",
1479
- " tsv_data = []\n",
1480
- " missied = []\n",
1481
- " for ref in tqdm(instances['images']):\n",
1482
- " ref_split = split\n",
1483
- " image_id = ref['id']\n",
1484
- " file_name = ref['file_name']\n",
1485
- "\n",
1486
- " if ref_split == 'train':\n",
1487
- " file_name = os.path.join('coco/train2014', file_name)\n",
1488
- "\n",
1489
- " if convert_images:\n",
1490
- " img_path = os.path.join('/data/mshukor/data/', file_name)\n",
1491
- " img = convert_img_to_str(img_path)\n",
1492
- " else:\n",
1493
- " img_path = file_name.replace('/data/mshukor/data/', '')\n",
1494
- " img = img_path\n",
1495
- "\n",
1496
- " # ann_id = ref['id']\n",
1497
- " # annot = id_to_annot[ann_id]\n",
1498
- " if image_id not in imgid_to_annot:\n",
1499
- " missied.append(image_id)\n",
1500
- " continue\n",
1501
- " annots = imgid_to_annot[image_id]\n",
1502
- " detections = []\n",
1503
- " areas = []\n",
1504
- " for annot in annots:\n",
1505
- " bbox = annot['bbox'] # x,y,w,h bottom left\n",
1506
- " area = bbox[2]*bbox[3]\n",
1507
- " x1, y1, x2, y2 = bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3] # top left, bottom right \n",
1508
- " # box = '{:.3f},{:.3f},{:.3f},{:.3f}'.format(x1, y1, x2, y2)\n",
1509
- "\n",
1510
- " object_id = annot['category_id']\n",
1511
- " category = id_to_category[object_id]\n",
1512
- "\n",
1513
- " tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
1514
- " areas.append(area)\n",
1515
- " detections.append(tmp)\n",
1516
- "\n",
1517
- " sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
1518
- " detections = [detections[k] for k in sorted_indices]\n",
1519
- " detections = '&&'.join(detections)\n",
1520
- " t = [start_id, img, detections]\n",
1521
- "\n",
1522
- " tsv_data.append(t)\n",
1523
- " start_id+=1\n",
1524
- "\n",
1525
- " return tsv_data\n",
1526
- "\n",
1527
- "instances_path = '/data/mshukor/data/coco/annotations/instances_train2014.json'\n",
1528
- "start_id = 0\n",
1529
- "tmp = get_tsv_from_coco_detection(instances_path, start_id, convert_images=False, split='train')"
1530
- ]
1531
- },
1532
- {
1533
- "cell_type": "code",
1534
- "execution_count": null,
1535
- "metadata": {},
1536
- "outputs": [],
1537
- "source": [
1538
- "list(imgid_to_annot.keys())[:10]\n",
1539
- "len(missied)"
1540
- ]
1541
- },
1542
- {
1543
- "cell_type": "markdown",
1544
- "metadata": {},
1545
- "source": [
1546
- "#### VG"
1547
- ]
1548
- },
1549
- {
1550
- "cell_type": "code",
1551
- "execution_count": null,
1552
- "metadata": {},
1553
- "outputs": [],
1554
- "source": [
1555
- "def get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=True, split='train'):\n",
1556
- " \n",
1557
- " instances = json.load(open(instances_path,'r'))\n",
1558
- " \n",
1559
- " id_to_objects = {}\n",
1560
- " for d in instances:\n",
1561
- " id_to_objects[d['id']] = d\n",
1562
- "\n",
1563
- "\n",
1564
- " \n",
1565
- " id_to_image_path = {}\n",
1566
- " for root, dirs, files, in os.walk(path_images):\n",
1567
- " for d in dirs:\n",
1568
- " dir_path = os.path.join(root, d)\n",
1569
- " for _, _, dir_files in os.walk(dir_path):\n",
1570
- " for f in dir_files:\n",
1571
- " file_path = os.path.join(dir_path, f)\n",
1572
- " file_path = '/'.join(file_path.split('/')[-4:])\n",
1573
- " image_id = f.split('.')[0]\n",
1574
- " id_to_image_path[image_id] = file_path\n",
1575
- "\n",
1576
- " \n",
1577
- "\n",
1578
- "\n",
1579
- " tsv_data = []\n",
1580
- " missied = []\n",
1581
- " negs = []\n",
1582
- " for ref in tqdm(id_to_image_path.keys()):\n",
1583
- " ref_split = split\n",
1584
- " \n",
1585
- " image_id = ref\n",
1586
- " \n",
1587
- " file_name = id_to_image_path[image_id]\n",
1588
- " if convert_images:\n",
1589
- " img_path = os.path.join('/data/mshukor/data/', file_name)\n",
1590
- " img = convert_img_to_str(img_path)\n",
1591
- " else:\n",
1592
- " img_path = file_name.replace('/data/mshukor/data/', '')\n",
1593
- " img = img_path\n",
1594
- "\n",
1595
- " \n",
1596
- " if int(image_id) in id_to_objects:\n",
1597
- " objects = id_to_objects[int(image_id)]['objects']\n",
1598
- " else:\n",
1599
- " missied.append(image_id)\n",
1600
- " continue\n",
1601
- " \n",
1602
- " if len(objects) == 0:\n",
1603
- " missied.append(image_id)\n",
1604
- " continue\n",
1605
- " \n",
1606
- " \n",
1607
- " areas = []\n",
1608
- " detections = []\n",
1609
- " for annot in objects:\n",
1610
- " x,y,w,h = annot['x'], annot['y'], annot['w'], annot['h'] # x,y,w,h bottom left\n",
1611
- " \n",
1612
- " area = w*h\n",
1613
- " \n",
1614
- " x1, y1, x2, y2 = x, y, x + w, y + h # top left, bottom right \n",
1615
- " \n",
1616
- " if x1 < 0 or x2 < 0:\n",
1617
- " negs.append(annot)\n",
1618
- " x1 = max(0, x1)\n",
1619
- " x2 = max(0, x2)\n",
1620
- " \n",
1621
- " \n",
1622
- " category = ','.join(annot['names']).replace('\\x00','')\n",
1623
- " object_id = annot['id']\n",
1624
- " \n",
1625
- " \n",
1626
- " tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
1627
- " detections.append(tmp)\n",
1628
- " areas.append(area)\n",
1629
- "\n",
1630
- " sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
1631
- " detections = [detections[k] for k in sorted_indices]\n",
1632
- " \n",
1633
- " detections = '&&'.join(detections)\n",
1634
- " t = [start_id, img, detections]\n",
1635
- "\n",
1636
- " tsv_data.append(t)\n",
1637
- " start_id+=1\n",
1638
- " print('missed images:', len(missied), 'negs', len(negs))\n",
1639
- " return tsv_data\n",
1640
- "\n",
1641
- "\n",
1642
- "instances_path = '/data/mshukor/data/visual_genome/annotations/objects.json'\n",
1643
- "path_images = '/data/mshukor/data/visual_genome/images'\n",
1644
- "start_id = 0\n",
1645
- "\n",
1646
- "tmp = get_tsv_from_vg_detection(instances_path, path_images, start_id, convert_images=False, split='train')"
1647
- ]
1648
- },
1649
- {
1650
- "cell_type": "code",
1651
- "execution_count": null,
1652
- "metadata": {},
1653
- "outputs": [],
1654
- "source": [
1655
- "image_root = '/data/mshukor/data/'\n",
1656
- "\n",
1657
- "Image.open(image_root+id_to_image_path['1087']).convert('RGB')"
1658
- ]
1659
- },
1660
- {
1661
- "cell_type": "markdown",
1662
- "metadata": {},
1663
- "source": [
1664
- "#### OpenImagesV5"
1665
- ]
1666
- },
1667
- {
1668
- "cell_type": "code",
1669
- "execution_count": null,
1670
- "metadata": {},
1671
- "outputs": [],
1672
- "source": [
1673
- "# data_path = '/data/mshukor/data/OpenImagesV5/train-annotations-bbox.csv'\n",
1674
- "# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable.csv'\n",
1675
- "# data_path = '/data/mshukor/data/OpenImagesV5/train-images-boxable-with-rotation.csv'\n",
1676
- "data_path = '/data/mshukor/data/OpenImagesV5/class-descriptions-boxable.csv'\n",
1677
- "\n",
1678
- "\n",
1679
- "\n",
1680
- "\n",
1681
- "selected_col_ids = [0,1,2]\n",
1682
- "out_data = []\n",
1683
- "with open(data_path) as file:\n",
1684
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1685
- " for i, line in tqdm(enumerate(tsv_file)):\n",
1686
- " # d = [line[i] for i in selected_col_ids]\n",
1687
- " out_data.append(line)\n",
1688
- "# print(line)\n",
1689
- "# if i > 2:\n",
1690
- "# break\n",
1691
- " "
1692
- ]
1693
- },
1694
- {
1695
- "cell_type": "code",
1696
- "execution_count": null,
1697
- "metadata": {},
1698
- "outputs": [],
1699
- "source": [
1700
- "def get_tsv_from_openimages_detection(instances_path, path_images, start_id, convert_images=False, split='train')\n",
1701
- "\n",
1702
- " id_to_image_path = {}\n",
1703
- " for root, dirs, files, in os.walk(path_images):\n",
1704
- " for d in dirs:\n",
1705
- " dir_path = os.path.join(root, d)\n",
1706
- " for _, _, dir_files in os.walk(dir_path):\n",
1707
- " for f in dir_files:\n",
1708
- " file_path = os.path.join(dir_path, f)\n",
1709
- " file_path = '/'.join(file_path.split('/')[-4:])\n",
1710
- " image_id = f.split('.')[0]\n",
1711
- " id_to_image_path[image_id] = file_path\n",
1712
- "\n",
1713
- " image_root = '/gpfsdswork/dataset'\n",
1714
- "\n",
1715
- " def imagepath_to_image_size(path):\n",
1716
- " w, h = Image.open(path).size\n",
1717
- "\n",
1718
- " id_to_annot = {}\n",
1719
- " with open(instances_path) as file:\n",
1720
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1721
- " for i, line in tqdm(enumerate(tsv_file)):\n",
1722
- " img_id = line[0].split(',')[0]\n",
1723
- " if img_id in id_to_annot:\n",
1724
- " id_to_annot[img_id].append(line)\n",
1725
- " else:\n",
1726
- " id_to_annot[img_id] = [line]\n",
1727
- "\n",
1728
- " classid_to_class = {}\n",
1729
- "\n",
1730
- " with open(class_path) as file:\n",
1731
- " tsv_file = csv.reader(file, delimiter=',')\n",
1732
- " for i, line in tqdm(enumerate(tsv_file)):\n",
1733
- " classid_to_class[line[0]] = line[1]\n",
1734
- "\n",
1735
- " tsv_data = []\n",
1736
- " for img_id in id_to_annot.keys():\n",
1737
- " annots = id_to_annot[img_id]\n",
1738
- " img_path = id_to_image_path[img_id]\n",
1739
- " orig_img_path = os.path.join(image_root, img_path)\n",
1740
- " w, h = imagepath_to_image_size(path)\n",
1741
- "\n",
1742
- " if convert_images:\n",
1743
- " img = convert_img_to_str(orig_img_path)\n",
1744
- " else:\n",
1745
- " img = img_path\n",
1746
- "\n",
1747
- " areas = []\n",
1748
- " detections = []\n",
1749
- " for d in annots:\n",
1750
- " d = d[0].split(',')\n",
1751
- "\n",
1752
- " x1, x2, y1, y2 = d[4:8]\n",
1753
- " x1, x2, y1, y2 = x1*w, x2*w, y1*h, y2*h\n",
1754
- " box_w, box_h = x2 - x1, y2 - y1\n",
1755
- " area = box_w*box_h\n",
1756
- " areas.append(area)\n",
1757
- "\n",
1758
- " object_id = d[2]\n",
1759
- " category = classid_to_class[object_id]\n",
1760
- "\n",
1761
- " tmp = '{:.3f},{:.3f},{:.3f},{:.3f},{},{}'.format(x1, y1, x2, y2, object_id, category)\n",
1762
- " detections.append(tmp)\n",
1763
- "\n",
1764
- "\n",
1765
- " sorted_indices = sorted(range(len(areas)), key=lambda k: areas[k], reverse=True)\n",
1766
- " detections = [detections[k] for k in sorted_indices]\n",
1767
- "\n",
1768
- " detections = '&&'.join(detections)\n",
1769
- " t = [start_id, img, detections]\n",
1770
- "\n",
1771
- " tsv_data.append(t)\n",
1772
- " start_id+=1\n",
1773
- " \n",
1774
- " return tsv_data\n",
1775
- "\n",
1776
- " "
1777
- ]
1778
- },
1779
- {
1780
- "cell_type": "code",
1781
- "execution_count": null,
1782
- "metadata": {},
1783
- "outputs": [],
1784
- "source": [
1785
- "e39871fd9fd74f55"
1786
- ]
1787
- },
1788
- {
1789
- "cell_type": "markdown",
1790
- "metadata": {},
1791
- "source": [
1792
- "### Text"
1793
- ]
1794
- },
1795
- {
1796
- "cell_type": "markdown",
1797
- "metadata": {},
1798
- "source": [
1799
- "#### En Wikipedia"
1800
- ]
1801
- },
1802
- {
1803
- "cell_type": "code",
1804
- "execution_count": null,
1805
- "metadata": {},
1806
- "outputs": [],
1807
- "source": [
1808
- "from datasets import load_dataset"
1809
- ]
1810
- },
1811
- {
1812
- "cell_type": "code",
1813
- "execution_count": null,
1814
- "metadata": {},
1815
- "outputs": [],
1816
- "source": [
1817
- "%env http_proxy='http://192.168.0.100:3128' \n",
1818
- "%env https_proxy='http://192.168.0.100:3128'\n",
1819
- "\n",
1820
- "%env HF_DATASETS_CACHE=\"/data/mshukor/data/.cache\"\n",
1821
- "%env HF_DATASETS_OFFLINE=1"
1822
- ]
1823
- },
1824
- {
1825
- "cell_type": "code",
1826
- "execution_count": null,
1827
- "metadata": {},
1828
- "outputs": [],
1829
- "source": [
1830
- "tmp = load_dataset(\"wikipedia\", \"20220301.en\", cache_dir=\"/data/mshukor/data/.cache\")"
1831
- ]
1832
- },
1833
- {
1834
- "cell_type": "code",
1835
- "execution_count": null,
1836
- "metadata": {},
1837
- "outputs": [],
1838
- "source": [
1839
- "len(tmp['train'][0]['text'])\n",
1840
- "tmp['train'][0]['text'][:512]"
1841
- ]
1842
- },
1843
- {
1844
- "cell_type": "code",
1845
- "execution_count": null,
1846
- "metadata": {},
1847
- "outputs": [],
1848
- "source": [
1849
- "def remove_special(input_string):\n",
1850
- " final_string = \"\"\n",
1851
- " for character in input_string:\n",
1852
- " if character == \" \":\n",
1853
- " final_string = final_string + character\n",
1854
- " else:\n",
1855
- " if(character.isalnum()):\n",
1856
- " final_string = final_string + character\n",
1857
- " return final_string\n",
1858
- "\n",
1859
- "def get_tsv_from_text_data(data_name=\"wikipedia\", data_subname=\"20220301.en\", \n",
1860
- " output_path, cache_dir=\"/data/mshukor/data/.cache\", start_id=0, num_max_characters=2500):\n",
1861
- " from datasets import load_dataset\n",
1862
- " tmp = load_dataset(data_name, data_subname, cache_dir=cache_dir)\n",
1863
- "\n",
1864
- " with open(output_path, 'w', newline='') as f_output:\n",
1865
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
1866
- "\n",
1867
- " for i, t in tqdm(enumerate(tmp['train'])):\n",
1868
- " text = t['text'][:num_max_characters].replace('\\t', ' ').replace(\"\\n\", ' ').replace('\\\"', '')\n",
1869
- " text = remove_special(text)\n",
1870
- " item = [start_id, text]\n",
1871
- " csv_output.writerow(item)\n",
1872
- " start_id+=1\n",
1873
- "\n",
1874
- " "
1875
- ]
1876
- },
1877
- {
1878
- "cell_type": "code",
1879
- "execution_count": null,
1880
- "metadata": {},
1881
- "outputs": [],
1882
- "source": []
1883
- },
1884
- {
1885
- "cell_type": "code",
1886
- "execution_count": null,
1887
- "metadata": {},
1888
- "outputs": [],
1889
- "source": [
1890
- "import csv\n",
1891
- "from io import StringIO\n",
1892
- "\n",
1893
- "output_path = '/data/mshukor/data/ofa/pretrain_ours/text_mini.tsv'\n",
1894
- "\n",
1895
- "start_id = 0 \n",
1896
- "num_max_characters = 2500\n",
1897
- "\n",
1898
- "with open(output_path, 'w', newline='') as f_output:\n",
1899
- " csv_output = csv.writer(f_output, delimiter='\\t')\n",
1900
- "\n",
1901
- " for i, t in tqdm(enumerate(tmp['train'])):\n",
1902
- " text = t['text'][:num_max_characters]\n",
1903
- " item = [start_id, text]\n",
1904
- " csv_output.writerow(item)\n",
1905
- " start_id+=1"
1906
- ]
1907
- },
1908
- {
1909
- "cell_type": "code",
1910
- "execution_count": null,
1911
- "metadata": {},
1912
- "outputs": [],
1913
- "source": [
1914
- "out_data = []\n",
1915
- "selected_cols='0,1'\n",
1916
- "\n",
1917
- "selected_col_ids = [int(col_id) for col_id in selected_cols.split(\",\")]\n",
1918
- "\n",
1919
- "with open(output_path) as file:\n",
1920
- " tsv_file = csv.reader(file, delimiter='\\t')\n",
1921
- " for line in tqdm(tsv_file):\n",
1922
- " d = [line[i] for i in selected_col_ids]\n",
1923
- " out_data.append(d)\n",
1924
- " "
1925
- ]
1926
- },
1927
- {
1928
- "cell_type": "code",
1929
- "execution_count": null,
1930
- "metadata": {},
1931
- "outputs": [],
1932
- "source": [
1933
- "out_data"
1934
- ]
1935
- },
1936
- {
1937
- "cell_type": "markdown",
1938
- "metadata": {},
1939
- "source": [
1940
- "### Create from finetuned data"
1941
- ]
1942
- },
1943
- {
1944
- "cell_type": "code",
1945
- "execution_count": null,
1946
- "metadata": {},
1947
- "outputs": [],
1948
- "source": [
1949
- "read from tsv and write to tsv directly \n",
1950
- "same for vqa v2\n",
1951
- "then create ofa_mini 4m, vqa and refcoco for pretraining "
1952
- ]
1953
- },
1954
- {
1955
- "cell_type": "markdown",
1956
- "metadata": {},
1957
- "source": [
1958
- "# Convert weights"
1959
- ]
1960
- },
1961
- {
1962
- "cell_type": "code",
1963
- "execution_count": 3,
1964
- "metadata": {},
1965
- "outputs": [],
1966
- "source": [
1967
- "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
1968
- "from models import ofa_base_architecture, OFAModel\n",
1969
- "from transformers.tokenization_utils_base import BatchEncoding"
1970
- ]
1971
- },
1972
- {
1973
- "cell_type": "markdown",
1974
- "metadata": {},
1975
- "source": [
1976
- "### Explore ofa"
1977
- ]
1978
- },
1979
- {
1980
- "cell_type": "code",
1981
- "execution_count": 4,
1982
- "metadata": {},
1983
- "outputs": [
1984
- {
1985
- "name": "stderr",
1986
- "output_type": "stream",
1987
- "text": [
1988
- "2022-11-15 08:52:08 | INFO | tasks.ofa_task | source dictionary: 59457 types\n",
1989
- "2022-11-15 08:52:08 | INFO | tasks.ofa_task | target dictionary: 59457 types\n"
1990
- ]
1991
- }
1992
- ],
1993
- "source": [
1994
- "import torch\n",
1995
- "import numpy as np\n",
1996
- "from fairseq import utils, tasks\n",
1997
- "from fairseq import checkpoint_utils\n",
1998
- "from utils.eval_utils import eval_step\n",
1999
- "from tasks.mm_tasks.caption import CaptionTask\n",
2000
- "from models.ofa import OFAModel\n",
2001
- "from PIL import Image\n",
2002
- "\n",
2003
- "# Register refcoco task\n",
2004
- "tasks.register_task('caption', CaptionTask)\n",
2005
- "\n",
2006
- "# turn on cuda if GPU is available\n",
2007
- "use_cuda = torch.cuda.is_available()\n",
2008
- "# use fp16 only when GPU is available\n",
2009
- "use_fp16 = False\n",
2010
- "\n",
2011
- "# Load pretrained ckpt & config\n",
2012
- "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":16, \"no_repeat_ngram_size\":3, \"seed\":7}\n",
2013
- "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
2014
- " utils.split_paths('/data/mshukor/logs/ofa/checkpoints/caption/ofa_caption_stage_1/5_0.06_6000/checkpoint_best.pt'),\n",
2015
- " arg_overrides=overrides\n",
2016
- " )\n",
2017
- "\n",
2018
- "# Move models to GPU\n",
2019
- "for model in models:\n",
2020
- " model.eval()\n",
2021
- " if use_fp16:\n",
2022
- " model.half()\n",
2023
- " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
2024
- " model.cuda()\n",
2025
- " model.prepare_for_inference_(cfg)\n",
2026
- "\n",
2027
- "# Initialize generator\n",
2028
- "generator = task.build_generator(models, cfg.generation)"
2029
- ]
2030
- },
2031
- {
2032
- "cell_type": "code",
2033
- "execution_count": 5,
2034
- "metadata": {},
2035
- "outputs": [],
2036
- "source": [
2037
- "model_ofa = models[0]\n",
2038
- "ofa_state = model_ofa.state_dict()"
2039
- ]
2040
- },
2041
- {
2042
- "cell_type": "code",
2043
- "execution_count": null,
2044
- "metadata": {},
2045
- "outputs": [],
2046
- "source": [
2047
- "def get_state_given_key(state, key, excluded_keys=None):\n",
2048
- " new_state = {}\n",
2049
- " for k, v in state.items():\n",
2050
- " if key in k:\n",
2051
- " if excluded_keys is not None:\n",
2052
- " if not any([ek in k for ek in excluded_keys]):\n",
2053
- " new_state[k] = v\n",
2054
- " else:\n",
2055
- " new_state[k] = v\n",
2056
- " return new_state\n",
2057
- "\n",
2058
- "key = 'encoder.layers.0'\n",
2059
- "excluded_keys = ['embed', 'image']\n",
2060
- "ofa_tmp = get_state_given_key(ofa_state, key, excluded_keys=excluded_keys)"
2061
- ]
2062
- },
2063
- {
2064
- "cell_type": "code",
2065
- "execution_count": null,
2066
- "metadata": {},
2067
- "outputs": [],
2068
- "source": [
2069
- "# def get_ofa_args_large(args):\n",
2070
- "# args['encoder_embed_path'] = getattr(args, \"encoder_embed_path\", None)\n",
2071
- "# args['encoder_embed_dim'] = getattr(args, \"encoder_embed_dim\", 1024)\n",
2072
- "# args['encoder_ffn_embed_dim'] = getattr(args, \"encoder_ffn_embed_dim\", 4 * 1024)\n",
2073
- "# args['encoder_layers'] = getattr(args, \"encoder_layers\", 12)\n",
2074
- "# args['encoder_attention_heads'] = getattr(args, \"encoder_attention_heads\", 16)\n",
2075
- "# args['encoder_normalize_before'] = getattr(args, \"encoder_normalize_before\", True)\n",
2076
- "# args['encoder_learned_pos'] = getattr(args, \"encoder_learned_pos\", True)\n",
2077
- "# args['decoder_embed_path'] = getattr(args, \"decoder_embed_path\", None)\n",
2078
- "# args['decoder_embed_dim'] = getattr(args, \"decoder_embed_dim\", args['encoder_embed_dim'])\n",
2079
- "# args['decoder_ffn_embed_dim'] = getattr(\n",
2080
- "# args, \"decoder_ffn_embed_dim\", args['encoder_ffn_embed_dim']\n",
2081
- "# )\n",
2082
- "# args['decoder_layers'] = getattr(args, \"decoder_layers\", 12)\n",
2083
- "# args['decoder_attention_heads'] = getattr(args, \"decoder_attention_heads\", 16)\n",
2084
- "# args['decoder_normalize_before'] = getattr(args, \"decoder_normalize_before\", True)\n",
2085
- "# args['decoder_learned_pos'] = getattr(args, \"decoder_learned_pos\", True)\n",
2086
- "# args['attention_dropout'] = getattr(args, \"attention_dropout\", 0.0)\n",
2087
- "# args['relu_dropout'] = getattr(args, \"relu_dropout\", 0.0)\n",
2088
- "# args['dropout'] = getattr(args, \"dropout\", 0.0)\n",
2089
- "# args['max_target_positions'] = getattr(args, \"max_target_positions\", 1024)\n",
2090
- "# args['max_source_positions'] = getattr(args, \"max_source_positions\", 1024)\n",
2091
- "# args['adaptive_softmax_cutoff'] = getattr(args, \"adaptive_softmax_cutoff\", None)\n",
2092
- "# args['adaptive_softmax_dropout'] = getattr(args, \"adaptive_softmax_dropout\", 0)\n",
2093
- "# args['share_decoder_input_output_embed'] = getattr(\n",
2094
- "# args, \"share_decoder_input_output_embed\", True\n",
2095
- "# )\n",
2096
- "# args['share_all_embeddings'] = getattr(args, \"share_all_embeddings\", True)\n",
2097
- "\n",
2098
- "# args['decoder_output_dim'] = getattr(\n",
2099
- "# args, \"decoder_output_dim\", args['decoder_embed_dim']\n",
2100
- "# )\n",
2101
- "# args['decoder_input_dim'] = getattr(args, \"decoder_input_dim\", args['decoder_embed_dim'])\n",
2102
- "\n",
2103
- "# args['no_scale_embedding'] = getattr(args, \"no_scale_embedding\", True)\n",
2104
- "# args['layernorm_embedding'] = getattr(args, \"layernorm_embedding\", True)\n",
2105
- "\n",
2106
- "# args['activation_fn'] = getattr(args, \"activation_fn\", \"gelu\")\n",
2107
- "# args['pooler_activation_fn'] = getattr(args, \"pooler_activation_fn\", \"tanh\")\n",
2108
- "# args['pooler_dropout'] = getattr(args, \"pooler_dropout\", 0.0)\n",
2109
- "# args['pooler_classifier'] = getattr(args, \"pooler_classifier\", \"mlp\")\n",
2110
- "\n",
2111
- "# args['resnet_drop_path_rate'] = getattr(args, \"resnet_drop_path_rate\", 0.0)\n",
2112
- "# args['encoder_drop_path_rate'] = getattr(args, \"encoder_drop_path_rate\", 0.0)\n",
2113
- "# args['decoder_drop_path_rate'] = getattr(args, \"decoder_drop_path_rate\", 0.0)\n",
2114
- "\n",
2115
- "# args['resnet_type'] = getattr(args, \"resnet_type\", \"resnet152\")\n",
2116
- "# args['token_bucket_size'] = getattr(args, \"token_bucket_size\", 256)\n",
2117
- "# args['image_bucket_size'] = getattr(args, \"image_bucket_size\", 42)\n",
2118
- "\n",
2119
- "# args['freeze_encoder_embedding'] = getattr(args, \"freeze_encoder_embedding\", False)\n",
2120
- "# args['freeze_decoder_embedding'] = getattr(args, \"freeze_decoder_embedding\", False)\n",
2121
- "# args['add_type_embedding'] = getattr(args, \"add_type_embedding\", True)\n",
2122
- "# args['attn_scale_factor'] = getattr(args, \"attn_scale_factor\", 2)\n",
2123
- "\n",
2124
- "# args['code_image_size'] = getattr(args, \"code_image_size\", 128)\n",
2125
- "# args['patch_layernorm_embedding'] = getattr(args, \"patch_layernorm_embedding\", True)\n",
2126
- "# args['code_layernorm_embedding'] = getattr(args, \"code_layernorm_embedding\", True)\n",
2127
- "# args['entangle_position_embedding'] = getattr(args, \"entangle_position_embedding\", False)\n",
2128
- "# args['disable_entangle'] = getattr(args, \"disable_entangle\", False)\n",
2129
- "# args['sync_bn'] = getattr(args, \"sync_bn\", False)\n",
2130
- "\n",
2131
- "# args['scale_attn'] = getattr(args, \"scale_attn\", False)\n",
2132
- "# args['scale_fc'] = getattr(args, \"scale_fc\", False)\n",
2133
- "# args['scale_heads'] = getattr(args, \"scale_heads\", False)\n",
2134
- "# args['scale_resids'] = getattr(args, \"scale_resids\", False)\n",
2135
- "\n",
2136
- "# args['orig_patch_image_size'] = getattr(args, \"orig_patch_image_size\", 256)\n",
2137
- "\n",
2138
- "# return args"
2139
- ]
2140
- },
2141
- {
2142
- "cell_type": "code",
2143
- "execution_count": null,
2144
- "metadata": {},
2145
- "outputs": [],
2146
- "source": [
2147
- "# args = {}\n",
2148
- "# args = get_ofa_args_large(args)\n",
2149
- "# args = BatchEncoding(args)\n",
2150
- "# ofa_base_architecture(args)\n",
2151
- "# data_dir = '/data/mshukor/data/ofa/pretrain_example'\n",
2152
- "\n",
2153
- "# cfg.task.neg_sample_dir = data_dir+'/negative_sample'"
2154
- ]
2155
- },
2156
- {
2157
- "cell_type": "markdown",
2158
- "metadata": {},
2159
- "source": [
2160
- "### convert t5 weights"
2161
- ]
2162
- },
2163
- {
2164
- "cell_type": "code",
2165
- "execution_count": 6,
2166
- "metadata": {},
2167
- "outputs": [],
2168
- "source": [
2169
- "model_t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")"
2170
- ]
2171
- },
2172
- {
2173
- "cell_type": "code",
2174
- "execution_count": null,
2175
- "metadata": {},
2176
- "outputs": [],
2177
- "source": [
2178
- "model_t5"
2179
- ]
2180
- },
2181
- {
2182
- "cell_type": "code",
2183
- "execution_count": 7,
2184
- "metadata": {},
2185
- "outputs": [],
2186
- "source": [
2187
- "t5_state = model_t5.state_dict()"
2188
- ]
2189
- },
2190
- {
2191
- "cell_type": "code",
2192
- "execution_count": 56,
2193
- "metadata": {},
2194
- "outputs": [],
2195
- "source": [
2196
- "import re\n",
2197
- "# line = re.sub(r\"</?\\[\\d+>\", \"\", line)\n",
2198
- "\n",
2199
- "mapping_dict = {\n",
2200
- " ## encoder\n",
2201
- " 'block': 'layers', \n",
2202
- " 'layer.[0-9]+.SelfAttention': 'self_attn', \n",
2203
- " '.q.': '.q_proj.', \n",
2204
- " '.k.weight': '.k_proj.weight', \n",
2205
- " '.v.': '.v_proj.', \n",
2206
- " # '.o.weight': '.out_proj.weight', \n",
2207
- " 'layer.0.layer_norm.': 'self_attn_layer_norm.', \n",
2208
- " 'layer.[0-9]+.DenseReluDense.': '', \n",
2209
- " '.wi.': '.fc1.', \n",
2210
- " '.wo.': '.fc2.', \n",
2211
- " \n",
2212
- " \n",
2213
- " # decoder\n",
2214
- " 'layer.[0-9]+.EncDecAttention': 'encoder_attn', \n",
2215
- " # 'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
2216
- " \n",
2217
- " \n",
2218
- "}\n",
2219
- "\n",
2220
- "encoder_mapping = {\n",
2221
- " 'layer.1.layer_norm.': 'final_layer_norm.', \n",
2222
- "}\n",
2223
- "\n",
2224
- "decoder_mapping = {\n",
2225
- " 'layer.1.layer_norm.': 'encoder_attn_layer_norm.', \n",
2226
- " 'layer.2.layer_norm.': 'final_layer_norm.', \n",
2227
- "}\n",
2228
- "\n",
2229
- "\n",
2230
- "simple_replace_mapping = {\n",
2231
- " \n",
2232
- " '.o.weight': '.out_proj.weight', \n",
2233
- "}\n",
2234
- "def modify_state(state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping):\n",
2235
- " # orig_keys = ['block', 'layer.[0-9]+.SelfAttention', '.q.', '.k.', '.v.', '.o.', '0.layer_norm.', '.DenseReluDense.wi.', '.DenseReluDense.wo.', '.1.layer_norm.']\n",
2236
- " # new_keys = ['layers', 'layer.self_attn', '.q_proj.', '.k_proj.', '.v_proj.', '.out_proj.', '.self_attn_layer_norm.', '.fc1.', '.fc2.', '.final_layer_norm.']\n",
2237
- " \n",
2238
- " new_state = state.copy()\n",
2239
- " old_keys = []\n",
2240
- " for k, v in state.items():\n",
2241
- " \n",
2242
- " new_key = '%s' % k \n",
2243
- " for old, new in simple_replace_mapping.items():\n",
2244
- " new_key = new_key.replace(old, new)\n",
2245
- " \n",
2246
- " for old, new in mapping_dict.items():\n",
2247
- " new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
2248
- " \n",
2249
- " if 'encoder' in new_key:\n",
2250
- " for old, new in encoder_mapping.items():\n",
2251
- " new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
2252
- " \n",
2253
- " if 'decoder' in new_key:\n",
2254
- " for old, new in decoder_mapping.items():\n",
2255
- " new_key = re.sub(r\"{}\".format(old), new, new_key)\n",
2256
- " \n",
2257
- " new_state[new_key] = v\n",
2258
- " old_keys.append(k)\n",
2259
- " \n",
2260
- " \n",
2261
- " \n",
2262
- " \n",
2263
- " for k in old_keys:\n",
2264
- " del new_state[k]\n",
2265
- " \n",
2266
- " final_state = {}\n",
2267
- " final_state['model'] = new_state\n",
2268
- " return final_state\n",
2269
- " \n",
2270
- "new_state = modify_state(t5_state, mapping_dict, encoder_mapping, decoder_mapping, simple_replace_mapping)\n",
2271
- "\n"
2272
- ]
2273
- },
2274
- {
2275
- "cell_type": "code",
2276
- "execution_count": null,
2277
- "metadata": {},
2278
- "outputs": [],
2279
- "source": [
2280
- "new_state['model'].keys()"
2281
- ]
2282
- },
2283
- {
2284
- "cell_type": "code",
2285
- "execution_count": null,
2286
- "metadata": {},
2287
- "outputs": [],
2288
- "source": [
2289
- "def compare_states(state1, state2):\n",
2290
- " different = []\n",
2291
- " for k1 in state1.keys():\n",
2292
- " if k1 not in state2:\n",
2293
- " different.append(k1)\n",
2294
- " return different\n",
2295
- " \n",
2296
- "tmp = compare_states(new_state, ofa_state)"
2297
- ]
2298
- },
2299
- {
2300
- "cell_type": "code",
2301
- "execution_count": 35,
2302
- "metadata": {},
2303
- "outputs": [],
2304
- "source": [
2305
- "output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
2306
- "torch.save(new_state, output_path)"
2307
- ]
2308
- },
2309
- {
2310
- "cell_type": "code",
2311
- "execution_count": 51,
2312
- "metadata": {},
2313
- "outputs": [],
2314
- "source": [
2315
- "output_path = '/data/mshukor/logs/ofa/pretrained_models/t5_base.pt'\n",
2316
- "\n",
2317
- "tmp_state = torch.load(output_path)"
2318
- ]
2319
- },
2320
- {
2321
- "cell_type": "code",
2322
- "execution_count": null,
2323
- "metadata": {},
2324
- "outputs": [],
2325
- "source": [
2326
- "\n",
2327
- "model_ofa.load_state_dict(tmp_state['model'], strict=False)"
2328
- ]
2329
- },
2330
- {
2331
- "cell_type": "code",
2332
- "execution_count": null,
2333
- "metadata": {},
2334
- "outputs": [],
2335
- "source": [
2336
- "tmp_state['model'].keys()"
2337
- ]
2338
- },
2339
- {
2340
- "cell_type": "code",
2341
- "execution_count": 18,
2342
- "metadata": {},
2343
- "outputs": [],
2344
- "source": [
2345
- "tmp_state = torch.load('/data/mshukor/logs/ofa/pretrained_models/ofa_base.pt')\n"
2346
- ]
2347
- },
2348
- {
2349
- "cell_type": "code",
2350
- "execution_count": 19,
2351
- "metadata": {},
2352
- "outputs": [
2353
- {
2354
- "data": {
2355
- "text/plain": [
2356
- "dict_keys(['args', 'cfg', 'model', 'criterion', 'optimizer_history', 'task_state', 'extra_state', 'last_optimizer_state'])"
2357
- ]
2358
- },
2359
- "execution_count": 19,
2360
- "metadata": {},
2361
- "output_type": "execute_result"
2362
- }
2363
- ],
2364
- "source": [
2365
- "tmp_state.keys()"
2366
- ]
2367
- },
2368
- {
2369
- "cell_type": "code",
2370
- "execution_count": null,
2371
- "metadata": {},
2372
- "outputs": [],
2373
- "source": []
2374
- },
2375
- {
2376
- "cell_type": "code",
2377
- "execution_count": null,
2378
- "metadata": {},
2379
- "outputs": [],
2380
- "source": [
2381
- "model_t5.encoder.block[0]"
2382
- ]
2383
- },
2384
- {
2385
- "cell_type": "code",
2386
- "execution_count": null,
2387
- "metadata": {},
2388
- "outputs": [],
2389
- "source": [
2390
- "model_ofa.encoder.layers[0]"
2391
- ]
2392
- },
2393
- {
2394
- "cell_type": "markdown",
2395
- "metadata": {},
2396
- "source": [
2397
- "### convert BART weights"
2398
- ]
2399
- },
2400
- {
2401
- "cell_type": "code",
2402
- "execution_count": 7,
2403
- "metadata": {},
2404
- "outputs": [],
2405
- "source": [
2406
- "weights_path = '/data/mshukor/logs/ofa/pretrained_models/bart.base/model.pt'\n",
2407
- "bart_state = torch.load(weights_path, map_location=torch.device('cpu'))"
2408
- ]
2409
- },
2410
- {
2411
- "cell_type": "code",
2412
- "execution_count": 13,
2413
- "metadata": {},
2414
- "outputs": [
2415
- {
2416
- "data": {
2417
- "text/plain": [
2418
- "<All keys matched successfully>"
2419
- ]
2420
- },
2421
- "execution_count": 13,
2422
- "metadata": {},
2423
- "output_type": "execute_result"
2424
- }
2425
- ],
2426
- "source": [
2427
- "model_ofa.load_state_dict(bart_state['model'], strict=True)"
2428
- ]
2429
- },
2430
- {
2431
- "cell_type": "code",
2432
- "execution_count": 9,
2433
- "metadata": {},
2434
- "outputs": [
2435
- {
2436
- "data": {
2437
- "text/plain": [
2438
- "odict_keys(['encoder.version', 'encoder.embed_tokens.weight', 'encoder.embed_positions.weight', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0.self_attn.out_proj.bias', 'encoder.layers.0.self_attn_layer_norm.weight', 'encoder.layers.0.self_attn_layer_norm.bias', 'encoder.layers.0.fc1.weight', 'encoder.layers.0.fc1.bias', 'encoder.layers.0.fc2.weight', 'encoder.layers.0.fc2.bias', 'encoder.layers.0.final_layer_norm.weight', 'encoder.layers.0.final_layer_norm.bias', 'encoder.layers.1.self_attn.k_proj.weight', 'encoder.layers.1.self_attn.k_proj.bias', 'encoder.layers.1.self_attn.v_proj.weight', 'encoder.layers.1.self_attn.v_proj.bias', 'encoder.layers.1.self_attn.q_proj.weight', 'encoder.layers.1.self_attn.q_proj.bias', 'encoder.layers.1.self_attn.out_proj.weight', 'encoder.layers.1.self_attn.out_proj.bias', 'encoder.layers.1.self_attn_layer_norm.weight', 'encoder.layers.1.self_attn_layer_norm.bias', 'encoder.layers.1.fc1.weight', 'encoder.layers.1.fc1.bias', 'encoder.layers.1.fc2.weight', 'encoder.layers.1.fc2.bias', 'encoder.layers.1.final_layer_norm.weight', 'encoder.layers.1.final_layer_norm.bias', 'encoder.layers.2.self_attn.k_proj.weight', 'encoder.layers.2.self_attn.k_proj.bias', 'encoder.layers.2.self_attn.v_proj.weight', 'encoder.layers.2.self_attn.v_proj.bias', 'encoder.layers.2.self_attn.q_proj.weight', 'encoder.layers.2.self_attn.q_proj.bias', 'encoder.layers.2.self_attn.out_proj.weight', 'encoder.layers.2.self_attn.out_proj.bias', 'encoder.layers.2.self_attn_layer_norm.weight', 'encoder.layers.2.self_attn_layer_norm.bias', 'encoder.layers.2.fc1.weight', 'encoder.layers.2.fc1.bias', 'encoder.layers.2.fc2.weight', 'encoder.layers.2.fc2.bias', 'encoder.layers.2.final_layer_norm.weight', 'encoder.layers.2.final_layer_norm.bias', 'encoder.layers.3.self_attn.k_proj.weight', 'encoder.layers.3.self_attn.k_proj.bias', 'encoder.layers.3.self_attn.v_proj.weight', 'encoder.layers.3.self_attn.v_proj.bias', 'encoder.layers.3.self_attn.q_proj.weight', 'encoder.layers.3.self_attn.q_proj.bias', 'encoder.layers.3.self_attn.out_proj.weight', 'encoder.layers.3.self_attn.out_proj.bias', 'encoder.layers.3.self_attn_layer_norm.weight', 'encoder.layers.3.self_attn_layer_norm.bias', 'encoder.layers.3.fc1.weight', 'encoder.layers.3.fc1.bias', 'encoder.layers.3.fc2.weight', 'encoder.layers.3.fc2.bias', 'encoder.layers.3.final_layer_norm.weight', 'encoder.layers.3.final_layer_norm.bias', 'encoder.layers.4.self_attn.k_proj.weight', 'encoder.layers.4.self_attn.k_proj.bias', 'encoder.layers.4.self_attn.v_proj.weight', 'encoder.layers.4.self_attn.v_proj.bias', 'encoder.layers.4.self_attn.q_proj.weight', 'encoder.layers.4.self_attn.q_proj.bias', 'encoder.layers.4.self_attn.out_proj.weight', 'encoder.layers.4.self_attn.out_proj.bias', 'encoder.layers.4.self_attn_layer_norm.weight', 'encoder.layers.4.self_attn_layer_norm.bias', 'encoder.layers.4.fc1.weight', 'encoder.layers.4.fc1.bias', 'encoder.layers.4.fc2.weight', 'encoder.layers.4.fc2.bias', 'encoder.layers.4.final_layer_norm.weight', 'encoder.layers.4.final_layer_norm.bias', 'encoder.layers.5.self_attn.k_proj.weight', 'encoder.layers.5.self_attn.k_proj.bias', 'encoder.layers.5.self_attn.v_proj.weight', 'encoder.layers.5.self_attn.v_proj.bias', 'encoder.layers.5.self_attn.q_proj.weight', 'encoder.layers.5.self_attn.q_proj.bias', 'encoder.layers.5.self_attn.out_proj.weight', 'encoder.layers.5.self_attn.out_proj.bias', 'encoder.layers.5.self_attn_layer_norm.weight', 'encoder.layers.5.self_attn_layer_norm.bias', 'encoder.layers.5.fc1.weight', 'encoder.layers.5.fc1.bias', 'encoder.layers.5.fc2.weight', 'encoder.layers.5.fc2.bias', 'encoder.layers.5.final_layer_norm.weight', 'encoder.layers.5.final_layer_norm.bias', 'encoder.layernorm_embedding.weight', 'encoder.layernorm_embedding.bias', 'decoder.version', 'decoder.embed_tokens.weight', 'decoder.embed_positions.weight', 'decoder.layers.0.self_attn.k_proj.weight', 'decoder.layers.0.self_attn.k_proj.bias', 'decoder.layers.0.self_attn.v_proj.weight', 'decoder.layers.0.self_attn.v_proj.bias', 'decoder.layers.0.self_attn.q_proj.weight', 'decoder.layers.0.self_attn.q_proj.bias', 'decoder.layers.0.self_attn.out_proj.weight', 'decoder.layers.0.self_attn.out_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.self_attn_layer_norm.bias', 'decoder.layers.0.encoder_attn.k_proj.weight', 'decoder.layers.0.encoder_attn.k_proj.bias', 'decoder.layers.0.encoder_attn.v_proj.weight', 'decoder.layers.0.encoder_attn.v_proj.bias', 'decoder.layers.0.encoder_attn.q_proj.weight', 'decoder.layers.0.encoder_attn.q_proj.bias', 'decoder.layers.0.encoder_attn.out_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.bias', 'decoder.layers.0.encoder_attn_layer_norm.weight', 'decoder.layers.0.encoder_attn_layer_norm.bias', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc1.bias', 'decoder.layers.0.fc2.weight', 'decoder.layers.0.fc2.bias', 'decoder.layers.0.final_layer_norm.weight', 'decoder.layers.0.final_layer_norm.bias', 'decoder.layers.1.self_attn.k_proj.weight', 'decoder.layers.1.self_attn.k_proj.bias', 'decoder.layers.1.self_attn.v_proj.weight', 'decoder.layers.1.self_attn.v_proj.bias', 'decoder.layers.1.self_attn.q_proj.weight', 'decoder.layers.1.self_attn.q_proj.bias', 'decoder.layers.1.self_attn.out_proj.weight', 'decoder.layers.1.self_attn.out_proj.bias', 'decoder.layers.1.self_attn_layer_norm.weight', 'decoder.layers.1.self_attn_layer_norm.bias', 'decoder.layers.1.encoder_attn.k_proj.weight', 'decoder.layers.1.encoder_attn.k_proj.bias', 'decoder.layers.1.encoder_attn.v_proj.weight', 'decoder.layers.1.encoder_attn.v_proj.bias', 'decoder.layers.1.encoder_attn.q_proj.weight', 'decoder.layers.1.encoder_attn.q_proj.bias', 'decoder.layers.1.encoder_attn.out_proj.weight', 'decoder.layers.1.encoder_attn.out_proj.bias', 'decoder.layers.1.encoder_attn_layer_norm.weight', 'decoder.layers.1.encoder_attn_layer_norm.bias', 'decoder.layers.1.fc1.weight', 'decoder.layers.1.fc1.bias', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.fc2.bias', 'decoder.layers.1.final_layer_norm.weight', 'decoder.layers.1.final_layer_norm.bias', 'decoder.layers.2.self_attn.k_proj.weight', 'decoder.layers.2.self_attn.k_proj.bias', 'decoder.layers.2.self_attn.v_proj.weight', 'decoder.layers.2.self_attn.v_proj.bias', 'decoder.layers.2.self_attn.q_proj.weight', 'decoder.layers.2.self_attn.q_proj.bias', 'decoder.layers.2.self_attn.out_proj.weight', 'decoder.layers.2.self_attn.out_proj.bias', 'decoder.layers.2.self_attn_layer_norm.weight', 'decoder.layers.2.self_attn_layer_norm.bias', 'decoder.layers.2.encoder_attn.k_proj.weight', 'decoder.layers.2.encoder_attn.k_proj.bias', 'decoder.layers.2.encoder_attn.v_proj.weight', 'decoder.layers.2.encoder_attn.v_proj.bias', 'decoder.layers.2.encoder_attn.q_proj.weight', 'decoder.layers.2.encoder_attn.q_proj.bias', 'decoder.layers.2.encoder_attn.out_proj.weight', 'decoder.layers.2.encoder_attn.out_proj.bias', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.encoder_attn_layer_norm.bias', 'decoder.layers.2.fc1.weight', 'decoder.layers.2.fc1.bias', 'decoder.layers.2.fc2.weight', 'decoder.layers.2.fc2.bias', 'decoder.layers.2.final_layer_norm.weight', 'decoder.layers.2.final_layer_norm.bias', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.3.self_attn.k_proj.bias', 'decoder.layers.3.self_attn.v_proj.weight', 'decoder.layers.3.self_attn.v_proj.bias', 'decoder.layers.3.self_attn.q_proj.weight', 'decoder.layers.3.self_attn.q_proj.bias', 'decoder.layers.3.self_attn.out_proj.weight', 'decoder.layers.3.self_attn.out_proj.bias', 'decoder.layers.3.self_attn_layer_norm.weight', 'decoder.layers.3.self_attn_layer_norm.bias', 'decoder.layers.3.encoder_attn.k_proj.weight', 'decoder.layers.3.encoder_attn.k_proj.bias', 'decoder.layers.3.encoder_attn.v_proj.weight', 'decoder.layers.3.encoder_attn.v_proj.bias', 'decoder.layers.3.encoder_attn.q_proj.weight', 'decoder.layers.3.encoder_attn.q_proj.bias', 'decoder.layers.3.encoder_attn.out_proj.weight', 'decoder.layers.3.encoder_attn.out_proj.bias', 'decoder.layers.3.encoder_attn_layer_norm.weight', 'decoder.layers.3.encoder_attn_layer_norm.bias', 'decoder.layers.3.fc1.weight', 'decoder.layers.3.fc1.bias', 'decoder.layers.3.fc2.weight', 'decoder.layers.3.fc2.bias', 'decoder.layers.3.final_layer_norm.weight', 'decoder.layers.3.final_layer_norm.bias', 'decoder.layers.4.self_attn.k_proj.weight', 'decoder.layers.4.self_attn.k_proj.bias', 'decoder.layers.4.self_attn.v_proj.weight', 'decoder.layers.4.self_attn.v_proj.bias', 'decoder.layers.4.self_attn.q_proj.weight', 'decoder.layers.4.self_attn.q_proj.bias', 'decoder.layers.4.self_attn.out_proj.weight', 'decoder.layers.4.self_attn.out_proj.bias', 'decoder.layers.4.self_attn_layer_norm.weight', 'decoder.layers.4.self_attn_layer_norm.bias', 'decoder.layers.4.encoder_attn.k_proj.weight', 'decoder.layers.4.encoder_attn.k_proj.bias', 'decoder.layers.4.encoder_attn.v_proj.weight', 'decoder.layers.4.encoder_attn.v_proj.bias', 'decoder.layers.4.encoder_attn.q_proj.weight', 'decoder.layers.4.encoder_attn.q_proj.bias', 'decoder.layers.4.encoder_attn.out_proj.weight', 'decoder.layers.4.encoder_attn.out_proj.bias', 'decoder.layers.4.encoder_attn_layer_norm.weight', 'decoder.layers.4.encoder_attn_layer_norm.bias', 'decoder.layers.4.fc1.weight', 'decoder.layers.4.fc1.bias', 'decoder.layers.4.fc2.weight', 'decoder.layers.4.fc2.bias', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.4.final_layer_norm.bias', 'decoder.layers.5.self_attn.k_proj.weight', 'decoder.layers.5.self_attn.k_proj.bias', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.5.self_attn.v_proj.bias', 'decoder.layers.5.self_attn.q_proj.weight', 'decoder.layers.5.self_attn.q_proj.bias', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.5.self_attn.out_proj.bias', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.5.self_attn_layer_norm.bias', 'decoder.layers.5.encoder_attn.k_proj.weight', 'decoder.layers.5.encoder_attn.k_proj.bias', 'decoder.layers.5.encoder_attn.v_proj.weight', 'decoder.layers.5.encoder_attn.v_proj.bias', 'decoder.layers.5.encoder_attn.q_proj.weight', 'decoder.layers.5.encoder_attn.q_proj.bias', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.5.encoder_attn.out_proj.bias', 'decoder.layers.5.encoder_attn_layer_norm.weight', 'decoder.layers.5.encoder_attn_layer_norm.bias', 'decoder.layers.5.fc1.weight', 'decoder.layers.5.fc1.bias', 'decoder.layers.5.fc2.weight', 'decoder.layers.5.fc2.bias', 'decoder.layers.5.final_layer_norm.weight', 'decoder.layers.5.final_layer_norm.bias', 'decoder.layernorm_embedding.weight', 'decoder.layernorm_embedding.bias'])"
2439
- ]
2440
- },
2441
- "execution_count": 9,
2442
- "metadata": {},
2443
- "output_type": "execute_result"
2444
- }
2445
- ],
2446
- "source": [
2447
- "bart_state['model'].keys()"
2448
- ]
2449
- },
2450
- {
2451
- "cell_type": "code",
2452
- "execution_count": 12,
2453
- "metadata": {},
2454
- "outputs": [
2455
- {
2456
- "data": {
2457
- "text/plain": [
2458
- "tensor([[ 0.0125, 0.0014, -0.0096, ..., 0.0022, 0.1057, 0.0103],\n",
2459
- " [-0.0114, -0.0169, -0.0184, ..., -0.0131, -0.0043, -0.0053],\n",
2460
- " [ 0.0842, -0.0389, 0.0096, ..., 0.0583, 0.0082, 0.0357],\n",
2461
- " ...,\n",
2462
- " [-0.0032, -0.0313, -0.1026, ..., 0.0138, 0.0056, -0.0023],\n",
2463
- " [ 0.0104, -0.0045, 0.0263, ..., 0.0158, 0.0324, -0.0111],\n",
2464
- " [-0.0038, -0.0532, -0.0147, ..., 0.0067, 0.0256, 0.0009]])"
2465
- ]
2466
- },
2467
- "execution_count": 12,
2468
- "metadata": {},
2469
- "output_type": "execute_result"
2470
- }
2471
- ],
2472
- "source": [
2473
- "ofa_state.keys()\n",
2474
- "ofa_state['encoder.embed_tokens.weight']"
2475
- ]
2476
- }
2477
- ],
2478
- "metadata": {
2479
- "kernelspec": {
2480
- "display_name": "ofa",
2481
- "language": "python",
2482
- "name": "ofa"
2483
- },
2484
- "language_info": {
2485
- "codemirror_mode": {
2486
- "name": "ipython",
2487
- "version": 3
2488
- },
2489
- "file_extension": ".py",
2490
- "mimetype": "text/x-python",
2491
- "name": "python",
2492
- "nbconvert_exporter": "python",
2493
- "pygments_lexer": "ipython3",
2494
- "version": "3.7.4"
2495
- }
2496
- },
2497
- "nbformat": 4,
2498
- "nbformat_minor": 4
2499
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompt_tuning.md DELETED
@@ -1,66 +0,0 @@
1
- <!---
2
- Copyright 2022 The OFA-Sys Team.
3
- All rights reserved.
4
- This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory.
5
- -->
6
-
7
- ## Prompt Tuning for Generative Multimodal Pretrained Models
8
-
9
- ### Overview
10
- This is the code for **"Prompt Tuning for Generative Multimodal Pretrained Models"**, [Check our paper on ArXiv](https://arxiv.org/abs/2208.02532). This paper explores prompt tuning for generative multimodal pretrained models, instead of the constrastive learning models. We specifically focuses on the unified sequence-to-sequence learning framework and implement on our OFA models.
11
- <br>
12
-
13
- ### Requirements
14
- * python 3.7.4
15
- * pytorch 1.8.1
16
- * torchvision 0.9.1
17
- * JAVA 1.8 (for COCO evaluation)
18
- <br></br>
19
-
20
- ### Installation
21
- ```bash
22
- pip install -r requirements.txt
23
- ```
24
- <br>
25
-
26
- ### Datasets and Checkpoints
27
- See [datasets.md](datasets.md) and [checkpoints.md](checkpoints.md).
28
- <br>
29
-
30
- ### Training
31
- We provide a demo script (`run_scripts/refcoco/train_refcoco_prefix.sh`) that has all the required parts for training.
32
-
33
- ```sh
34
- sh ./run_scripts/refcoco/train_refcoco_prefix.sh
35
- ```
36
- A few options of note:
37
- * `--encoder-prompt` :: whether to insert prompts to the encoder
38
- * `--decoder-prompt` :: whether to insert prompts to the decoder
39
- * `--encoder-prompt-length` :: encoder prompt length
40
- * `--decoder-prompt-length` :: decoder prompt length
41
- * `--bitfit` :: whether to use bitfit
42
- * `--adapter` :: whether to use adapter
43
- * `--adapter-dim` :: adapter projection dim
44
-
45
- We recommend that your workspace directory should be organized like this:
46
- ```
47
- OFA/
48
- ├── checkpoints/
49
- │   ├── ofa_base.pt
50
- │   ├── ofa_large.pt
51
- │   └── ...
52
- ├── criterions/
53
- ├── data/
54
- ├── dataset/
55
- │   ├── caption_data/
56
- │   ├── refcoco_data/
57
- │   └── ...
58
- ├── fairseq/
59
- ├── models/
60
- ├── run_scripts/
61
- ├── tasks/
62
- ├── train.py
63
- ├── trainer.py
64
- └── utils/
65
- ```
66
- <br>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces.md DELETED
@@ -1,8 +0,0 @@
1
- # Spaces
2
- To provide better experience, we plan to build demos for our OFA models on Huggingface Spaces. Below we provide links to the demos. Have fun!
3
-
4
- * Generic Interface: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Generic_Interface)
5
- * Text-to-Image Generation: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Text2Image_Generation)
6
- * Image Captioning: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)
7
- * Referring Expression Comprehension: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Grounding)
8
- * Visual Question Answering: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Question_Answering)
 
 
 
 
 
 
 
 
 
test.py DELETED
@@ -1,101 +0,0 @@
1
- from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG
2
- import soundfile as sf
3
- import io
4
- import torch
5
- import numpy as np
6
-
7
- AUDIO_CFG = {
8
- "audio_length": 1024,
9
- "clip_samples": 480000,
10
- "mel_bins": 64,
11
- "sample_rate": 48000,
12
- "window_size": 1024,
13
- "hop_size": 480,
14
- "fmin": 50,
15
- "fmax": 14000,
16
- "class_num": 527,
17
- }
18
-
19
-
20
-
21
- audio_cfg = AUDIO_CFG
22
- max_len = 480000
23
- data_path = '/work/NAT/gda2204/mshukor/data/audiocaps/train/--CHY2qO5zc.wav'
24
-
25
- audio_data, orig_sr = sf.read(data_path)
26
- # import librosa
27
- # audio_data, orig_sr = librosa.load(data_path, sr=48000)
28
-
29
- print(orig_sr)
30
- if audio_data.ndim>1:
31
- audio_data = np.mean(audio_data,axis=1)
32
-
33
-
34
- print(audio_data.shape, audio_data)
35
-
36
- audio_data = int16_to_float32(float32_to_int16(audio_data))
37
- audio_data = torch.tensor(audio_data).float()
38
- print(audio_data.dtype)
39
- print(audio_data.shape, audio_data)
40
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
41
- sample = {}
42
-
43
- sample = get_audio_features(
44
- sample, audio_data, max_len,
45
- data_truncating='fusion',
46
- data_filling='repeatpad',
47
- audio_cfg=audio_cfg,
48
- )
49
-
50
- patch_audio = sample['waveform'] #.half()
51
- print(patch_audio.shape, patch_audio.min(), patch_audio.max(), patch_audio)
52
-
53
- patch_audio = torch.zeros(480000)
54
- print(patch_audio.shape)
55
-
56
-
57
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
58
-
59
- AUDIO_CFG = {
60
- "sample_rate": 48000,
61
- "audio_length": 1024,
62
- "clip_samples": 480000,
63
- "mel_bins": 64,
64
- "sample_rate": 48000,
65
- "window_size": 1024,
66
- "hop_size": 480,
67
- "fmin": 50,
68
- "fmax": 14000,
69
- "class_num": 527,
70
- }
71
-
72
- window = 'hann'
73
- center = True
74
- pad_mode = 'reflect'
75
- ref = 1.0
76
- amin = 1e-10
77
- top_db = None
78
-
79
- spectrogram_extractor = Spectrogram(n_fft=AUDIO_CFG['window_size'], hop_length=AUDIO_CFG['hop_size'],
80
- win_length=AUDIO_CFG['window_size'], window=window, center=center, pad_mode=pad_mode,
81
- freeze_parameters=True)
82
-
83
-
84
- logmel_extractor = LogmelFilterBank(sr=AUDIO_CFG['sample_rate'], n_fft=AUDIO_CFG['window_size'],
85
- n_mels=AUDIO_CFG['mel_bins'], fmin=AUDIO_CFG['fmin'], fmax=AUDIO_CFG['fmax'],
86
- ref=ref, amin=amin, top_db=top_db,
87
- freeze_parameters=True)#.half()
88
-
89
-
90
- patch_audio = patch_audio[None, :]
91
- print(patch_audio.shape)
92
- spectro = spectrogram_extractor(patch_audio)
93
-
94
- print(spectro.shape)
95
- print(spectro)
96
-
97
-
98
- mel = logmel_extractor(spectro)
99
-
100
- print(mel.shape)
101
- print(mel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,729 +0,0 @@
1
- #!/usr/bin/env python3 -u
2
- # Copyright 2022 The OFA-Sys Team.
3
- # All rights reserved.
4
- # This source code is licensed under the Apache 2.0 license
5
- # found in the LICENSE file in the root directory.
6
-
7
- """
8
- Train a new model on one or across multiple GPUs.
9
- """
10
-
11
- import argparse
12
- import logging
13
- import math
14
- import os
15
- import sys
16
- from typing import Dict, Optional, Any, List, Tuple, Callable
17
-
18
- # We need to setup root logger before importing any fairseq libraries.
19
- logging.basicConfig(
20
- format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
21
- datefmt="%Y-%m-%d %H:%M:%S",
22
- level=os.environ.get("LOGLEVEL", "INFO").upper(),
23
- stream=sys.stdout,
24
- )
25
- logger = logging.getLogger("fairseq_cli.train")
26
-
27
- import numpy as np
28
- import torch
29
- from fairseq import (
30
- # checkpoint_utils,
31
- options,
32
- quantization_utils,
33
- tasks,
34
- utils,
35
- )
36
- from fairseq.data import iterators
37
- from fairseq.data.plasma_utils import PlasmaStore
38
- from fairseq.dataclass.configs import FairseqConfig
39
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
40
- from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
41
- from fairseq.file_io import PathManager
42
- from fairseq.logging import meters, metrics, progress_bar
43
- from fairseq.model_parallel.megatron_trainer import MegatronTrainer
44
- # from fairseq.trainer import Trainer
45
- from omegaconf import DictConfig, OmegaConf
46
-
47
- from utils import checkpoint_utils
48
- from trainer import Trainer
49
-
50
- from utils.utils import print_trainable_params_percentage, setup_for_distributed
51
-
52
- import psutil
53
-
54
- def main(cfg: FairseqConfig) -> None:
55
- print(distributed_utils.is_master(cfg.distributed_training))
56
- print(cfg.distributed_training)
57
- setup_for_distributed(distributed_utils.is_master(cfg.distributed_training))
58
-
59
- if isinstance(cfg, argparse.Namespace):
60
- cfg = convert_namespace_to_omegaconf(cfg)
61
-
62
- utils.import_user_module(cfg.common)
63
-
64
- if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
65
- # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
66
- logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
67
-
68
- assert (
69
- cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
70
- ), "Must specify batch size either with --max-tokens or --batch-size"
71
- metrics.reset()
72
-
73
- if cfg.common.log_file is not None:
74
- handler = logging.FileHandler(filename=cfg.common.log_file)
75
- logger.addHandler(handler)
76
-
77
- np.random.seed(cfg.common.seed)
78
- utils.set_torch_seed(cfg.common.seed)
79
-
80
- if distributed_utils.is_master(cfg.distributed_training):
81
- print(cfg.checkpoint.save_dir)
82
- checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
83
-
84
-
85
- # Print args
86
- logger.info(cfg)
87
-
88
-
89
- if cfg.checkpoint.write_checkpoints_asynchronously:
90
- try:
91
- import iopath # noqa: F401
92
- except ImportError:
93
- logging.exception(
94
- "Asynchronous checkpoint writing is specified but iopath is "
95
- "not installed: `pip install iopath`"
96
- )
97
- return
98
-
99
- # Setup task, e.g., translation, language modeling, etc.
100
- task = tasks.setup_task(cfg.task)
101
-
102
- assert cfg.criterion, "Please specify criterion to train a model"
103
-
104
- # Build model and criterion
105
- if cfg.distributed_training.ddp_backend == "fully_sharded":
106
- with fsdp_enable_wrap(cfg.distributed_training):
107
- model = fsdp_wrap(task.build_model(cfg.model))
108
- else:
109
- model = task.build_model(cfg.model)
110
-
111
- # bitfit
112
- if cfg.model.bitfit:
113
- for name, param in model.named_parameters():
114
- if ("layer_norm" in name and "bias" in name) or ("fc" in name and "bias" in name):
115
- param.requires_grad = True
116
- else:
117
- param.requires_grad = False
118
-
119
- criterion = task.build_criterion(cfg.criterion)
120
-
121
- logger.info(model)
122
- logger.info("task: {}".format(task.__class__.__name__))
123
- logger.info("model: {}".format(model.__class__.__name__))
124
- logger.info("criterion: {}".format(criterion.__class__.__name__))
125
- logger.info(
126
- "num. shared model params: {:,} (num. trained: {:,})".format(
127
- sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
128
- sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
129
- )
130
- )
131
-
132
- logger.info(
133
- "num. expert model params: {} (num. trained: {})".format(
134
- sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
135
- sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
136
- )
137
- )
138
-
139
- # Load valid dataset (we load training data below, based on the latest checkpoint)
140
- # We load the valid dataset AFTER building the model
141
- # data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
142
- if cfg.dataset.combine_valid_subsets:
143
- task.load_dataset("valid", combine=True, epoch=1)
144
- else:
145
- for valid_sub_split in cfg.dataset.valid_subset.split(","):
146
- task.load_dataset(valid_sub_split, combine=False, epoch=1)
147
-
148
- # (optionally) Configure quantization
149
- if cfg.common.quantization_config_path is not None:
150
- quantizer = quantization_utils.Quantizer(
151
- config_path=cfg.common.quantization_config_path,
152
- max_epoch=cfg.optimization.max_epoch,
153
- max_update=cfg.optimization.max_update,
154
- )
155
- else:
156
- quantizer = None
157
-
158
-
159
-
160
-
161
- # for n, p in model.named_parameters():
162
- # if not p.requires_grad:
163
- # print(n)
164
-
165
- # Build trainer
166
- if cfg.common.model_parallel_size == 1:
167
- trainer = Trainer(cfg, task, model, criterion, quantizer)
168
- else:
169
- trainer = MegatronTrainer(cfg, task, model, criterion)
170
- logger.info(
171
- "training on {} devices (GPUs/TPUs)".format(
172
- cfg.distributed_training.distributed_world_size
173
- )
174
- )
175
- logger.info(
176
- "max tokens per device = {} and max sentences per device = {}".format(
177
- cfg.dataset.max_tokens,
178
- cfg.dataset.batch_size,
179
- )
180
- )
181
-
182
-
183
- # Load the latest checkpoint if one is available and restore the
184
- # corresponding train iterator
185
- strict = getattr(cfg.model, 'strict', True)
186
- logger.info('load checkpoint, strict:{}'.format(strict))
187
- extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
188
- cfg.checkpoint,
189
- trainer,
190
- strict=strict,
191
- # don't cache epoch iterators for sharded datasets
192
- disable_iterator_cache=True,
193
- load_on_cuda=cfg.checkpoint.load_on_cuda,
194
- )
195
- if cfg.common.tpu:
196
- import torch_xla.core.xla_model as xm
197
- xm.rendezvous("load_checkpoint") # wait for all workers
198
-
199
- max_epoch = cfg.optimization.max_epoch or math.inf
200
- if max_epoch > 0 and max_epoch != math.inf:
201
- total_num_updates = sum(
202
- math.ceil(len(epoch_itr) / cfg.optimization.update_freq[i])
203
- if i < len(cfg.optimization.update_freq) else
204
- math.ceil(len(epoch_itr) / cfg.optimization.update_freq[-1])
205
- for i in range(max_epoch)
206
- )
207
- trainer.lr_reinit(total_num_updates, trainer.get_num_updates())
208
-
209
- # if getattr(cfg.model, "freeze_encoder", False):
210
- # for idx, layer in enumerate(model.encoder.layers):
211
- # layer.requires_grad_(False)
212
- # if getattr(cfg.model, "freeze_decoder", False):
213
- # for idx, layer in enumerate(model.decoder.layers):
214
- # layer.requires_grad_(False)
215
-
216
- # if hasattr(cfg.model, 'progressive') or getattr(cfg.model, "freeze_perception", False):
217
- # custom_unfreeze(trainer, epoch_itr, cfg.model)
218
-
219
- # if hasattr(cfg.model, 'only_linear_proj') and getattr(cfg.model, "only_linear_proj", False):
220
- # model.requires_grad_(False)
221
- # model.encoder.embed_tokens.weight.requires_grad = True
222
- # model.decoder.embed_tokens.weight.requires_grad = True
223
-
224
- # if getattr(cfg.model, "freeze_encoder_embedding", False) or getattr(
225
- # cfg.model, "encoder_prompt", False) or getattr(cfg.model, "decoder_prompt", False) or getattr(cfg.model, "adapter", False):
226
- # model.encoder.embed_tokens.weight.requires_grad = False
227
- # if getattr(cfg.model, "freeze_decoder_embedding", False) or getattr(
228
- # cfg.model, "encoder_prompt", False) or getattr(cfg.model, "decoder_prompt", False) or getattr(cfg.model, "adapter", False):
229
- # model.decoder.embed_tokens.weight.requires_grad = False
230
-
231
-
232
- # model.encoder.image_proj.requires_grad_(True)
233
- # if getattr(cfg.model, "video_encoder_name", None):
234
- # model.encoder.video_proj.requires_grad_(True)
235
- # if getattr(cfg.model, "audio_encoder_name", None):
236
- # model.encoder.audio_proj.requires_grad_(True)
237
-
238
-
239
- print_trainable_params_percentage(model)
240
-
241
- lr = trainer.get_lr()
242
-
243
- train_meter = meters.StopwatchMeter()
244
- train_meter.start()
245
- while epoch_itr.next_epoch_idx <= max_epoch:
246
- if lr <= cfg.optimization.stop_min_lr:
247
- logger.info(
248
- f"stopping training because current learning rate ({lr}) is smaller "
249
- "than or equal to minimum learning rate "
250
- f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
251
- )
252
- break
253
-
254
- # train for one epoch
255
-
256
- valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
257
- if should_stop:
258
- break
259
-
260
- # only use first validation loss to update the learning rate
261
- lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
262
-
263
- epoch_itr = trainer.get_train_iterator(
264
- epoch_itr.next_epoch_idx,
265
- # sharded data: get train iterator for next epoch
266
- load_dataset=True,
267
- # don't cache epoch iterators for sharded datasets
268
- disable_iterator_cache=True,
269
- )
270
- train_meter.stop()
271
- logger.info("done training in {:.1f} seconds".format(train_meter.sum))
272
-
273
- # ioPath implementation to wait for all asynchronous file writes to complete.
274
- if cfg.checkpoint.write_checkpoints_asynchronously:
275
- logger.info(
276
- "ioPath PathManager waiting for all asynchronous checkpoint "
277
- "writes to finish."
278
- )
279
- PathManager.async_close()
280
- logger.info("ioPath PathManager finished waiting.")
281
-
282
-
283
- def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
284
- # skip check if no validation was done in the current epoch
285
- if valid_loss is None:
286
- return False
287
- if cfg.checkpoint.patience <= 0:
288
- return False
289
-
290
- def is_better(a, b):
291
- return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
292
-
293
- prev_best = getattr(should_stop_early, "best", None)
294
- if prev_best is None or is_better(valid_loss, prev_best):
295
- should_stop_early.best = valid_loss
296
- should_stop_early.num_runs = 0
297
- return False
298
- else:
299
- should_stop_early.num_runs += 1
300
- if should_stop_early.num_runs >= cfg.checkpoint.patience:
301
- logger.info(
302
- "early stop since valid performance hasn't improved for last {} runs".format(
303
- cfg.checkpoint.patience
304
- )
305
- )
306
- return True
307
- else:
308
- return False
309
-
310
-
311
- @metrics.aggregate("train")
312
- def train(
313
- cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
314
- ) -> Tuple[List[Optional[float]], bool]:
315
- """Train the model for one epoch and return validation losses."""
316
- # Initialize data iterator
317
- itr = epoch_itr.next_epoch_itr(
318
- fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
319
- shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
320
- )
321
- update_freq = (
322
- cfg.optimization.update_freq[epoch_itr.epoch - 1]
323
- if epoch_itr.epoch <= len(cfg.optimization.update_freq)
324
- else cfg.optimization.update_freq[-1]
325
- )
326
- itr = iterators.GroupedIterator(itr, update_freq)
327
- if cfg.common.tpu:
328
- itr = utils.tpu_data_loader(itr)
329
- progress = progress_bar.progress_bar(
330
- itr,
331
- log_format=cfg.common.log_format,
332
- log_file=cfg.common.log_file,
333
- log_interval=cfg.common.log_interval,
334
- epoch=epoch_itr.epoch,
335
- tensorboard_logdir=(
336
- cfg.common.tensorboard_logdir
337
- if distributed_utils.is_master(cfg.distributed_training)
338
- else None
339
- ),
340
- default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
341
- wandb_project=(
342
- cfg.common.wandb_project
343
- if distributed_utils.is_master(cfg.distributed_training)
344
- else None
345
- ),
346
- wandb_run_name=os.environ.get(
347
- "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
348
- ),
349
- azureml_logging=(
350
- cfg.common.azureml_logging
351
- if distributed_utils.is_master(cfg.distributed_training)
352
- else False
353
- ),
354
- )
355
- progress.update_config(_flatten_config(cfg))
356
-
357
- trainer.begin_epoch(epoch_itr.epoch)
358
-
359
- valid_subsets = cfg.dataset.valid_subset.split(",")
360
- should_stop = False
361
- num_updates = trainer.get_num_updates()
362
- logger.info("Start iterating over samples")
363
- for i, samples in enumerate(progress):
364
- with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
365
- "train_step-%d" % i
366
- ):
367
- log_output = trainer.train_step(samples)
368
-
369
- if log_output is not None: # not OOM, overflow, ...
370
- # log mid-epoch stats
371
- num_updates = trainer.get_num_updates()
372
- if num_updates % cfg.common.log_interval == 0:
373
- stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
374
- progress.log(stats, tag="train_inner", step=num_updates)
375
-
376
- # reset mid-epoch stats after each log interval
377
- # the end-of-epoch stats will still be preserved
378
- metrics.reset_meters("train_inner")
379
-
380
- end_of_epoch = not itr.has_next()
381
- valid_losses, should_stop = validate_and_save(
382
- cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
383
- )
384
-
385
- # if (hasattr(cfg.model, 'progressive') or hasattr(cfg.model, 'only_linear_proj') or hasattr(cfg.model, 'freeze_perception')) and end_of_epoch:
386
- # custom_unfreeze(trainer, epoch_itr, cfg.model)
387
- # print_trainable_params_percentage(trainer.model)
388
-
389
- if should_stop:
390
- break
391
-
392
-
393
- # print(i, len(progress))
394
- # if i > 5:
395
- # break
396
- # log end-of-epoch stats
397
- logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
398
- stats = get_training_stats(metrics.get_smoothed_values("train"))
399
- progress.print(stats, tag="train", step=num_updates)
400
- print_trainable_params_percentage(trainer.model)
401
- # reset epoch-level meters
402
- metrics.reset_meters("train")
403
- return valid_losses, should_stop
404
-
405
- # progressive training
406
- def custom_unfreeze(trainer, epoch_itr, cfg):
407
- model = trainer.model
408
- epoch = epoch_itr.epoch
409
- print("Epoch, ", epoch)
410
- ## unfreeze epochs
411
- unfreeze_epoch_encoder = cfg.unfreeze_epoch_encoder
412
- unfreeze_epoch_decoder = cfg.unfreeze_epoch_decoder
413
-
414
- unfreeze_epoch_image = cfg.unfreeze_epoch_image
415
- unfreeze_epoch_video = cfg.unfreeze_epoch_video
416
- unfreeze_epoch_audio = cfg.unfreeze_epoch_audio
417
-
418
-
419
- if getattr(cfg, "only_linear_proj", False):
420
- unfreeze_epoch = cfg.unfreeze_epoch
421
- if epoch >= unfreeze_epoch:
422
- model.requires_grad_(True)
423
-
424
- if getattr(cfg, "freeze_encoder_embedding", False) or getattr(
425
- cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
426
- model.encoder.embed_tokens.weight.requires_grad = False
427
- if getattr(cfg, "freeze_decoder_embedding", False) or getattr(
428
- cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
429
- model.decoder.embed_tokens.weight.requires_grad = False
430
-
431
- if trainer._ema is not None:
432
- trainer._ema.requires_grad_(False)
433
- print_trainable_params_percentage(model)
434
- return
435
-
436
- if getattr(cfg, "freeze_perception", False):
437
-
438
- if hasattr(model.encoder, 'embed_images'):
439
- if epoch >= unfreeze_epoch_image:
440
- grad = True
441
- else:
442
- grad = False
443
- model.encoder.embed_images.requires_grad_(grad)
444
- print('model.encoder.embed_images.requires_grad', grad)
445
- if hasattr(model.encoder, 'embed_videos'):
446
- if epoch >= unfreeze_epoch_video:
447
- grad = True
448
- else:
449
- grad = False
450
- model.encoder.embed_videos.requires_grad_(grad)
451
- print('model.encoder.embed_videos.requires_grad', grad)
452
-
453
- if hasattr(model.encoder, 'embed_audios'):
454
- if epoch >= unfreeze_epoch_audio:
455
- grad = True
456
- else:
457
- grad = False
458
- model.encoder.embed_audios.requires_grad_(grad)
459
- print('model.encoder.embed_audios.requires_grad', grad)
460
-
461
- if trainer._ema is not None:
462
- trainer._ema.requires_grad_(False)
463
- return
464
-
465
- if epoch >= unfreeze_epoch_encoder:
466
- grad=True
467
- else:
468
- grad=False
469
- for l in model.encoder.layers:
470
- l.requires_grad_(grad)
471
- print('model.encoder.layers.requires_grad', grad)
472
-
473
- if epoch >= unfreeze_epoch_decoder:
474
- grad=True
475
- else:
476
- grad=False
477
- for l in model.decoder.layers:
478
- l.requires_grad_(grad)
479
- print('model.decoder.layers.requires_grad', grad)
480
-
481
- if getattr(cfg, "freeze_encoder_embedding", False) or getattr(
482
- cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
483
- model.encoder.embed_tokens.weight.requires_grad = False
484
- if getattr(cfg, "freeze_decoder_embedding", False) or getattr(
485
- cfg, "encoder_prompt", False) or getattr(cfg, "decoder_prompt", False) or getattr(cfg, "adapter", False):
486
- model.decoder.embed_tokens.weight.requires_grad = False
487
-
488
- if getattr(cfg, "encoder_prompt", False):
489
- model.encoder.encoder_prompt_encoder.requires_grad_(True)
490
- if getattr(cfg, "decoder_prompt", False):
491
- model.decoder.decoder_prompt_encoder.requires_grad_(True)
492
- if getattr(cfg, "adapter", False):
493
- for idx, layer in enumerate(model.encoder.layers):
494
- layer.adapter.requires_grad_(True)
495
- for idx, layer in enumerate(model.decoder.layers):
496
- layer.adapter.requires_grad_(True)
497
-
498
- if hasattr(model.encoder, 'embed_images'):
499
- if epoch >= unfreeze_epoch_image:
500
- grad = True
501
- else:
502
- grad = False
503
- model.encoder.embed_images.requires_grad_(grad)
504
- print('model.encoder.embed_images.requires_grad', grad)
505
- if hasattr(model.encoder, 'embed_videos'):
506
- if epoch >= unfreeze_epoch_video:
507
- grad = True
508
- else:
509
- grad = False
510
- model.encoder.embed_videos.requires_grad_(grad)
511
- print('model.encoder.embed_videos.requires_grad', grad)
512
-
513
- if hasattr(model.encoder, 'embed_audios'):
514
- if epoch >= unfreeze_epoch_audio:
515
- grad = True
516
- else:
517
- grad = False
518
- model.encoder.embed_audios.requires_grad_(grad)
519
- print('model.encoder.embed_audios.requires_grad', grad)
520
-
521
- if trainer._ema is not None:
522
- trainer._ema.requires_grad_(False)
523
-
524
- def _flatten_config(cfg: DictConfig):
525
- config = OmegaConf.to_container(cfg)
526
- # remove any legacy Namespaces and replace with a single "args"
527
- namespace = None
528
- for k, v in list(config.items()):
529
- if isinstance(v, argparse.Namespace):
530
- namespace = v
531
- del config[k]
532
- if namespace is not None:
533
- config["args"] = vars(namespace)
534
- return config
535
-
536
-
537
- def validate_and_save(
538
- cfg: DictConfig,
539
- trainer: Trainer,
540
- task: tasks.FairseqTask,
541
- epoch_itr,
542
- valid_subsets: List[str],
543
- end_of_epoch: bool,
544
- ) -> Tuple[List[Optional[float]], bool]:
545
- num_updates = trainer.get_num_updates()
546
- max_update = cfg.optimization.max_update or math.inf
547
-
548
- # Stopping conditions (and an additional one based on validation loss later
549
- # on)
550
- should_stop = False
551
- if num_updates >= max_update:
552
- should_stop = True
553
- logger.info(
554
- f"Stopping training due to "
555
- f"num_updates: {num_updates} >= max_update: {max_update}"
556
- )
557
-
558
- training_time_hours = trainer.cumulative_training_time() / (60 * 60)
559
- if (
560
- cfg.optimization.stop_time_hours > 0
561
- and training_time_hours > cfg.optimization.stop_time_hours
562
- ):
563
- should_stop = True
564
- logger.info(
565
- f"Stopping training due to "
566
- f"cumulative_training_time: {training_time_hours} > "
567
- f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
568
- )
569
-
570
- do_save = (
571
- (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
572
- or should_stop
573
- or (
574
- cfg.checkpoint.save_interval_updates > 0
575
- and num_updates > 0
576
- and num_updates % cfg.checkpoint.save_interval_updates == 0
577
- and num_updates >= cfg.dataset.validate_after_updates
578
- )
579
- )
580
- do_validate = (
581
- (not end_of_epoch and do_save) # validate during mid-epoch saves
582
- or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
583
- or should_stop
584
- or (
585
- cfg.dataset.validate_interval_updates > 0
586
- and num_updates > 0
587
- and num_updates % cfg.dataset.validate_interval_updates == 0
588
- )
589
- ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates
590
-
591
- # Validate
592
- valid_losses = [None]
593
- if do_validate:
594
- valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
595
-
596
- should_stop |= should_stop_early(cfg, valid_losses[0])
597
-
598
- # Save checkpoint
599
- if do_save or should_stop:
600
- checkpoint_utils.save_checkpoint(
601
- cfg.checkpoint, trainer, epoch_itr, valid_losses[0], save_on_cuda=cfg.checkpoint.save_on_cuda,
602
- )
603
-
604
- return valid_losses, should_stop
605
-
606
-
607
- def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
608
- stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
609
- return stats
610
-
611
-
612
- def validate(
613
- cfg: DictConfig,
614
- trainer: Trainer,
615
- task: tasks.FairseqTask,
616
- epoch_itr,
617
- subsets: List[str],
618
- ) -> List[Optional[float]]:
619
- """Evaluate the model on the validation set(s) and return the losses."""
620
-
621
- if cfg.dataset.fixed_validation_seed is not None:
622
- # set fixed seed for every validation
623
- utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
624
-
625
- trainer.begin_valid_epoch(epoch_itr.epoch)
626
- valid_losses = []
627
- for subset in subsets:
628
- logger.info('begin validation on "{}" subset'.format(subset))
629
-
630
- # Initialize data iterator
631
- itr = trainer.get_valid_iterator(subset).next_epoch_itr(
632
- shuffle=False, set_dataset_epoch=False # use a fixed valid set
633
- )
634
- if cfg.common.tpu:
635
- itr = utils.tpu_data_loader(itr)
636
- progress = progress_bar.progress_bar(
637
- itr,
638
- log_format=cfg.common.log_format,
639
- log_interval=cfg.common.log_interval,
640
- epoch=epoch_itr.epoch,
641
- prefix=f"valid on '{subset}' subset",
642
- tensorboard_logdir=(
643
- cfg.common.tensorboard_logdir
644
- if distributed_utils.is_master(cfg.distributed_training)
645
- else None
646
- ),
647
- default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
648
- wandb_project=(
649
- cfg.common.wandb_project
650
- if distributed_utils.is_master(cfg.distributed_training)
651
- else None
652
- ),
653
- wandb_run_name=os.environ.get(
654
- "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
655
- ),
656
- )
657
-
658
- # create a new root metrics aggregator so validation metrics
659
- # don't pollute other aggregators (e.g., train meters)
660
- with metrics.aggregate(new_root=True) as agg:
661
- for i, sample in enumerate(progress):
662
- if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
663
- break
664
- try:
665
- # print(sample)
666
- trainer.valid_step(sample)
667
- except IndexError:
668
- # print(sample)
669
- print('didnt pass')
670
- trainer.valid_step(sample)
671
- continue
672
-
673
- # log validation stats
674
- if hasattr(task, 'get_valid_stats'):
675
- stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values())
676
- else:
677
- stats = agg.get_smoothed_values()
678
- stats = get_valid_stats(cfg, trainer, stats)
679
-
680
- if hasattr(task, "post_validate"):
681
- task.post_validate(trainer.get_model(), stats, agg)
682
-
683
-
684
- progress.print(stats, tag=subset, step=trainer.get_num_updates())
685
-
686
- valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
687
- return valid_losses
688
-
689
-
690
- def get_valid_stats(
691
- cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
692
- ) -> Dict[str, Any]:
693
- stats["num_updates"] = trainer.get_num_updates()
694
- if hasattr(checkpoint_utils.save_checkpoint, "best"):
695
- key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
696
- best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
697
- stats[key] = best_function(
698
- checkpoint_utils.save_checkpoint.best,
699
- stats[cfg.checkpoint.best_checkpoint_metric],
700
- )
701
- return stats
702
-
703
-
704
- def cli_main(
705
- modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
706
- ) -> None:
707
- parser = options.get_training_parser()
708
- args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
709
- print(args)
710
- cfg = convert_namespace_to_omegaconf(args)
711
-
712
- if cfg.common.use_plasma_view:
713
- server = PlasmaStore(path=cfg.common.plasma_path)
714
- logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")
715
-
716
- if args.profile:
717
- with torch.cuda.profiler.profile():
718
- with torch.autograd.profiler.emit_nvtx():
719
- distributed_utils.call_main(cfg, main)
720
- else:
721
- distributed_utils.call_main(cfg, main)
722
-
723
- # if cfg.common.use_plasma_view:
724
- # server.server.kill()
725
-
726
-
727
- if __name__ == "__main__":
728
- cli_main()
729
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trainer.py DELETED
@@ -1,1569 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- """
7
- Train a network across multiple GPUs.
8
- """
9
-
10
- import contextlib
11
- import logging
12
- import sys
13
- import time
14
- from argparse import Namespace
15
- from itertools import chain
16
- from typing import Any, Dict, List
17
-
18
- import torch
19
- from fairseq import models, optim, utils
20
- from fairseq.dataclass.configs import FairseqConfig
21
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
- from fairseq.distributed import utils as distributed_utils
23
- from fairseq.file_io import PathManager
24
- from fairseq.logging import meters, metrics
25
- from fairseq.models.ema import build_ema
26
- from fairseq.nan_detector import NanDetector
27
- from fairseq.optim import lr_scheduler
28
- from omegaconf import OmegaConf
29
-
30
- from utils import checkpoint_utils
31
- import torch.nn as nn
32
-
33
- logger = logging.getLogger(__name__)
34
-
35
-
36
- class Trainer(object):
37
- """Main class for data parallel training.
38
-
39
- This class supports synchronous distributed data parallel training,
40
- where multiple workers each have a full model replica and gradients
41
- are accumulated across workers before each update. We use
42
- :class:`~torch.nn.parallel.DistributedDataParallel` to handle
43
- communication of the gradients across workers.
44
- """
45
-
46
- def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
47
-
48
- if isinstance(cfg, Namespace):
49
- logger.warning(
50
- "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
51
- )
52
- cfg = convert_namespace_to_omegaconf(cfg)
53
-
54
- self.cfg = cfg
55
- self.task = task
56
-
57
- # catalog shared parameters
58
- shared_params = _catalog_shared_params(model)
59
- self.tpu = cfg.common.tpu
60
- self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
61
- if self.cuda:
62
- self.device = torch.device("cuda")
63
- elif self.tpu:
64
- self.device = utils.get_tpu_device()
65
- else:
66
- self.device = torch.device("cpu")
67
-
68
- if self.is_fsdp:
69
- import fairscale
70
- if self.cfg.common.bf16:
71
- raise ValueError(
72
- "FullyShardedDataParallel is not compatible with --bf16 or "
73
- "--memory-efficient-bf16"
74
- )
75
- if self.cfg.distributed_training.zero_sharding != "none":
76
- raise ValueError(
77
- "FullyShardedDataParallel is not compatible with --zero-sharding "
78
- "option (it's already built in)"
79
- )
80
- if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0":
81
- raise RuntimeError(
82
- "Please update to fairscale 0.4.0 or newer when combining "
83
- "--update-freq with FullyShardedDataParallel"
84
- )
85
- else:
86
- if (
87
- hasattr(self.cfg.distributed_training, "cpu_offload")
88
- and self.cfg.distributed_training.cpu_offload
89
- ):
90
- raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
91
-
92
- # copy model and criterion to current device/dtype
93
- self._criterion = criterion
94
- self._model = model
95
- if not self.is_fsdp:
96
- if cfg.common.fp16:
97
- assert not cfg.common.amp, "Cannot use fp16 and AMP together"
98
- self._criterion = self._criterion.half()
99
- self._model = self._model.half()
100
-
101
- if hasattr(self._model.encoder, 'embed_audios'):
102
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
103
- for layer in self._model.modules():
104
- if isinstance(layer, LogmelFilterBank) or isinstance(layer, Spectrogram):
105
- layer.float()
106
- print(layer)
107
- # for layer in self._model.modules():
108
- # if isinstance(layer, nn.BatchNorm2d):
109
- # layer.float()
110
- # print(layer)
111
-
112
-
113
- elif cfg.common.bf16:
114
- self._criterion = self._criterion.to(dtype=torch.bfloat16)
115
- self._model = self._model.to(dtype=torch.bfloat16)
116
- elif cfg.common.amp:
117
- self._amp_retries = 0
118
- if (
119
- not cfg.distributed_training.pipeline_model_parallel
120
- # the DistributedFairseqModel wrapper will handle moving to device,
121
- # so only handle cases which don't use the wrapper
122
- and not self.use_distributed_wrapper
123
- ):
124
- self._criterion = self._criterion.to(device=self.device)
125
- self._model = self._model.to(device=self.device)
126
- self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
127
- self.last_device = None
128
- if self.cuda and self.pipeline_model_parallel:
129
- self.last_device = torch.device(
130
- cfg.distributed_training.pipeline_devices[-1]
131
- )
132
-
133
- # check that shared parameters are preserved after device transfer
134
- for shared_param in shared_params:
135
- ref = _get_module_by_path(self._model, shared_param[0])
136
- for path in shared_param[1:]:
137
- logger.info(
138
- "detected shared parameter: {} <- {}".format(shared_param[0], path)
139
- )
140
- _set_module_by_path(self._model, path, ref)
141
-
142
- self._dummy_batch = None # indicates we don't have a dummy batch at first
143
- self._lr_scheduler = None
144
- self._num_updates = 0
145
- self._num_xla_compiles = 0 # for TPUs
146
- self._optim_history = None
147
- self._optimizer = None
148
- self._warn_once = set()
149
- self._wrapped_criterion = None
150
- self._wrapped_model = None
151
- self._ema = None
152
-
153
- # TODO(myleott): support tpu
154
- if self.cuda and self.data_parallel_world_size > 1:
155
- self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
156
- else:
157
- self._grad_norm_buf = None
158
-
159
- self.quantizer = quantizer
160
- if self.quantizer is not None:
161
- self.quantizer.set_trainer(self)
162
-
163
- # get detailed cuda environment
164
- if self.cuda:
165
- self.cuda_env = utils.CudaEnvironment()
166
- if self.data_parallel_world_size > 1:
167
- self.cuda_env_arr = distributed_utils.all_gather_list(
168
- self.cuda_env, group=distributed_utils.get_global_group()
169
- )
170
- else:
171
- self.cuda_env_arr = [self.cuda_env]
172
- if self.data_parallel_rank == 0:
173
- utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
174
- else:
175
- self.cuda_env = None
176
- self.cuda_env_arr = None
177
-
178
- metrics.log_start_time("wall", priority=790, round=0)
179
-
180
- self._start_time = time.time()
181
- self._previous_training_time = 0
182
- self._cumulative_training_time = None
183
-
184
- def reinitialize(self):
185
- """Reinitialize the Trainer, typically after model params change."""
186
- self._lr_scheduler = None
187
- self._optimizer = None
188
- self._wrapped_criterion = None
189
- self._wrapped_model = None
190
-
191
- @property
192
- def data_parallel_world_size(self):
193
- if self.cfg.distributed_training.distributed_world_size == 1:
194
- return 1
195
- return distributed_utils.get_data_parallel_world_size()
196
-
197
- @property
198
- def data_parallel_process_group(self):
199
- return distributed_utils.get_data_parallel_group()
200
-
201
- @property
202
- def data_parallel_rank(self):
203
- if self.cfg.distributed_training.distributed_world_size == 1:
204
- return 0
205
- return distributed_utils.get_data_parallel_rank()
206
-
207
- @property
208
- def is_data_parallel_master(self):
209
- # NOTE: this returns true for all model parallel replicas with data
210
- # parallel rank 0
211
- return self.data_parallel_rank == 0
212
-
213
- @property
214
- def use_distributed_wrapper(self) -> bool:
215
- return (
216
- self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf
217
- ) or (
218
- self.is_fsdp and self.cfg.distributed_training.cpu_offload
219
- )
220
-
221
- @property
222
- def should_save_checkpoint_on_current_rank(self) -> bool:
223
- """Indicates whether to save checkpoints on the current DDP rank."""
224
- if (
225
- self.is_fsdp and self.cfg.distributed_training.use_sharded_state
226
- ) or getattr(self.cfg.model, "base_layers", 0) > 0:
227
- return True
228
- else:
229
- return self.is_data_parallel_master
230
-
231
- @property
232
- def always_call_state_dict_during_save_checkpoint(self) -> bool:
233
- if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state:
234
- # FSDP calls communication collective when consolidating checkpoints
235
- return True
236
- else:
237
- return False
238
-
239
- @property
240
- def checkpoint_suffix(self) -> str:
241
- """Suffix to add to the checkpoint file name."""
242
- if self.is_fsdp and self.cfg.distributed_training.use_sharded_state:
243
- return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(
244
- self.data_parallel_rank
245
- )
246
- else:
247
- return self.cfg.checkpoint.checkpoint_suffix or ""
248
-
249
- @property
250
- def criterion(self):
251
- if self._wrapped_criterion is None:
252
- if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
253
- self._wrapped_criterion = models.DistributedFairseqModel(
254
- self.cfg.distributed_training,
255
- self._criterion,
256
- process_group=self.data_parallel_process_group,
257
- device=self.device,
258
- )
259
- else:
260
- self._wrapped_criterion = self._criterion
261
- return self._wrapped_criterion
262
-
263
- @property
264
- def model(self):
265
- if self._wrapped_model is None:
266
- if self.use_distributed_wrapper:
267
- self._wrapped_model = models.DistributedFairseqModel(
268
- self.cfg.distributed_training,
269
- self._model,
270
- process_group=self.data_parallel_process_group,
271
- device=self.device,
272
- )
273
- else:
274
- self._wrapped_model = self._model
275
- return self._wrapped_model
276
-
277
- @property
278
- def ema(self):
279
- if self._ema is None:
280
- self._build_ema()
281
- return self._ema
282
-
283
- def _build_ema(self):
284
- if self.cfg.ema.store_ema:
285
- self._ema = build_ema(self._model, self.cfg.ema, self.device)
286
- logger.info(
287
- "Exponential Moving Average Shadow Model is initialized."
288
- )
289
-
290
- @property
291
- def optimizer(self):
292
- if self._optimizer is None:
293
- self._build_optimizer()
294
- return self._optimizer
295
-
296
- @property
297
- def lr_scheduler(self):
298
- if self._lr_scheduler is None:
299
- self._build_optimizer() # this will initialize self._lr_scheduler
300
- return self._lr_scheduler
301
-
302
- def _build_optimizer(self):
303
- # params = list(self.model.parameters())
304
- # print("len of model param:", len(params))
305
- # params += list(
306
- # filter(
307
- # lambda p: p.requires_grad,
308
- # chain(self.criterion.parameters()),
309
- # )
310
- # )
311
-
312
-
313
- params = list(
314
- filter(
315
- lambda p: p.requires_grad,
316
- chain(self.model.parameters(), self.criterion.parameters()),
317
- )
318
- )
319
- print("len of optim param:", len(params))
320
-
321
- if self.is_fsdp and self.cfg.common.fp16:
322
- # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
323
- # mostly for the grad scaling. But if we don't have the
324
- # --memory-efficient-fp16 flag set, then we're effectively doing
325
- # regular --fp16 and can allow the use of optimizers that would
326
- # otherwise be unsupported by MemoryEfficientFP16Optimizer.
327
- allow_unsupported = not self.cfg.common.memory_efficient_fp16
328
- self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
329
- self.cfg, params, allow_unsupported=allow_unsupported
330
- )
331
- elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
332
- if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
333
- logger.info(
334
- "NOTE: your device does NOT support faster training with --fp16 or --amp, "
335
- "please switch to FP32 which is likely to be faster"
336
- )
337
- if (
338
- self.cfg.common.memory_efficient_fp16
339
- or self.cfg.common.memory_efficient_bf16
340
- ):
341
- self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
342
- self.cfg, params
343
- )
344
- elif self.cfg.common.amp:
345
- self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
346
- else:
347
- self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
348
- else:
349
- if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
350
- logger.info("NOTE: your device may support faster training with --fp16 or --amp")
351
- self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
352
-
353
- if self.is_fsdp:
354
- assert (
355
- not self.cfg.optimization.use_bmuf
356
- ), "--ddp-backend=fully_sharded is not compatible with BMUF"
357
- assert self._optimizer.supports_flat_params, (
358
- "--ddp-backend=fully_sharded is only compatible with pointwise "
359
- "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
360
- "However, the sharding will result in slightly different results when "
361
- "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
362
- )
363
-
364
- if self.cfg.optimization.use_bmuf:
365
- self._optimizer = optim.FairseqBMUF(
366
- self.cfg.bmuf,
367
- self._optimizer,
368
- )
369
-
370
- if self.cfg.distributed_training.zero_sharding == "os":
371
- if (
372
- self.cfg.common.fp16
373
- and not self.cfg.common.memory_efficient_fp16
374
- and not self.cfg.common.memory_efficient_bf16
375
- ) and not self.cfg.common.fp16_no_flatten_grads:
376
- raise ValueError(
377
- "ZeRO is incomptabile with fp16 and flattened grads. "
378
- "Please use --fp16-no-flatten-grads"
379
- )
380
- else:
381
- optim.shard_(self._optimizer, self.data_parallel_process_group)
382
-
383
- # We should initialize the learning rate scheduler immediately after
384
- # building the optimizer, so that the initial learning rate is set.
385
- self._lr_scheduler = lr_scheduler.build_lr_scheduler(
386
- self.cfg.lr_scheduler,
387
- self.optimizer,
388
- )
389
- self._lr_scheduler.step_update(0)
390
-
391
- @property
392
- def is_fsdp(self):
393
- return self.cfg.distributed_training.ddp_backend == "fully_sharded"
394
-
395
- def consolidate_optimizer(self):
396
- """For OSS, we need to consolidate the state dict."""
397
- if self.cfg.checkpoint.no_save_optimizer_state:
398
- return
399
- self._gathered_optim_state = None
400
- if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
401
- self.optimizer.optimizer.consolidate_state_dict()
402
- elif self.is_fsdp and not self.model.use_sharded_state:
403
- st = self.model.gather_full_optim_state_dict(
404
- self.optimizer
405
- ) # only returns on rank 0
406
- self._gathered_optim_state = st
407
-
408
- def state_dict(self):
409
- state_dict = {
410
- "args": None, # legacy
411
- "cfg": (
412
- OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
413
- if OmegaConf.is_config(self.cfg)
414
- else self.cfg
415
- ),
416
- "model": self.model.state_dict(),
417
- "criterion": (
418
- self.criterion.state_dict()
419
- if utils.has_parameters(self.criterion)
420
- else None
421
- ),
422
- "optimizer_history": (self._optim_history or [])
423
- + [
424
- {
425
- "criterion_name": self.get_criterion().__class__.__name__,
426
- "optimizer_name": self.optimizer.__class__.__name__,
427
- "lr_scheduler_state": self.lr_scheduler.state_dict(),
428
- "num_updates": self.get_num_updates(),
429
- }
430
- ],
431
- "task_state": self.task.state_dict() if self.task is not None else {},
432
- "extra_state": {
433
- "metrics": metrics.state_dict(),
434
- "previous_training_time": self.cumulative_training_time(),
435
- },
436
- }
437
- if self.cfg.ema.store_ema:
438
- # Save EMA model state as extra state
439
- state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
440
- if self.cfg.ema.ema_fp32:
441
- # Save EMA params in fp32
442
- state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
443
- if not self.cfg.checkpoint.no_save_optimizer_state:
444
- if self._gathered_optim_state is not None:
445
- state_dict["last_optimizer_state"] = self._gathered_optim_state
446
- self._gathered_optim_state = None
447
- else:
448
- state_dict["last_optimizer_state"] = self.optimizer.state_dict()
449
- if self.is_fsdp:
450
- # save meta data for recombining checkpoint upon loading
451
- state_dict["fsdp_metadata"] = self.model.local_metadata_dict()
452
- return state_dict
453
-
454
- def save_checkpoint(self, filename, extra_state, save_on_cuda=False):
455
- """Save all training state in a checkpoint file."""
456
- logger.info(f"Saving checkpoint to {filename}")
457
- # call state_dict on all ranks in case it needs internal communication
458
- if not save_on_cuda:
459
- state_dict = utils.move_to_cpu(self.state_dict())
460
- else:
461
- print("Save on cuda")
462
- state_dict = self.state_dict()
463
- state_dict["extra_state"].update(extra_state)
464
- if self.should_save_checkpoint_on_current_rank:
465
- checkpoint_utils.torch_persistent_save(
466
- state_dict,
467
- filename,
468
- async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
469
- )
470
- logger.info(f"Finished saving checkpoint to {filename}")
471
-
472
- def load_checkpoint(
473
- self,
474
- filename,
475
- reset_optimizer=False,
476
- reset_lr_scheduler=False,
477
- optimizer_overrides=None,
478
- reset_meters=False,
479
- strict=True,
480
- load_on_cuda=False,
481
- ):
482
- """
483
- Load all training state from a checkpoint file.
484
- rank = 0 will load the checkpoint, and then broadcast it to all
485
- other ranks.
486
- """
487
- extra_state, self._optim_history, last_optim_state = None, [], None
488
-
489
- logger.info(f"Preparing to load checkpoint {filename}")
490
- is_distributed = self.data_parallel_world_size > 1
491
- bexists = PathManager.isfile(filename)
492
- if bexists:
493
- load_on_all_ranks = (
494
- self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
495
- # TPUs don't support broadcast yet, so load checkpoints
496
- # on every worker for now
497
- or self.tpu
498
- # FSDP requires loading checkpoint shards on all ranks
499
- or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state)
500
- or getattr(self.cfg.model, "base_layers", 0) > 0
501
- )
502
-
503
- if load_on_all_ranks or self.data_parallel_rank == 0:
504
- state = checkpoint_utils.load_checkpoint_to_cpu(
505
- filename, load_on_all_ranks=load_on_all_ranks, strict=strict,load_on_cuda=load_on_cuda,
506
- )
507
- last_optim_state = state.get("last_optimizer_state", None)
508
-
509
- # If doing zero_sharding, do not broadcast global optimizer
510
- # state. Later we will broadcast sharded states to each rank
511
- # to avoid memory from exploding.
512
- if (
513
- not load_on_all_ranks
514
- and self.cfg.distributed_training.zero_sharding == "os"
515
- and "last_optimizer_state" in state
516
- and is_distributed
517
- ):
518
- state["last_optimizer_state"] = "SHARDED"
519
- else:
520
- last_optim_state = None
521
- state = None
522
-
523
- if is_distributed and not load_on_all_ranks: # .contiguous()
524
- state = distributed_utils.broadcast_object(
525
- state,
526
- src_rank=0,
527
- group=self.data_parallel_process_group,
528
- dist_device=self.device,
529
- )
530
- if self.data_parallel_rank > 0:
531
- last_optim_state = state.get("last_optimizer_state", None)
532
-
533
- # load model parameters
534
- try:
535
- if self.cfg.checkpoint.use_ema_weights_to_init_param and "extra_state" in state and "ema" in state["extra_state"]:
536
- logger.info("use_ema_weights_to_init_param = True, will use EMA weights in the ckpt to init the model param...")
537
- ema_state_dict = state["extra_state"]["ema_fp32_params"] if "ema_fp32_params" in state["extra_state"] else state["extra_state"]["ema"]
538
- msg = self.model.load_state_dict(
539
- ema_state_dict, strict=strict, model_cfg=self.cfg.model
540
- )
541
- else:
542
- msg = self.model.load_state_dict(
543
- state["model"], strict=strict, model_cfg=self.cfg.model
544
- )
545
- logger.info(msg)
546
-
547
- # save memory for later steps
548
- if not (self.cfg.ema.store_ema and (self.cfg.checkpoint.use_latest_weights_to_init_ema or not ("extra_state" in state and "ema" in state["extra_state"]))):
549
- del state["model"]
550
- if utils.has_parameters(self.get_criterion()) and 'criterion' in state:
551
- self.get_criterion().load_state_dict(
552
- state["criterion"], strict=strict
553
- )
554
- del state["criterion"]
555
-
556
- except Exception:
557
- raise Exception(
558
- "Cannot load model parameters from checkpoint {}; "
559
- "please ensure that the architectures match.".format(filename)
560
- )
561
- extra_state = state.get("extra_state", None)
562
- self._optim_history = state.get("optimizer_history", None)
563
-
564
- if last_optim_state is not None and not reset_optimizer:
565
- # rebuild optimizer after loading model, since params may have changed
566
- self._build_optimizer()
567
-
568
- # only reload optimizer and lr_scheduler if they match
569
- last_optim = self._optim_history[-1]
570
- assert (
571
- last_optim["criterion_name"] == self.get_criterion().__class__.__name__
572
- ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
573
- assert (
574
- last_optim["optimizer_name"] == self.optimizer.__class__.__name__
575
- ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
576
-
577
- if not reset_lr_scheduler:
578
- self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
579
-
580
- if self.is_fsdp and not self.model.use_sharded_state:
581
- # if use_sharded_state, the last_optim_state is already sharded, skip this
582
- last_optim_state = self.model.get_shard_from_optim_state_dict(
583
- last_optim_state
584
- )
585
- elif not load_on_all_ranks and is_distributed:
586
- last_optim_state = self.optimizer.broadcast_global_state_dict(
587
- last_optim_state
588
- )
589
-
590
- self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
591
-
592
- self.set_num_updates(last_optim["num_updates"])
593
-
594
- if extra_state is not None:
595
- itr_state = extra_state["train_iterator"]
596
- epoch = itr_state["epoch"]
597
-
598
- if "previous_training_time" in extra_state:
599
- self._previous_training_time = extra_state["previous_training_time"]
600
- self._start_time = time.time()
601
-
602
- self.lr_step(epoch)
603
-
604
- if (
605
- itr_state.get("version", 1) >= 2
606
- and itr_state["iterations_in_epoch"] == 0
607
- ):
608
- # reset meters at start of epoch
609
- reset_meters = True
610
-
611
- if "metrics" in extra_state and not reset_meters:
612
- metrics.load_state_dict(extra_state["metrics"])
613
-
614
- # reset TimeMeters, since their start times don't make sense anymore
615
- for meter in metrics.get_meters("default"):
616
- if isinstance(meter, meters.TimeMeter):
617
- meter.reset()
618
-
619
- if self.cfg.ema.store_ema:
620
- if self.cfg.checkpoint.use_latest_weights_to_init_ema or "ema" not in extra_state:
621
- if "ema" not in extra_state:
622
- logger.warn(
623
- "EMA not found in checkpoint. But store_ema is True. "
624
- "EMA is re-initialized from checkpoint."
625
- )
626
- elif self.cfg.checkpoint.use_latest_weights_to_init_ema:
627
- logger.info(
628
- "use_latest_weights_to_init_ema = True. EMA is re-initialized from checkpoint."
629
- )
630
- self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
631
- del state["model"]
632
- else:
633
- logger.info(
634
- "Loading EMA from checkpoint"
635
- )
636
- self.ema.restore(extra_state["ema"], build_fp32_params=False)
637
-
638
- if self.cfg.ema.ema_fp32:
639
- if "ema_fp32_params" in extra_state:
640
- logger.info(
641
- "Loading EMA fp32 params from checkpoint"
642
- )
643
- self.ema.build_fp32_params(extra_state["ema_fp32_params"])
644
- else:
645
- logger.info(
646
- "Building EMA fp32 params from EMA model in checkpoint"
647
- )
648
- self.ema.build_fp32_params()
649
-
650
- logger.info(
651
- "Loaded checkpoint {} (epoch {} @ {} updates)".format(
652
- filename, epoch, self.get_num_updates()
653
- )
654
- )
655
-
656
- else:
657
- logger.info("No existing checkpoint found {}".format(filename))
658
-
659
- # print("delete state ...")
660
- # del state # dereference seems crucial
661
- # torch.cuda.empty_cache()
662
-
663
- return extra_state
664
-
665
- def get_train_iterator(
666
- self,
667
- epoch,
668
- combine=True,
669
- load_dataset=True,
670
- data_selector=None,
671
- shard_batch_itr=True,
672
- disable_iterator_cache=False,
673
- ):
674
- """Return an EpochBatchIterator over the training set for a given epoch."""
675
- if load_dataset:
676
- logger.info("loading train data for epoch {}".format(epoch))
677
- self.task.load_dataset(
678
- self.cfg.dataset.train_subset,
679
- epoch=epoch,
680
- combine=combine,
681
- data_selector=data_selector,
682
- tpu=self.tpu,
683
- )
684
- batch_iterator = self.task.get_batch_iterator(
685
- dataset=self.task.dataset(self.cfg.dataset.train_subset),
686
- max_tokens=self.cfg.dataset.max_tokens,
687
- max_sentences=self.cfg.dataset.batch_size,
688
- max_positions=utils.resolve_max_positions(
689
- self.task.max_positions(),
690
- self.model.max_positions(),
691
- self.cfg.dataset.max_tokens,
692
- ),
693
- ignore_invalid_inputs=True,
694
- required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
695
- seed=self.cfg.common.seed,
696
- num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
697
- shard_id=self.data_parallel_rank if shard_batch_itr else 0,
698
- num_workers=self.cfg.dataset.num_workers,
699
- epoch=epoch,
700
- data_buffer_size=self.cfg.dataset.data_buffer_size,
701
- disable_iterator_cache=disable_iterator_cache,
702
- )
703
- self.reset_dummy_batch(batch_iterator.first_batch)
704
- batch_iterator.dataset.dataset._seek()
705
- return batch_iterator
706
-
707
- def get_valid_iterator(
708
- self,
709
- subset,
710
- disable_iterator_cache=False,
711
- ):
712
- """Return an EpochBatchIterator over given validation subset for a given epoch."""
713
- self.task.dataset(subset).dataset._seek()
714
- batch_iterator = self.task.get_batch_iterator(
715
- dataset=self.task.dataset(subset),
716
- max_tokens=self.cfg.dataset.max_tokens_valid,
717
- max_sentences=self.cfg.dataset.batch_size_valid,
718
- max_positions=utils.resolve_max_positions(
719
- self.task.max_positions(),
720
- self.model.max_positions(),
721
- ),
722
- ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
723
- required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
724
- seed=self.cfg.common.seed,
725
- num_shards=self.data_parallel_world_size,
726
- shard_id=self.data_parallel_rank,
727
- num_workers=self.cfg.dataset.num_workers,
728
- # always pass a fixed "epoch" to keep validation data consistent
729
- # across training epochs
730
- epoch=1,
731
- data_buffer_size=self.cfg.dataset.data_buffer_size,
732
- disable_iterator_cache=disable_iterator_cache,
733
- )
734
- self.reset_dummy_batch(batch_iterator.first_batch)
735
- batch_iterator.dataset.dataset._seek()
736
- return batch_iterator
737
-
738
- def begin_epoch(self, epoch):
739
- """Called at the beginning of each epoch."""
740
- logger.info("begin training epoch {}".format(epoch))
741
-
742
- self.lr_step_begin_epoch(epoch)
743
-
744
- if self.quantizer is not None:
745
- self.quantizer.begin_epoch(epoch)
746
-
747
- # task specific setup per epoch
748
- self.task.begin_epoch(epoch, self.get_model())
749
-
750
- if self.tpu:
751
- import torch_xla.core.xla_model as xm
752
-
753
- xm.rendezvous("begin_epoch") # wait for all workers
754
- xm.mark_step()
755
-
756
- def begin_valid_epoch(self, epoch):
757
- """Called at the beginning of each validation epoch."""
758
-
759
- # task specific setup per validation epoch
760
- self.task.begin_valid_epoch(epoch, self.get_model())
761
-
762
- def reset_dummy_batch(self, batch):
763
- self._dummy_batch = batch
764
-
765
- @metrics.aggregate("train")
766
- def train_step(self, samples, raise_oom=False):
767
- """Do forward, backward and parameter update."""
768
- self._set_seed()
769
- self.model.train()
770
- self.criterion.train()
771
- self.zero_grad()
772
-
773
- metrics.log_start_time("train_wall", priority=800, round=0)
774
-
775
- # If EMA is enabled through store_ema=True
776
- # and task.uses_ema is True, pass the EMA model as a keyword
777
- # argument to the task.
778
- extra_kwargs = {}
779
- if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
780
- extra_kwargs["ema_model"] = self.ema.get_model()
781
-
782
- # forward and backward pass
783
- logging_outputs, sample_size, ooms = [], 0, 0
784
- for i, sample in enumerate(samples): # delayed update loop
785
- sample, is_dummy_batch = self._prepare_sample(sample)
786
-
787
- def maybe_no_sync():
788
- """
789
- Whenever *samples* contains more than one mini-batch, we
790
- want to accumulate gradients locally and only call
791
- all-reduce in the last backwards pass.
792
- """
793
- if (
794
- self.data_parallel_world_size > 1
795
- and hasattr(self.model, "no_sync")
796
- and i < len(samples) - 1
797
- # The no_sync context manager results in increased memory
798
- # usage with FSDP, since full-size gradients will be
799
- # accumulated on each GPU. It's typically a better tradeoff
800
- # to do the extra communication with FSDP.
801
- and not self.is_fsdp
802
- ):
803
- return self.model.no_sync()
804
- else:
805
- return contextlib.ExitStack() # dummy contextmanager
806
-
807
- try:
808
- with maybe_no_sync():
809
- # forward and backward
810
- loss, sample_size_i, logging_output = self.task.train_step(
811
- sample=sample,
812
- model=self.model,
813
- criterion=self.criterion,
814
- optimizer=self.optimizer,
815
- update_num=self.get_num_updates(),
816
- ignore_grad=is_dummy_batch,
817
- **extra_kwargs,
818
- )
819
- del loss
820
-
821
- logging_outputs.append(logging_output)
822
- sample_size += sample_size_i
823
-
824
- # emptying the CUDA cache after the first step can
825
- # reduce the chance of OOM
826
- if self.cuda and self.get_num_updates() == 0:
827
- torch.cuda.empty_cache()
828
- except RuntimeError as e:
829
- if "out of memory" in str(e):
830
- self._log_oom(e)
831
- if raise_oom:
832
- raise e
833
- logger.warning(
834
- "attempting to recover from OOM in forward/backward pass"
835
- )
836
- ooms += 1
837
- self.zero_grad()
838
- if self.cuda:
839
- torch.cuda.empty_cache()
840
- if self.cfg.distributed_training.distributed_world_size == 1:
841
- return None
842
- else:
843
- raise e
844
-
845
- if self.tpu and i < len(samples) - 1:
846
- # tpu-comment: every XLA operation before marking step is
847
- # appended to the IR graph, and processing too many batches
848
- # before marking step can lead to OOM errors.
849
- # To handle gradient accumulation use case, we explicitly
850
- # mark step here for every forward pass without a backward pass
851
- self._xla_markstep_and_send_to_cpu()
852
-
853
- if is_dummy_batch:
854
- if torch.is_tensor(sample_size):
855
- sample_size.zero_()
856
- else:
857
- sample_size *= 0.0
858
-
859
- if torch.is_tensor(sample_size):
860
- sample_size = sample_size.float()
861
- else:
862
- sample_size = float(sample_size)
863
-
864
- # gather logging outputs from all replicas
865
- if self._sync_stats():
866
- train_time = self._local_cumulative_training_time()
867
- logging_outputs, (
868
- sample_size,
869
- ooms,
870
- total_train_time,
871
- ) = self._aggregate_logging_outputs(
872
- logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch
873
- )
874
- self._cumulative_training_time = (
875
- total_train_time / self.data_parallel_world_size
876
- )
877
-
878
- overflow = False
879
- try:
880
- with torch.autograd.profiler.record_function("reduce-grads"):
881
- # reduce gradients across workers
882
- self.optimizer.all_reduce_grads(self.model)
883
- if utils.has_parameters(self.criterion):
884
- self.optimizer.all_reduce_grads(self.criterion)
885
-
886
- with torch.autograd.profiler.record_function("multiply-grads"):
887
- # multiply gradients by (data_parallel_size / sample_size) since
888
- # DDP normalizes by the number of data parallel workers for
889
- # improved fp16 precision.
890
- # Thus we get (sum_of_gradients / sample_size) at the end.
891
- # In case of fp16, this step also undoes loss scaling.
892
- # (Debugging note: Some optimizers perform this scaling on the
893
- # fly, so inspecting model.parameters() or optimizer.params may
894
- # still show the original, unscaled gradients.)
895
- numer = (
896
- self.data_parallel_world_size
897
- if not self.cfg.optimization.use_bmuf or self._sync_stats()
898
- else 1
899
- )
900
- self.optimizer.multiply_grads(numer / (sample_size or 1.0))
901
- # Note: (sample_size or 1.0) handles the case of a zero gradient, in a
902
- # way that avoids CPU/device transfers in case sample_size is a GPU or
903
- # TPU object. The assumption is that the gradient itself is also 0.
904
-
905
- with torch.autograd.profiler.record_function("clip-grads"):
906
- # clip grads
907
- grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
908
-
909
- # check that grad norms are consistent across workers
910
- # on tpu check tensor is slow
911
- if not self.tpu:
912
- if (
913
- not self.cfg.optimization.use_bmuf
914
- and self.cfg.distributed_training.ddp_backend != "slow_mo"
915
- ):
916
- self._check_grad_norms(grad_norm)
917
- if not torch.isfinite(grad_norm).all():
918
- # in case of AMP, if gradients are Nan/Inf then
919
- # optimizer step is still required
920
- if self.cfg.common.amp:
921
- overflow = True
922
- else:
923
- # check local gradnorm single GPU case, trigger NanDetector
924
- raise FloatingPointError("gradients are Nan/Inf")
925
-
926
- with torch.autograd.profiler.record_function("optimizer"):
927
- # take an optimization step
928
- self.task.optimizer_step(
929
- self.optimizer, model=self.model, update_num=self.get_num_updates()
930
- )
931
- if self.cfg.common.amp and overflow:
932
- if self._amp_retries == self.cfg.common.amp_batch_retries:
933
- logger.info("AMP: skipping this batch.")
934
- self._amp_retries = 0
935
- else:
936
- self._amp_retries += 1
937
- return self.train_step(samples, raise_oom) # recursion to feed in same batch
938
-
939
- except FloatingPointError:
940
- # re-run the forward and backward pass with hooks attached to print
941
- # out where it fails
942
- self.zero_grad()
943
- with NanDetector(self.get_model()):
944
- for _, sample in enumerate(samples):
945
- sample, _ = self._prepare_sample(sample)
946
- self.task.train_step(
947
- sample,
948
- self.model,
949
- self.criterion,
950
- self.optimizer,
951
- self.get_num_updates(),
952
- ignore_grad=False,
953
- **extra_kwargs,
954
- )
955
- raise
956
- except OverflowError as e:
957
- overflow = True
958
- logger.info(
959
- f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
960
- )
961
- grad_norm = torch.tensor(0.0).cuda()
962
- self.zero_grad()
963
- except RuntimeError as e:
964
- if "out of memory" in str(e):
965
- self._log_oom(e)
966
- logger.error("OOM during optimization, irrecoverable")
967
- raise e
968
-
969
- # Some distributed wrappers (e.g., SlowMo) need access to the optimizer
970
- # after the step
971
- if hasattr(self.model, "perform_additional_optimizer_actions"):
972
- if hasattr(self.optimizer, "fp32_params"):
973
- self.model.perform_additional_optimizer_actions(
974
- self.optimizer.optimizer, self.optimizer.fp32_params
975
- )
976
- else:
977
- self.model.perform_additional_optimizer_actions(
978
- self.optimizer.optimizer
979
- )
980
-
981
- logging_output = None
982
- if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
983
- self.set_num_updates(self.get_num_updates() + 1)
984
-
985
- if self.cfg.ema.store_ema:
986
- # Step EMA forward with new model.
987
- self.ema.step(
988
- self.get_model(),
989
- self.get_num_updates(),
990
- )
991
- metrics.log_scalar(
992
- "ema_decay",
993
- self.ema.get_decay(),
994
- priority=10000,
995
- round=5,
996
- weight=0,
997
- )
998
-
999
- if self.tpu:
1000
- import torch_xla.core.xla_model as xm
1001
-
1002
- # mark step on TPUs
1003
- self._xla_markstep_and_send_to_cpu()
1004
-
1005
- # only log stats every log_interval steps
1006
- # this causes wps to be misreported when log_interval > 1
1007
- logging_output = {}
1008
- if self.get_num_updates() % self.cfg.common.log_interval == 0:
1009
- # log memory usage
1010
- mem_info = xm.get_memory_info(self.device)
1011
- gb_free = mem_info["kb_free"] / 1024 / 1024
1012
- gb_total = mem_info["kb_total"] / 1024 / 1024
1013
- metrics.log_scalar(
1014
- "gb_free", gb_free, priority=1500, round=1, weight=0
1015
- )
1016
- metrics.log_scalar(
1017
- "gb_total", gb_total, priority=1600, round=1, weight=0
1018
- )
1019
- logging_outputs = self._xla_markstep_and_send_to_cpu(
1020
- logging_outputs
1021
- )
1022
- logging_output = self._reduce_and_log_stats(
1023
- logging_outputs, sample_size, grad_norm
1024
- )
1025
-
1026
- # log whenever there's an XLA compilation, since these
1027
- # slow down training and may indicate opportunities for
1028
- # optimization
1029
- self._check_xla_compilation()
1030
- else:
1031
- if self.cuda and self.cuda_env is not None:
1032
- # log minimum free memory over the iteration
1033
- gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
1034
- torch.cuda.reset_peak_memory_stats()
1035
- gb_free = self.cuda_env.total_memory_in_GB - gb_used
1036
- metrics.log_scalar(
1037
- "gb_free", gb_free, priority=1500, round=1, weight=0
1038
- )
1039
-
1040
- # log stats
1041
- logging_output = self._reduce_and_log_stats(
1042
- logging_outputs, sample_size, grad_norm
1043
- )
1044
-
1045
- # clear CUDA cache to reduce memory fragmentation
1046
- if (
1047
- self.cuda
1048
- and self.cfg.common.empty_cache_freq > 0
1049
- and (
1050
- (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
1051
- % self.cfg.common.empty_cache_freq
1052
- )
1053
- == 0
1054
- ):
1055
- torch.cuda.empty_cache()
1056
-
1057
- if self.cfg.common.fp16 or self.cfg.common.amp:
1058
- metrics.log_scalar(
1059
- "loss_scale",
1060
- (
1061
- self.optimizer.scaler.loss_scale
1062
- if self.cfg.common.fp16
1063
- else self.optimizer.scaler.get_scale()
1064
- ),
1065
- priority=700,
1066
- round=4,
1067
- weight=0,
1068
- )
1069
-
1070
- metrics.log_stop_time("train_wall")
1071
- return logging_output
1072
-
1073
- @metrics.aggregate("valid")
1074
- def valid_step(self, sample, raise_oom=False):
1075
- """Do forward pass in evaluation mode."""
1076
- if self.tpu:
1077
- import torch_xla.core.xla_model as xm
1078
-
1079
- xm.rendezvous("valid_step") # wait for all workers
1080
-
1081
- # If EMA is enabled through store_ema=True
1082
- # and task.uses_ema is True, pass the EMA model as a keyword
1083
- # argument to the task.
1084
- extra_kwargs = {}
1085
- if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
1086
- extra_kwargs["ema_model"] = self.ema.get_model()
1087
-
1088
- with torch.no_grad():
1089
- self.model.eval()
1090
- self.criterion.eval()
1091
-
1092
- sample, is_dummy_batch = self._prepare_sample(sample)
1093
-
1094
- try:
1095
- _loss, sample_size, logging_output = self.task.valid_step(
1096
- sample, self.model, self.criterion, **extra_kwargs
1097
- )
1098
- except RuntimeError as e:
1099
- if "out of memory" in str(e):
1100
- self._log_oom(e)
1101
- if not raise_oom:
1102
- logger.warning(
1103
- "ran out of memory in validation step, retrying batch"
1104
- )
1105
- for p in self.model.parameters():
1106
- if p.grad is not None:
1107
- p.grad = None # free some memory
1108
- if self.cuda:
1109
- torch.cuda.empty_cache()
1110
- return self.valid_step(sample, raise_oom=True)
1111
- raise e
1112
-
1113
-
1114
- logging_outputs = [logging_output]
1115
- if is_dummy_batch:
1116
- if torch.is_tensor(sample_size):
1117
- sample_size.zero_()
1118
- else:
1119
- sample_size *= 0.0
1120
-
1121
- # gather logging outputs from all replicas
1122
- if self.data_parallel_world_size > 1:
1123
- logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
1124
- logging_outputs,
1125
- sample_size,
1126
- ignore=is_dummy_batch,
1127
- )
1128
-
1129
- # log validation stats
1130
- if self.tpu:
1131
- logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
1132
- logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
1133
-
1134
- return logging_output
1135
-
1136
- def zero_grad(self):
1137
- self.optimizer.zero_grad()
1138
-
1139
- def lr_step_begin_epoch(self, epoch):
1140
- """Adjust the learning rate at the beginning of the epoch."""
1141
- self.lr_scheduler.step_begin_epoch(epoch)
1142
- # prefer updating the LR based on the number of steps
1143
- return self.lr_step_update()
1144
-
1145
- def lr_reinit(self, total_updates, num_updates):
1146
- self.lr_scheduler.reinit(total_updates, num_updates)
1147
-
1148
- def lr_step(self, epoch, val_loss=None):
1149
- """Adjust the learning rate at the end of the epoch."""
1150
- self.lr_scheduler.step(epoch, val_loss)
1151
- # prefer updating the LR based on the number of steps
1152
- return self.lr_step_update()
1153
-
1154
- def lr_step_update(self):
1155
- """Update the learning rate after each update."""
1156
- new_lr = self.lr_scheduler.step_update(self.get_num_updates())
1157
- if isinstance(new_lr, dict):
1158
- for k, v in new_lr.items():
1159
- metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
1160
- new_lr = new_lr.get("default", next(iter(new_lr.values())))
1161
- else:
1162
- metrics.log_scalar("lr", new_lr, weight=0, priority=300)
1163
- return new_lr
1164
-
1165
- def get_lr(self):
1166
- """Get the current learning rate."""
1167
- return self.optimizer.get_lr()
1168
-
1169
- def get_model(self):
1170
- """Get the (non-wrapped) model instance."""
1171
- return self._model
1172
-
1173
- def get_criterion(self):
1174
- """Get the (non-wrapped) criterion instance."""
1175
- return self._criterion
1176
-
1177
- def get_meter(self, name):
1178
- """[deprecated] Get a specific meter by name."""
1179
- from fairseq import meters
1180
-
1181
- if "get_meter" not in self._warn_once:
1182
- self._warn_once.add("get_meter")
1183
- utils.deprecation_warning(
1184
- "Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
1185
- )
1186
-
1187
- train_meters = metrics.get_meters("train")
1188
- if train_meters is None:
1189
- train_meters = {}
1190
-
1191
- if name == "train_loss" and "loss" in train_meters:
1192
- return train_meters["loss"]
1193
- elif name == "train_nll_loss":
1194
- # support for legacy train.py, which assumed this meter is
1195
- # always initialized
1196
- m = train_meters.get("nll_loss", None)
1197
- return m or meters.AverageMeter()
1198
- elif name == "wall":
1199
- # support for legacy train.py, which assumed this meter is
1200
- # always initialized
1201
- m = metrics.get_meter("default", "wall")
1202
- return m or meters.TimeMeter()
1203
- elif name == "wps":
1204
- m = metrics.get_meter("train", "wps")
1205
- return m or meters.TimeMeter()
1206
- elif name in {"valid_loss", "valid_nll_loss"}:
1207
- # support for legacy train.py, which assumed these meters
1208
- # are always initialized
1209
- k = name[len("valid_") :]
1210
- m = metrics.get_meter("valid", k)
1211
- return m or meters.AverageMeter()
1212
- elif name == "oom":
1213
- return meters.AverageMeter()
1214
- elif name in train_meters:
1215
- return train_meters[name]
1216
- return None
1217
-
1218
- def get_num_updates(self):
1219
- """Get the number of parameters updates."""
1220
- return self._num_updates
1221
-
1222
- def set_num_updates(self, num_updates):
1223
- """Set the number of parameters updates."""
1224
- self._num_updates = num_updates
1225
- self.lr_step_update()
1226
- if self.quantizer:
1227
- self.quantizer.step_update(self._num_updates)
1228
- metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
1229
-
1230
- def clip_grad_norm(self, clip_norm):
1231
- def agg_norm_fn(total_norm):
1232
- total_norm = total_norm.cuda().float() ** 2
1233
- total_norm = distributed_utils.all_reduce(
1234
- total_norm, group=self.data_parallel_process_group
1235
- )
1236
- return total_norm ** 0.5
1237
-
1238
- should_agg_norm = (
1239
- self.is_fsdp
1240
- and (
1241
- self.data_parallel_process_group is not None
1242
- or torch.distributed.is_initialized()
1243
- )
1244
- )
1245
- return self.optimizer.clip_grad_norm(
1246
- clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None
1247
- )
1248
-
1249
- def cumulative_training_time(self):
1250
- if self._cumulative_training_time is None:
1251
- # single GPU
1252
- return self._local_cumulative_training_time()
1253
- else:
1254
- return self._cumulative_training_time
1255
-
1256
- def _local_cumulative_training_time(self):
1257
- """Aggregate training time in seconds."""
1258
- return time.time() - self._start_time + self._previous_training_time
1259
-
1260
- def _fp_convert_sample(self, sample):
1261
- def apply_half(t):
1262
- if t.dtype is torch.float32:
1263
- return t.to(dtype=torch.half)
1264
- return t
1265
-
1266
- def apply_bfloat16(t):
1267
- if t.dtype is torch.float32:
1268
- return t.to(dtype=torch.bfloat16)
1269
- return t
1270
-
1271
- if self.cfg.common.fp16:
1272
- sample = utils.apply_to_sample(apply_half, sample)
1273
-
1274
- if self.cfg.common.bf16:
1275
- sample = utils.apply_to_sample(apply_bfloat16, sample)
1276
-
1277
- return sample
1278
-
1279
- def _prepare_sample(self, sample, is_dummy=False):
1280
- if sample == "DUMMY":
1281
- raise Exception(
1282
- "Trying to use an uninitialized 'dummy' batch. This usually indicates "
1283
- "that the total number of batches is smaller than the number of "
1284
- "participating GPUs. Try reducing the batch size or using fewer GPUs."
1285
- )
1286
-
1287
- if sample is None or len(sample) == 0:
1288
- assert (
1289
- self._dummy_batch is not None and len(self._dummy_batch) > 0
1290
- ), "Invalid dummy batch: {}".format(self._dummy_batch)
1291
- sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
1292
- return sample, True
1293
-
1294
- # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth
1295
- # it makes sense to do the format conversion on the CPU and then transfer
1296
- # a smaller buffer to the device. This also saves GPU memory capacity.
1297
-
1298
- if self.cfg.common.on_cpu_convert_precision:
1299
- sample = self._fp_convert_sample(sample)
1300
-
1301
- if self.cuda:
1302
- if self.pipeline_model_parallel:
1303
- if 'target' in sample:
1304
- sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
1305
- else:
1306
- sample = utils.move_to_cuda(sample)
1307
- elif self.tpu and is_dummy:
1308
- # the dummy batch may not be on the appropriate device
1309
- sample = utils.move_to_cuda(sample, device=self.device)
1310
-
1311
- if not self.cfg.common.on_cpu_convert_precision:
1312
- sample = self._fp_convert_sample(sample)
1313
-
1314
- if self._dummy_batch == "DUMMY":
1315
- self._dummy_batch = sample
1316
-
1317
- return sample, False
1318
-
1319
- def _set_seed(self):
1320
- # Set seed based on args.seed and the update number so that we get
1321
- # reproducible results when resuming from checkpoints
1322
- seed = self.cfg.common.seed + self.get_num_updates()
1323
- utils.set_torch_seed(seed)
1324
-
1325
- def _sync_stats(self):
1326
- # Return True if it's using multiple GPUs and DDP or multiple GPUs with
1327
- # BMUF and it's a bmuf sync with warmup iterations completed before.
1328
- if self.data_parallel_world_size == 1:
1329
- return False
1330
- elif self.cfg.optimization.use_bmuf:
1331
- return (
1332
- self.get_num_updates() + 1
1333
- ) % self.cfg.bmuf.global_sync_iter == 0 and (
1334
- self.get_num_updates() + 1
1335
- ) > self.cfg.bmuf.warmup_iterations
1336
- else:
1337
- return True
1338
-
1339
- def _log_oom(self, exc):
1340
- msg = "OOM: Ran out of memory with exception: {}".format(exc)
1341
- logger.warning(msg)
1342
- if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
1343
- for device_idx in range(torch.cuda.device_count()):
1344
- logger.warning(torch.cuda.memory_summary(device=device_idx))
1345
- sys.stderr.flush()
1346
-
1347
- def _aggregate_logging_outputs(
1348
- self,
1349
- logging_outputs: List[Dict[str, Any]],
1350
- *extra_stats_to_sum,
1351
- ignore=False,
1352
- ):
1353
- if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
1354
- return self._fast_stat_sync_sum(
1355
- logging_outputs, *extra_stats_to_sum, ignore=ignore
1356
- )
1357
- else:
1358
- return self._all_gather_list_sync(
1359
- logging_outputs, *extra_stats_to_sum, ignore=ignore
1360
- )
1361
-
1362
- def _all_gather_list_sync(
1363
- self,
1364
- logging_outputs: List[Dict[str, Any]],
1365
- *extra_stats_to_sum,
1366
- ignore=False,
1367
- ):
1368
- """
1369
- Sync logging outputs across workers. all_gather_list_sync is
1370
- suitable when logging outputs are complex types.
1371
- """
1372
- if self.tpu:
1373
- raise NotImplementedError
1374
- if ignore:
1375
- logging_outputs = []
1376
- results = list(
1377
- zip(
1378
- *distributed_utils.all_gather_list(
1379
- [logging_outputs] + list(extra_stats_to_sum),
1380
- max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
1381
- group=self.data_parallel_process_group,
1382
- )
1383
- )
1384
- )
1385
- logging_outputs, extra_stats_to_sum = results[0], results[1:]
1386
- logging_outputs = list(chain.from_iterable(logging_outputs))
1387
- extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
1388
- return logging_outputs, extra_stats_to_sum
1389
-
1390
- def _fast_stat_sync_sum(
1391
- self,
1392
- logging_outputs: List[Dict[str, Any]],
1393
- *extra_stats_to_sum,
1394
- ignore=False,
1395
- ):
1396
- """
1397
- Sync logging outputs across workers. fast_stat_sync_sum is
1398
- faster than all_gather_list_sync, but is only suitable when
1399
- logging outputs are scalars and can be summed. Note that
1400
- *logging_outputs* cannot contain any nested dicts/lists.
1401
- """
1402
- data = {}
1403
- for i, stat in enumerate(extra_stats_to_sum):
1404
- data["extra_stats_" + str(i)] = stat
1405
- if len(logging_outputs) > 0:
1406
- log_keys = list(logging_outputs[0].keys())
1407
- for k in log_keys:
1408
- if not ignore:
1409
- v = sum(log[k] for log in logging_outputs if k in log)
1410
- else:
1411
- v = logging_outputs[0][k]
1412
- v = torch.zeros_like(v) if torch.is_tensor(v) else 0
1413
- data["logging_outputs_" + k] = v
1414
- else:
1415
- log_keys = None
1416
-
1417
- data = distributed_utils.all_reduce_dict(
1418
- data, device=self.device, group=self.data_parallel_process_group
1419
- )
1420
-
1421
- extra_stats_to_sum = [
1422
- data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
1423
- ]
1424
- if log_keys is not None:
1425
- logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
1426
- else:
1427
- logging_outputs = []
1428
- return logging_outputs, extra_stats_to_sum
1429
-
1430
- def _check_grad_norms(self, grad_norm):
1431
- """Check that grad norms are consistent across workers."""
1432
- if self._grad_norm_buf is not None:
1433
- self._grad_norm_buf.zero_()
1434
- self._grad_norm_buf[self.data_parallel_rank] = grad_norm
1435
- distributed_utils.all_reduce(
1436
- self._grad_norm_buf, group=self.data_parallel_process_group
1437
- )
1438
-
1439
- def is_consistent(tensor):
1440
- max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
1441
- return (
1442
- (torch.isfinite(tensor).all()
1443
- and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
1444
- or
1445
- (self.cfg.common.amp and not torch.isfinite(tensor).all())
1446
- # in case of amp non-finite grads are fine
1447
- )
1448
-
1449
- if not is_consistent(self._grad_norm_buf):
1450
- pretty_detail = "\n".join(
1451
- "rank {:3d} = {:.8f}".format(r, n)
1452
- for r, n in enumerate(self._grad_norm_buf.tolist())
1453
- )
1454
- error_detail = "grad_norm across the workers:\n{}\n".format(
1455
- pretty_detail
1456
- )
1457
- # use FloatingPointError to trigger NanDetector
1458
- raise FloatingPointError(
1459
- "Fatal error: gradients are inconsistent between workers. "
1460
- "Try --ddp-backend=legacy_ddp. "
1461
- "Or are you mixing up different generation of GPUs in training?"
1462
- + "\n"
1463
- + "-" * 80
1464
- + "\n{}\n".format(error_detail)
1465
- + "-" * 80
1466
- )
1467
-
1468
- def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
1469
- if grad_norm is not None and (
1470
- not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
1471
- ):
1472
- metrics.log_speed("ups", 1.0, priority=100, round=2)
1473
- metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
1474
- if self.cfg.optimization.clip_norm > 0:
1475
- metrics.log_scalar(
1476
- "clip",
1477
- torch.where(
1478
- grad_norm > self.cfg.optimization.clip_norm,
1479
- grad_norm.new_tensor(100),
1480
- grad_norm.new_tensor(0),
1481
- ),
1482
- priority=500,
1483
- round=1,
1484
- )
1485
-
1486
- with metrics.aggregate() as agg:
1487
- if logging_outputs is not None:
1488
- self.task.reduce_metrics(logging_outputs, self.get_criterion())
1489
- del logging_outputs
1490
-
1491
- # extra warning for criterions that don't properly log a loss value
1492
- if "loss" not in agg:
1493
- if "loss" not in self._warn_once:
1494
- self._warn_once.add("loss")
1495
- logger.warning(
1496
- "Criterion.reduce_metrics did not log a 'loss' value, "
1497
- "which may break some functionality"
1498
- )
1499
- metrics.log_scalar("loss", -1)
1500
-
1501
- # support legacy interface
1502
- if self.tpu:
1503
- logging_output = {}
1504
- else:
1505
- logging_output = agg.get_smoothed_values()
1506
- logging_output["sample_size"] = sample_size
1507
- for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
1508
- if key_to_delete in logging_output:
1509
- del logging_output[key_to_delete]
1510
- return logging_output
1511
-
1512
- def _check_xla_compilation(self):
1513
- import torch_xla.debug.metrics as met
1514
-
1515
- compile_stats = met.metric_data("CompileTime")
1516
- if compile_stats is None:
1517
- return
1518
- num_xla_compiles = compile_stats[0]
1519
- if num_xla_compiles > self._num_xla_compiles:
1520
- logger.warning(
1521
- "XLA compilation detected on device #{}; too many of these can lead "
1522
- "to slow training, but we expect a few in the beginning".format(
1523
- self.cfg.distributed_training.distributed_rank
1524
- )
1525
- )
1526
- self._num_xla_compiles = num_xla_compiles
1527
-
1528
- def _xla_markstep_and_send_to_cpu(self, data=None):
1529
- import torch_xla.core.xla_model as xm
1530
-
1531
- xm.mark_step()
1532
- if data is not None:
1533
- from fairseq.utils import xla_device_to_cpu
1534
-
1535
- return xla_device_to_cpu(data)
1536
-
1537
-
1538
- def _catalog_shared_params(module, memo=None, prefix=""):
1539
- if memo is None:
1540
- first_call = True
1541
- memo = {}
1542
- else:
1543
- first_call = False
1544
- for name, param in module._parameters.items():
1545
- param_prefix = prefix + ("." if prefix else "") + name
1546
- if param not in memo:
1547
- memo[param] = []
1548
- memo[param].append(param_prefix)
1549
- for name, m in module._modules.items():
1550
- if m is None:
1551
- continue
1552
- submodule_prefix = prefix + ("." if prefix else "") + name
1553
- _catalog_shared_params(m, memo, submodule_prefix)
1554
- if first_call:
1555
- return [x for x in memo.values() if len(x) > 1]
1556
-
1557
-
1558
- def _get_module_by_path(module, path):
1559
- path = path.split(".")
1560
- for name in path:
1561
- module = getattr(module, name)
1562
- return module
1563
-
1564
-
1565
- def _set_module_by_path(module, path, value):
1566
- path = path.split(".")
1567
- for name in path[:-1]:
1568
- module = getattr(module, name)
1569
- setattr(module, path[-1], value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformers.md DELETED
@@ -1,69 +0,0 @@
1
- # Use in huggingface transformers (Beta)
2
-
3
- [**Colab Notebook**](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
4
- ![image](https://user-images.githubusercontent.com/27664428/190052470-56679999-571b-4d46-a9a8-e567b78e20d1.png)
5
-
6
-
7
- We now support inference of OFA on the huggingface transformers. In the near future, we will provide the codes for training.
8
-
9
- Model checkpoints are stored in our [huggingface models](https://huggingface.co/OFA-Sys). Specifically, 5 versions of the pretrained OFA models, namely OFA-tiny, OFA-medium, OFA-base, OFA-large, and OFA-huge have been already uploaded. For more information about the models, please refer to the Model Card on our [README](https://github.com/OFA-Sys/OFA).
10
- Note that each directory includes 4 files, namely `config.json` which consists of model configuration, `vocab.json` and `merge.txt` for our OFA tokenizer, and lastly `pytorch_model.bin` which consists of model weights. There is no need to worry about the mismatch between Fairseq and transformers, since we have addressed the issue yet.
11
-
12
- To use it in transformers, you can first refer to this notebook ([link](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)). For more information, you can find codes in this branch https://github.com/OFA-Sys/OFA/tree/feature/add_transformers.
13
-
14
- In the following, we introduce the details in our provided notebook and illustrate how to use OFA in Transformers.
15
-
16
- First, install the transformers and download the models (take OFA-tiny as an example) as shown below.
17
-
18
- ```
19
- git clone --single-branch --branch feature/add_transformers https://github.com/OFA-Sys/OFA.git
20
- pip install OFA/transformers/
21
- git clone https://huggingface.co/OFA-Sys/OFA-tiny
22
- ```
23
-
24
- Next, refer the path to OFA-tiny to `ckpt_dir`, and prepare an image for the testing example below. Also, ensure that you have pillow and torchvision in your environment. Check if there is the directory `generate` in your model directory `transformers/src/transformers/models/ofa` to ensure that you can use the sequence generator that we provide.
25
-
26
- ```
27
- >>> from PIL import Image
28
- >>> from torchvision import transforms
29
- >>> from transformers import OFATokenizer, OFAModel
30
- >>> from transformers.models.ofa.generate import sequence_generator
31
-
32
- >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
33
- >>> resolution = 256
34
- >>> patch_resize_transform = transforms.Compose([
35
- lambda image: image.convert("RGB"),
36
- transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
37
- transforms.ToTensor(),
38
- transforms.Normalize(mean=mean, std=std)
39
- ])
40
-
41
-
42
- >>> tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
43
-
44
- >>> txt = " what does the image describe?"
45
- >>> inputs = tokenizer([txt], return_tensors="pt").input_ids
46
- >>> img = Image.open(path_to_image)
47
- >>> patch_img = patch_resize_transform(img).unsqueeze(0)
48
-
49
-
50
- >>> # using the generator of fairseq version
51
- >>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
52
- >>> generator = sequence_generator.SequenceGenerator(
53
- tokenizer=tokenizer,
54
- beam_size=5,
55
- max_len_b=16,
56
- min_len=0,
57
- no_repeat_ngram_size=3,
58
- )
59
- >>> data = {}
60
- >>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
61
- >>> gen_output = generator.generate([model], data)
62
- >>> gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]
63
-
64
- >>> # using the generator of huggingface version
65
- >>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)
66
- >>> gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3)
67
-
68
- >>> print(tokenizer.batch_decode(gen, skip_special_tokens=True))
69
- ```