tsujuifu commited on
Commit
bdbb79e
·
0 Parent(s):

archive v1

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MLLM-guided Image Editing (MGIE)
3
+ emoji: 👩‍🎨
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.37.0
8
+ app_file: app.py
9
+ license: other
10
+ ---
_input/0.jpg ADDED
_input/1.jpg ADDED
_input/10.jpg ADDED
_input/11.jpg ADDED
_input/12.jpg ADDED
_input/13.jpg ADDED
_input/14.jpg ADDED
_input/15.jpg ADDED
_input/16.jpg ADDED
_input/17.jpg ADDED
_input/18.jpg ADDED
_input/19.jpg ADDED
_input/2.jpg ADDED
_input/3.jpg ADDED
_input/4.jpg ADDED
_input/5.jpg ADDED
_input/6.jpg ADDED
_input/7.jpg ADDED
_input/8.jpg ADDED
_input/9.jpg ADDED
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ # os.system('cp -r ./_ckpt/LLaVA-7B-v1 /data/LLaVA-7B-v1'), os.system('cp -r ./_ckpt/mgie_7b /data/mgie_7b')
4
+ os.system('ls /data'), os.system('df -h /data')
5
+ [os.system('mv llava.py /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/llava/model/llava.py'),
6
+ os.system('mv train.py /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/llava/train/train.py')]
7
+
8
+ from PIL import Image
9
+
10
+ import numpy as np
11
+ import torch as T
12
+ import transformers, diffusers
13
+
14
+ from llava.conversation import conv_templates
15
+ from llava.model import *
16
+
17
+ import gradio as gr
18
+
19
+ def crop_resize(f, sz=512):
20
+ w, h = f.size
21
+ if w>h:
22
+ p = (w-h)//2
23
+ f = f.crop([p, 0, p+h, h])
24
+ elif h>w:
25
+ p = (h-w)//2
26
+ f = f.crop([0, p, w, p+w])
27
+ f = f.resize([sz, sz])
28
+ return f
29
+ def remove_alter(s): # hack expressive instruction
30
+ if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()
31
+ if '</s>' in s: s = s[:s.index('</s>')].strip()
32
+ if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
33
+ if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
34
+ s = '.'.join([s.strip() for s in s.split('.')[:2]])
35
+ if s[-1]!='.': s += '.'
36
+ return s.strip()
37
+
38
+ DEFAULT_IMAGE_TOKEN = '<image>'
39
+ DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
40
+ DEFAULT_IM_START_TOKEN = '<im_start>'
41
+ DEFAULT_IM_END_TOKEN = '<im_end>'
42
+ PATH_LLAVA = '/data/LLaVA-7B-v1'
43
+
44
+ tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
45
+ model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
46
+ image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)
47
+
48
+ tokenizer.padding_side = 'left'
49
+ tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
50
+ model.resize_token_embeddings(len(tokenizer))
51
+ ckpt = T.load('/data/mgie_7b/mllm.pt', map_location='cpu')
52
+ model.load_state_dict(ckpt, strict=False)
53
+
54
+ mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
55
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
56
+ if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
57
+
58
+ vision_tower = model.get_model().vision_tower[0]
59
+ vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
60
+ model.get_model().vision_tower[0] = vision_tower
61
+ vision_config = vision_tower.config
62
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
63
+ vision_config.use_im_start_end = mm_use_im_start_end
64
+ if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
65
+ image_token_len = (vision_config.image_size//vision_config.patch_size)**2
66
+
67
+ _ = model.eval()
68
+ EMB = ckpt['emb'].cuda()
69
+ with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
70
+
71
+ pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
72
+ pipe.set_progress_bar_config(disable=True)
73
+ pipe.unet.load_state_dict(T.load('/data/mgie_7b/unet.pt', map_location='cpu'))
74
+ print('--init MGIE--')
75
+
76
+ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
77
+ img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
78
+ inp = img
79
+
80
+ img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
81
+ txt = "what will this image be like if '%s'"%(txt)
82
+ txt = txt+'\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
83
+ conv = conv_templates['vicuna_v1_1'].copy()
84
+ conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
85
+ txt = conv.get_prompt()
86
+ txt = tokenizer(txt)
87
+ txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
88
+
89
+ with T.inference_mode():
90
+ out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
91
+ do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
92
+ return_dict_in_generate=True, output_hidden_states=True)
93
+ out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]
94
+
95
+ if 32003 in out: p = out.index(32003)-1
96
+ else: p = len(hid)-9
97
+ p = min(p, len(hid)-9)
98
+ hid = hid[p:p+8]
99
+
100
+ out = remove_alter(tokenizer.decode(out))
101
+ emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
102
+ res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
103
+ generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
104
+
105
+ return res, out
106
+
107
+ def go_example(seed, cfg_txt, cfg_img):
108
+ txt = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
109
+ 'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
110
+ 'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
111
+ 'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
112
+ 'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
113
+ i = T.randint(len(txt), (1, )).item()
114
+
115
+ return './_input/%d.jpg'%(i), txt[i], seed, cfg_txt, cfg_img
116
+
117
+ go_mgie(np.array(Image.open('./_input/0.jpg').convert('RGB')), 'make the frame red', 13331, 7.5, 1.5)
118
+ print('--init GO--')
119
+
120
+ with gr.Blocks() as app:
121
+ gr.Markdown(
122
+ """
123
+ 🔔 we will have a maintenance at 3 a.m. (PST)
124
+ # [ICLR\'24] Guiding Instruction-based Image Editing via Multimodal Large Language Models<br>
125
+ 🔔 this demo is hosted by [Tsu-Jui Fu](https://github.com/tsujuifu/pytorch_mgie)<br>
126
+ 🔔 a black image means that the output did not pass the [safety checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker)<br>
127
+ 🔔 if the queue is full (*this app is too busy*), you can also try it [here](http://128.111.41.13:7122)<br>
128
+ 🔔 if the building process takes too long, please try refreshing the page
129
+ """
130
+ )
131
+ with gr.Row(): inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
132
+ gr.Image(height=384, width=384, label='Goal Image', interactive=False)]
133
+ with gr.Row(): txt, out = [gr.Textbox(label='Instruction', interactive=True),
134
+ gr.Textbox(label='Expressive Instruction', interactive=False)]
135
+ with gr.Row(): seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
136
+ gr.Number(value=7.5, label='Text CFG', interactive=True),
137
+ gr.Number(value=1.5, label='Image CFG', interactive=True)]
138
+ with gr.Row(): btn_sub, btn_exp = [gr.Button('Submit'),
139
+ gr.Button('Example')]
140
+
141
+ btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
142
+ btn_exp.click(fn=go_example, inputs=[seed, cfg_txt, cfg_img], outputs=[inp, txt, seed, cfg_txt, cfg_img])
143
+
144
+ app.queue(concurrency_count=1, max_size=75), app.launch()
llava.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py
3
+
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers import AutoConfig, AutoModelForCausalLM, \
12
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
13
+ CLIPVisionModel, CLIPImageProcessor
14
+
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ import os, diffusers
18
+
19
+ DEFAULT_IMAGE_TOKEN = "<image>"
20
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
21
+ DEFAULT_IM_START_TOKEN = "<im_start>"
22
+ DEFAULT_IM_END_TOKEN = "<im_end>"
23
+
24
+
25
+ class LlavaConfig(LlamaConfig):
26
+ model_type = "llava"
27
+
28
+
29
+ class LlavaLlamaModel(LlamaModel):
30
+ config_class = LlavaConfig
31
+
32
+ def __init__(self, config: LlamaConfig):
33
+ super(LlavaLlamaModel, self).__init__(config)
34
+
35
+ if hasattr(config, "mm_vision_tower"):
36
+ # HACK: for FSDP
37
+ self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
38
+ # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
39
+
40
+ if hasattr(config, "use_mm_proj"):
41
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
42
+
43
+ def get_vision_tower(self):
44
+ vision_tower = getattr(self, 'vision_tower', None)
45
+ if type(vision_tower) is list:
46
+ vision_tower = vision_tower[0]
47
+ return vision_tower
48
+
49
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
50
+ pretrain_mm_mlp_adapter=None, fsdp=None):
51
+ self.config.mm_vision_tower = vision_tower
52
+
53
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
54
+
55
+ if not hasattr(self, 'vision_tower'):
56
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
57
+ else:
58
+ vision_tower = self.vision_tower[0]
59
+ vision_tower.requires_grad_(False)
60
+
61
+ if fsdp is not None and len(fsdp) > 0:
62
+ self.vision_tower = [vision_tower]
63
+ else:
64
+ self.vision_tower = vision_tower
65
+
66
+ vision_config = vision_tower.config
67
+ num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
68
+
69
+ self.config.use_mm_proj = True
70
+ self.config.mm_hidden_size = vision_config.hidden_size
71
+ self.config.mm_vision_select_layer = mm_vision_select_layer
72
+
73
+ if not hasattr(self, 'mm_projector'):
74
+ self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
75
+
76
+ if pretrain_mm_mlp_adapter is not None:
77
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
78
+ self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
79
+
80
+ return dict(
81
+ image_processor=image_processor,
82
+ image_token_len=num_patches,
83
+ vision_config=vision_config
84
+ )
85
+
86
+ def forward(
87
+ self,
88
+ input_ids: torch.LongTensor = None,
89
+ attention_mask: Optional[torch.Tensor] = None,
90
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
91
+ inputs_embeds: Optional[torch.FloatTensor] = None,
92
+ use_cache: Optional[bool] = None,
93
+ output_attentions: Optional[bool] = None,
94
+ output_hidden_states: Optional[bool] = None,
95
+ images: Optional[torch.FloatTensor] = None,
96
+ return_dict: Optional[bool] = None,
97
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
98
+
99
+ # HACK: replace back original embeddings for LLaVA pretraining
100
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
101
+ # if orig_embeds_params is not None:
102
+ # orig_embeds_params = orig_embeds_params[0]
103
+ # with torch.no_grad():
104
+ # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
105
+
106
+ if inputs_embeds is None:
107
+ inputs_embeds = self.embed_tokens(input_ids)
108
+
109
+ vision_tower = self.get_vision_tower()
110
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
111
+ # TODO: this is a modified multimodal LLM -- Haotian Liu
112
+ with torch.no_grad():
113
+ if type(images) is list:
114
+ # variable length images
115
+ image_features = []
116
+ for image in images:
117
+ image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
118
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
119
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
120
+ image_feature = select_hidden_state[:, 1:]
121
+ image_features.append(image_feature)
122
+ else:
123
+ image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True)
124
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
125
+ select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
126
+ image_features = select_hidden_state[:, 1:].to(images.dtype)
127
+ if type(images) is list:
128
+ image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
129
+ else:
130
+ image_features = self.mm_projector(image_features)
131
+ dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
132
+ dummy_image_features = self.mm_projector(dummy_image_features)
133
+
134
+ new_input_embeds = []
135
+ cur_image_idx = 0
136
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
137
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
138
+ # multimodal LLM, but the current sample is not multimodal
139
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
140
+ new_input_embeds.append(cur_input_embeds)
141
+ cur_image_idx += 1
142
+ continue
143
+ if vision_tower.config.use_im_start_end:
144
+ cur_image_features = image_features[cur_image_idx]
145
+ num_patches = cur_image_features.shape[0]
146
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
147
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
148
+ image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
149
+ for image_start_token_pos in image_start_tokens:
150
+ cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
151
+ num_patches = cur_image_features.shape[0]
152
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
153
+ raise ValueError("The image end token should follow the image start token.")
154
+ if orig_embeds_params is not None:
155
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
156
+ else:
157
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
158
+ cur_image_idx += 1
159
+ new_input_embeds.append(cur_new_input_embeds)
160
+ else:
161
+ cur_image_features = image_features[cur_image_idx]
162
+ num_patches = cur_image_features.shape[0]
163
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
164
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
165
+ masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
166
+ mask_index_start = masked_indices[0]
167
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
168
+ raise ValueError("The image patch tokens should be consecutive.")
169
+ if orig_embeds_params is not None:
170
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
171
+ else:
172
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
173
+ new_input_embeds.append(cur_new_input_embeds)
174
+ cur_image_idx += 1
175
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
176
+
177
+ return super(LlavaLlamaModel, self).forward(
178
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
179
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
180
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
181
+ return_dict=return_dict
182
+ )
183
+
184
+ class EditMapper(nn.Module):
185
+ def __init__(self):
186
+ super().__init__()
187
+
188
+ self.llm2hid = nn.Linear(4096, 512)
189
+ self.query = nn.Parameter(torch.randn(1, 77, 512))
190
+ self.mapper = nn.Transformer(batch_first=True, norm_first=True,
191
+ d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
192
+ dim_feedforward=2048, dropout=0.0)
193
+ self.hid2feat = nn.Linear(512, 768)
194
+
195
+ def forward(self, llm, emb):
196
+ hid = self.llm2hid(llm+emb)
197
+ hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
198
+ feat = self.hid2feat(hid)
199
+
200
+ return feat
201
+
202
+ class LlavaLlamaForCausalLM(LlamaForCausalLM):
203
+ config_class = LlavaConfig
204
+
205
+ def __init__(self, config):
206
+ super(LlamaForCausalLM, self).__init__(config)
207
+ self.model = LlavaLlamaModel(config)
208
+
209
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
210
+
211
+ self.edit_head = EditMapper()
212
+
213
+ '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
214
+ diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
215
+ diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
216
+ self.vae.requires_grad_(False)
217
+ self.unet.register_to_config(in_channels=8)
218
+ with torch.no_grad():
219
+ conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
220
+ conv.weight.zero_()
221
+ conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
222
+ self.unet.conv_in = conv'''
223
+
224
+ # Initialize weights and apply final processing
225
+ self.post_init()
226
+
227
+ def get_model(self):
228
+ return self.model
229
+
230
+ def get_vision_tower(self):
231
+ return self.get_model().get_vision_tower()
232
+
233
+ def get_vision_tower(self):
234
+ model = self.get_model()
235
+ vision_tower = model.vision_tower
236
+ if type(vision_tower) is list:
237
+ vision_tower = vision_tower[0]
238
+ return vision_tower
239
+
240
+ def forward(
241
+ self,
242
+ input_ids: torch.LongTensor = None,
243
+ attention_mask: Optional[torch.Tensor] = None,
244
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
245
+ inputs_embeds: Optional[torch.FloatTensor] = None,
246
+ labels: Optional[torch.LongTensor] = None,
247
+ use_cache: Optional[bool] = None,
248
+ output_attentions: Optional[bool] = None,
249
+ output_hidden_states: Optional[bool] = None,
250
+ images: Optional[torch.FloatTensor] = None,
251
+ return_dict: Optional[bool] = None,
252
+ p2p_inp=None, p2p_ans=None
253
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
254
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
255
+ output_hidden_states = (
256
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
257
+ )
258
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
259
+
260
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
261
+ outputs = self.model(
262
+ input_ids=input_ids,
263
+ attention_mask=attention_mask,
264
+ past_key_values=past_key_values,
265
+ inputs_embeds=inputs_embeds,
266
+ use_cache=use_cache,
267
+ output_attentions=output_attentions,
268
+ output_hidden_states=output_hidden_states,
269
+ return_dict=return_dict,
270
+ images=images
271
+ )
272
+
273
+ hidden_states = outputs[0]
274
+ logits = self.lm_head(hidden_states)
275
+
276
+ loss = None
277
+ if labels is not None:
278
+ # Shift so that tokens < n predict n
279
+ shift_logits = logits[..., :-1, :].contiguous()
280
+ shift_labels = labels[..., 1:].contiguous()
281
+ # Flatten the tokens
282
+ loss_fct = CrossEntropyLoss()
283
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
284
+ shift_labels = shift_labels.view(-1)
285
+ # Enable model/pipeline parallelism
286
+ shift_labels = shift_labels.to(shift_logits.device)
287
+ loss = loss_fct(shift_logits, shift_labels)
288
+
289
+ if labels is not None:
290
+ llm = []
291
+ for i in range(labels.shape[0]):
292
+ try: p = labels[i].data.cpu().tolist().index(32003)-1
293
+ except: p = len(labels[i])-9
294
+ p = min(len(hidden_states[i])-9, p)
295
+ llm.append(hidden_states[i][p:p+8].unsqueeze(0))
296
+ llm = torch.cat(llm, dim=0)
297
+ hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
298
+
299
+ B, DROP = labels.shape[0], 0.05
300
+
301
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
302
+ self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
303
+
304
+ with torch.no_grad():
305
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
306
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
307
+ torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
308
+
309
+ noise = torch.randn_like(lat_ans)
310
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
311
+ lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
312
+
313
+ prob = torch.rand(B, device=lat_ans.device)
314
+ mask = (prob<(DROP*2)).reshape(B, 1, 1)
315
+ hid_edit = torch.where(mask, hid_null, hid_edit)
316
+ mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
317
+ lat_inp *= mask
318
+
319
+ out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
320
+
321
+ loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
322
+ if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
323
+ loss = loss_ce+loss_edit*0.5
324
+
325
+ if not return_dict:
326
+ output = (logits,) + outputs[1:]
327
+ return (loss,) + output if loss is not None else output
328
+
329
+ return CausalLMOutputWithPast(
330
+ loss=loss,
331
+ logits=logits,
332
+ past_key_values=outputs.past_key_values,
333
+ hidden_states=outputs.hidden_states,
334
+ attentions=outputs.attentions,
335
+ )
336
+
337
+ def prepare_inputs_for_generation(
338
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
339
+ ):
340
+ if past_key_values:
341
+ input_ids = input_ids[:, -1:]
342
+
343
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
344
+ if inputs_embeds is not None and past_key_values is None:
345
+ model_inputs = {"inputs_embeds": inputs_embeds}
346
+ else:
347
+ model_inputs = {"input_ids": input_ids}
348
+
349
+ model_inputs.update(
350
+ {
351
+ "past_key_values": past_key_values,
352
+ "use_cache": kwargs.get("use_cache"),
353
+ "attention_mask": attention_mask,
354
+ "images": kwargs.get("images", None),
355
+ }
356
+ )
357
+ return model_inputs
358
+
359
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
360
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
361
+ vision_config = self.get_vision_tower().config
362
+ vision_config.use_im_start_end = mm_use_im_start_end
363
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
364
+ self.resize_token_embeddings(len(tokenizer))
365
+
366
+ if mm_use_im_start_end:
367
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
368
+ self.resize_token_embeddings(len(tokenizer))
369
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
370
+
371
+ if num_new_tokens > 0:
372
+ input_embeddings = self.get_input_embeddings().weight.data
373
+ output_embeddings = self.get_output_embeddings().weight.data
374
+
375
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
376
+ dim=0, keepdim=True)
377
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
378
+ dim=0, keepdim=True)
379
+
380
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
381
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
382
+
383
+ if tune_mm_mlp_adapter:
384
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
385
+ for p in self.get_input_embeddings().parameters():
386
+ p.requires_grad = True
387
+ for p in self.get_output_embeddings().parameters():
388
+ p.requires_grad = False
389
+
390
+ if pretrain_mm_mlp_adapter:
391
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
392
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
393
+ assert num_new_tokens == 2
394
+ if input_embeddings.shape == embed_tokens_weight.shape:
395
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
396
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
397
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
398
+ else:
399
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
400
+
401
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
402
+
403
+ AutoConfig.register("llava", LlavaConfig)
404
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
pre-requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ sentencepiece
2
+ transformers
3
+ diffusers
4
+ tokenizers
5
+ datasets
6
+ accelerate
7
+ evaluate
8
+ gradio
9
+ git+https://github.com/haotian-liu/LLaVA@7ace501
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ -i https://download.pytorch.org/whl/cu113
2
+ torch==1.12.0
3
+ torchvision==0.13.0
4
+ torchaudio==0.12.0
train.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/train/train.py
3
+
4
+ import os
5
+ import copy
6
+ from dataclasses import dataclass, field
7
+ import json
8
+ import logging
9
+ import pathlib
10
+ from typing import Dict, Optional, Sequence, List
11
+
12
+ import torch
13
+
14
+ import transformers
15
+ from torch.utils.data import Dataset
16
+ from llava.train.llava_trainer import LLaVATrainer
17
+
18
+ from llava import conversation as conversation_lib
19
+ from llava.model import *
20
+
21
+ from PIL import Image
22
+ import torch.nn as nn
23
+
24
+ # TODO: import and use code from ../data/dataset.py
25
+
26
+ IGNORE_INDEX = -100
27
+ DEFAULT_PAD_TOKEN = "[PAD]"
28
+ DEFAULT_EOS_TOKEN = "</s>"
29
+ DEFAULT_BOS_TOKEN = "<s>"
30
+ DEFAULT_UNK_TOKEN = "<unk>"
31
+ DEFAULT_IMAGE_TOKEN = "<image>"
32
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
33
+ DEFAULT_IM_START_TOKEN = "<im_start>"
34
+ DEFAULT_IM_END_TOKEN = "<im_end>"
35
+
36
+ import io, base64, pickle, random
37
+ from tqdm import tqdm
38
+ import numpy as np
39
+
40
+ def b2f(b): return Image.open(io.BytesIO(base64.b64decode(b))).convert('RGB')
41
+ def resize(f):
42
+ w, h = f.size
43
+ if w>h:
44
+ p = (w-h)//2
45
+ f = f.crop([p, 0, p+h, h])
46
+ elif h>w:
47
+ p = (h-w)//2
48
+ f = f.crop([0, p, w, p+w])
49
+ f = f.resize([512, 512])
50
+ return f
51
+ def img2npy(f): return (2.0*np.array(f)/255.0-1.0).transpose((2, 0, 1)).astype(np.float32)
52
+
53
+ @dataclass
54
+ class ModelArguments:
55
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56
+ version: Optional[str] = field(default="v0")
57
+ freeze_backbone: bool = field(default=False)
58
+ tune_mm_mlp_adapter: bool = field(default=False)
59
+ vision_tower: Optional[str] = field(default=None)
60
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
61
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
62
+ mm_use_im_start_end: bool = field(default=False)
63
+
64
+
65
+ @dataclass
66
+ class DataArguments:
67
+ data_path: str = field(default=None,
68
+ metadata={"help": "Path to the training data."})
69
+ lazy_preprocess: bool = False
70
+ is_multimodal: bool = False
71
+ sep_image_conv_front: bool = False
72
+ image_token_len: int = 0
73
+ image_folder: Optional[str] = field(default=None)
74
+ image_aspect_ratio: str = 'square'
75
+
76
+
77
+ @dataclass
78
+ class TrainingArguments(transformers.TrainingArguments):
79
+ cache_dir: Optional[str] = field(default=None)
80
+ optim: str = field(default="adamw_torch")
81
+ remove_unused_columns: bool = field(default=False)
82
+ freeze_mm_mlp_adapter: bool = field(default=False)
83
+ force_fsdp: bool = field(default=False)
84
+ model_max_length: int = field(
85
+ default=512,
86
+ metadata={
87
+ "help":
88
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
89
+ },
90
+ )
91
+ double_quant: bool = field(
92
+ default=True,
93
+ metadata={"help": "Compress the quantization statistics through double quantization."}
94
+ )
95
+ quant_type: str = field(
96
+ default="nf4",
97
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
98
+ )
99
+ bits: int = field(
100
+ default=16,
101
+ metadata={"help": "How many bits to use."}
102
+ )
103
+ lora_enable: bool = False
104
+ lora_r: int = 64
105
+ lora_alpha: int = 16
106
+ lora_dropout: float = 0.05
107
+ lora_weight_path: str = ""
108
+ lora_bias: str = "none"
109
+
110
+
111
+ def maybe_zero_3(param, ignore_status=False, name=None):
112
+ from deepspeed import zero
113
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
114
+ if hasattr(param, "ds_id"):
115
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
116
+ if not ignore_status:
117
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
118
+ with zero.GatheredParameters([param]):
119
+ param = param.data.detach().cpu().clone()
120
+ else:
121
+ param = param.detach().cpu().clone()
122
+ return param
123
+
124
+
125
+ # Borrowed from peft.utils.get_peft_model_state_dict
126
+ def get_peft_state_maybe_zero_3(named_params, bias):
127
+ if bias == "none":
128
+ to_return = {k: t for k, t in named_params if "lora_" in k}
129
+ elif bias == "all":
130
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
131
+ elif bias == "lora_only":
132
+ to_return = {}
133
+ maybe_lora_bias = {}
134
+ lora_bias_names = set()
135
+ for k, t in named_params:
136
+ if "lora_" in k:
137
+ to_return[k] = t
138
+ bias_name = k.split("lora_")[0] + "bias"
139
+ lora_bias_names.add(bias_name)
140
+ elif "bias" in k:
141
+ maybe_lora_bias[k] = t
142
+ for k, t in maybe_lora_bias:
143
+ if bias_name in lora_bias_names:
144
+ to_return[bias_name] = t
145
+ else:
146
+ raise NotImplementedError
147
+ to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
148
+ return to_return
149
+
150
+
151
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
152
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
153
+ if require_grad_only:
154
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
155
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
156
+ return to_return
157
+
158
+
159
+ def find_all_linear_names(model):
160
+ cls = torch.nn.Linear
161
+ lora_module_names = set()
162
+ for name, module in model.named_modules():
163
+ if isinstance(module, cls):
164
+ names = name.split('.')
165
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
166
+
167
+
168
+ if 'lm_head' in lora_module_names: # needed for 16-bit
169
+ lora_module_names.remove('lm_head')
170
+ return list(lora_module_names)
171
+
172
+
173
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
174
+ output_dir: str):
175
+ """Collects the state dict and dump to disk."""
176
+ if trainer.deepspeed:
177
+ torch.cuda.synchronize()
178
+ trainer.save_model(output_dir)
179
+ return
180
+
181
+ state_dict = trainer.model.state_dict()
182
+ if trainer.args.should_save:
183
+ cpu_state_dict = {
184
+ key: value.cpu()
185
+ for key, value in state_dict.items()
186
+ }
187
+ del state_dict
188
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
189
+
190
+
191
+ def smart_tokenizer_and_embedding_resize(
192
+ special_tokens_dict: Dict,
193
+ tokenizer: transformers.PreTrainedTokenizer,
194
+ model: transformers.PreTrainedModel,
195
+ ):
196
+ """Resize tokenizer and embedding.
197
+
198
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
199
+ """
200
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
201
+ model.resize_token_embeddings(len(tokenizer))
202
+
203
+ if num_new_tokens > 0:
204
+ input_embeddings = model.get_input_embeddings().weight.data
205
+ output_embeddings = model.get_output_embeddings().weight.data
206
+
207
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
208
+ dim=0, keepdim=True)
209
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
210
+ dim=0, keepdim=True)
211
+
212
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
213
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
214
+
215
+
216
+ def _tokenize_fn(strings: Sequence[str],
217
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
218
+ """Tokenize a list of strings."""
219
+ tokenized_list = [
220
+ tokenizer(
221
+ text,
222
+ return_tensors="pt",
223
+ padding="longest",
224
+ max_length=tokenizer.model_max_length,
225
+ truncation=True,
226
+ ) for text in strings
227
+ ]
228
+ input_ids = labels = [
229
+ tokenized.input_ids[0] for tokenized in tokenized_list
230
+ ]
231
+ input_ids_lens = labels_lens = [
232
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
233
+ for tokenized in tokenized_list
234
+ ]
235
+ return dict(
236
+ input_ids=input_ids,
237
+ labels=labels,
238
+ input_ids_lens=input_ids_lens,
239
+ labels_lens=labels_lens,
240
+ )
241
+
242
+
243
+ def _mask_targets(target, tokenized_lens, speakers):
244
+ # cur_idx = 0
245
+ cur_idx = tokenized_lens[0]
246
+ tokenized_lens = tokenized_lens[1:]
247
+ target[:cur_idx] = IGNORE_INDEX
248
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
249
+ if speaker == "human":
250
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
251
+ cur_idx += tokenized_len
252
+
253
+
254
+ def _add_speaker_and_signal(header, source, get_conversation=True):
255
+ """Add speaker and start/end signal on each round."""
256
+ BEGIN_SIGNAL = "### "
257
+ END_SIGNAL = "\n"
258
+ conversation = header
259
+ for sentence in source:
260
+ from_str = sentence["from"]
261
+ if from_str.lower() == "human":
262
+ from_str = conversation_lib.default_conversation.roles[0]
263
+ elif from_str.lower() == "gpt":
264
+ from_str = conversation_lib.default_conversation.roles[1]
265
+ else:
266
+ from_str = 'unknown'
267
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
268
+ sentence["value"] + END_SIGNAL)
269
+ if get_conversation:
270
+ conversation += sentence["value"]
271
+ conversation += BEGIN_SIGNAL
272
+ return conversation
273
+
274
+
275
+ def preprocess_multimodal(
276
+ sources: Sequence[str],
277
+ multimodal_cfg: dict,
278
+ cur_token_len: int,
279
+ ) -> Dict:
280
+ is_multimodal = multimodal_cfg['is_multimodal']
281
+ # image_token_len = multimodal_cfg['image_token_len']
282
+ image_token_len = cur_token_len
283
+ if not is_multimodal:
284
+ return sources
285
+
286
+ for source in sources:
287
+ if multimodal_cfg['sep_image_conv_front']:
288
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
289
+ source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
290
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
291
+ for sentence in source:
292
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
293
+ if multimodal_cfg['use_im_start_end']:
294
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
295
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
296
+
297
+ return sources
298
+
299
+
300
+ def preprocess_v1(
301
+ sources,
302
+ tokenizer: transformers.PreTrainedTokenizer,
303
+ ) -> Dict:
304
+ conv = conversation_lib.default_conversation.copy()
305
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
306
+
307
+ # Apply prompt templates
308
+ conversations = []
309
+ for i, source in enumerate(sources):
310
+ if roles[source[0]["from"]] != conv.roles[0]:
311
+ # Skip the first one if it is not from human
312
+ source = source[1:]
313
+
314
+ conv.messages = []
315
+ for j, sentence in enumerate(source):
316
+ role = roles[sentence["from"]]
317
+ assert role == conv.roles[j % 2], f"{i}"
318
+ conv.append_message(role, sentence["value"])
319
+ conversations.append(conv.get_prompt())
320
+
321
+ # Tokenize conversations
322
+ input_ids = tokenizer(
323
+ conversations,
324
+ return_tensors="pt",
325
+ padding="longest",
326
+ max_length=tokenizer.model_max_length,
327
+ truncation=True,
328
+ ).input_ids
329
+ targets = input_ids.clone()
330
+
331
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
332
+
333
+ # Mask targets
334
+ sep = conv.sep + conv.roles[1] + ": "
335
+ for conversation, target in zip(conversations, targets):
336
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
337
+
338
+ rounds = conversation.split(conv.sep2)
339
+ cur_len = 1
340
+ target[:cur_len] = IGNORE_INDEX
341
+ for i, rou in enumerate(rounds):
342
+ if rou == "":
343
+ break
344
+
345
+ parts = rou.split(sep)
346
+ if len(parts) != 2:
347
+ break
348
+ parts[0] += sep
349
+ round_len = len(tokenizer(rou).input_ids)
350
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
351
+
352
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
353
+
354
+ cur_len += round_len
355
+ target[cur_len:] = IGNORE_INDEX
356
+
357
+ if cur_len < tokenizer.model_max_length:
358
+ if cur_len != total_len:
359
+ target[:] = IGNORE_INDEX
360
+ print(
361
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
362
+ f" (ignored)"
363
+ )
364
+
365
+ return dict(
366
+ input_ids=input_ids,
367
+ labels=targets,
368
+ )
369
+
370
+ def preprocess_mpt(
371
+ sources,
372
+ tokenizer: transformers.PreTrainedTokenizer,
373
+ ) -> Dict:
374
+ conv = conversation_lib.default_conversation.copy()
375
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
376
+
377
+ # Apply prompt templates
378
+ conversations = []
379
+ for i, source in enumerate(sources):
380
+ if roles[source[0]["from"]] != conv.roles[0]:
381
+ # Skip the first one if it is not from human
382
+ source = source[1:]
383
+
384
+ conv.messages = []
385
+ for j, sentence in enumerate(source):
386
+ role = roles[sentence["from"]]
387
+ assert role == conv.roles[j % 2], f"{i}"
388
+ conv.append_message(role, sentence["value"])
389
+ conversations.append(conv.get_prompt())
390
+
391
+ # Tokenize conversations
392
+ input_ids = tokenizer(
393
+ conversations,
394
+ return_tensors="pt",
395
+ padding="longest",
396
+ max_length=tokenizer.model_max_length,
397
+ truncation=True,
398
+ ).input_ids
399
+ targets = input_ids.clone()
400
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
401
+
402
+ # Mask targets
403
+ sep = conv.sep + conv.roles[1]
404
+ for conversation, target in zip(conversations, targets):
405
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
406
+
407
+ rounds = conversation.split(conv.sep)
408
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
409
+ for conv_idx in range(3, len(rounds), 2):
410
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
411
+ cur_len = 0
412
+ target[:cur_len] = IGNORE_INDEX
413
+ for i, rou in enumerate(re_rounds):
414
+ if rou == "":
415
+ break
416
+
417
+ parts = rou.split(sep)
418
+ if len(parts) != 2:
419
+ break
420
+ parts[0] += sep
421
+ round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
422
+ instruction_len = len(tokenizer(parts[0]).input_ids)
423
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
424
+
425
+ cur_len += round_len
426
+ target[cur_len:] = IGNORE_INDEX
427
+
428
+ if cur_len < tokenizer.model_max_length:
429
+ if cur_len != total_len:
430
+ target[:] = IGNORE_INDEX
431
+ print(
432
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
433
+ f" (ignored)"
434
+ )
435
+
436
+ return dict(
437
+ input_ids=input_ids,
438
+ labels=targets,
439
+ )
440
+
441
+
442
+ def preprocess(
443
+ sources: Sequence[str],
444
+ tokenizer: transformers.PreTrainedTokenizer,
445
+ ) -> Dict:
446
+ """
447
+ Given a list of sources, each is a conversation list. This transform:
448
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
449
+ 2. Concatenate conversations together;
450
+ 3. Tokenize the concatenated conversation;
451
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
452
+ """
453
+ if conversation_lib.default_conversation.version == "v1":
454
+ return preprocess_v1(sources, tokenizer)
455
+ if conversation_lib.default_conversation.version == "mpt":
456
+ return preprocess_mpt(sources, tokenizer)
457
+ # add end signal and concatenate together
458
+ conversations = []
459
+ for source in sources:
460
+ header = f"{conversation_lib.default_conversation.system}\n\n"
461
+ conversation = _add_speaker_and_signal(header, source)
462
+ conversations.append(conversation)
463
+ # tokenize conversations
464
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
465
+ input_ids = conversations_tokenized["input_ids"]
466
+ targets = copy.deepcopy(input_ids)
467
+ for target, source in zip(targets, sources):
468
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
469
+ tokenizer)["input_ids_lens"]
470
+ speakers = [sentence["from"] for sentence in source]
471
+ _mask_targets(target, tokenized_lens, speakers)
472
+
473
+ return dict(input_ids=input_ids, labels=targets)
474
+
475
+
476
+ class SupervisedDataset(Dataset):
477
+ """Dataset for supervised fine-tuning."""
478
+
479
+ def __init__(self, data_path: str,
480
+ tokenizer: transformers.PreTrainedTokenizer):
481
+ super(SupervisedDataset, self).__init__()
482
+ logging.warning("Loading data...")
483
+ list_data_dict = json.load(open(data_path, "r"))
484
+
485
+ logging.warning("Formatting inputs...")
486
+ sources = [example["conversations"] for example in list_data_dict]
487
+ data_dict = preprocess(sources, tokenizer)
488
+
489
+ self.input_ids = data_dict["input_ids"]
490
+ self.labels = data_dict["labels"]
491
+
492
+ def __len__(self):
493
+ return len(self.input_ids)
494
+
495
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
496
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
497
+
498
+
499
+ class LazySupervisedDataset(Dataset):
500
+
501
+ def __init__(self, data_path: str,
502
+ tokenizer: transformers.PreTrainedTokenizer,
503
+ multimodal_cfg: dict):
504
+ super(LazySupervisedDataset, self).__init__()
505
+
506
+ self.tokenizer, self.multimodal_cfg = tokenizer, multimodal_cfg
507
+
508
+ self.pkl, self.prompt = pickle.load(open('./_data/ipr2pr.pkl', 'rb'))['task'], json.load(open('./_data/ipr2pr_expressive.json', 'r'))
509
+ random.shuffle(self.pkl)
510
+ print('--pkl: %d--'%(len(self.pkl)))
511
+
512
+ def __len__(self):
513
+ return len(self.pkl)
514
+
515
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
516
+ item = self.pkl[i][0]
517
+
518
+ tsv = open('./_data/ipr2pr.tsv', 'r')
519
+ tsv.seek(item['lineidx'])
520
+ b = tsv.readline().strip().split('\t')
521
+ image = resize(b2f(b[0]))
522
+
523
+ processor = self.multimodal_cfg['image_processor']
524
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
525
+
526
+ cur_token_len = (image.shape[1]//14)*(image.shape[2]//14)
527
+ query = "what will this image be like if '%s'\n%s"%(item['instruction'], DEFAULT_IMAGE_TOKEN)
528
+ ans = '%s [IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]'%(self.prompt[item['input']]['expressive'])
529
+ sources = preprocess_multimodal(copy.deepcopy([[{'from': 'human', 'value': query}, {'from': 'gpt', 'value': ans}]]),
530
+ self.multimodal_cfg, cur_token_len)
531
+
532
+ data_dict = preprocess(sources, self.tokenizer)
533
+ if isinstance(i, int): data_dict = dict(input_ids=data_dict['input_ids'][0],
534
+ labels=data_dict['labels'][0])
535
+ data_dict['image'] = image
536
+
537
+ p2p_inp, p2p_ans = img2npy(resize(b2f(b[0])).resize([256, 256])), img2npy(resize(b2f(b[1])).resize([256, 256]))
538
+ data_dict['p2p_inp'], data_dict['p2p_ans'] = p2p_inp, p2p_ans
539
+
540
+ return data_dict
541
+
542
+
543
+ @dataclass
544
+ class DataCollatorForSupervisedDataset(object):
545
+ """Collate examples for supervised fine-tuning."""
546
+
547
+ tokenizer: transformers.PreTrainedTokenizer
548
+
549
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
550
+ input_ids, labels = tuple([instance[key] for instance in instances]
551
+ for key in ("input_ids", "labels"))
552
+ input_ids = torch.nn.utils.rnn.pad_sequence(
553
+ input_ids,
554
+ batch_first=True,
555
+ padding_value=self.tokenizer.pad_token_id)
556
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
557
+ batch_first=True,
558
+ padding_value=IGNORE_INDEX)
559
+ batch = dict(
560
+ input_ids=input_ids,
561
+ labels=labels,
562
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
563
+ )
564
+
565
+ if 'image' in instances[0]:
566
+ images = [instance['image'] for instance in instances]
567
+ if all(x is not None and x.shape == images[0].shape for x in images):
568
+ batch['images'] = torch.stack(images)
569
+ else:
570
+ batch['images'] = images
571
+
572
+ batch['p2p_inp'], batch['p2p_ans'] = [torch.cat([torch.from_numpy(d['p2p_inp']).unsqueeze(dim=0) for d in instances], dim=0),
573
+ torch.cat([torch.from_numpy(d['p2p_ans']).unsqueeze(dim=0) for d in instances], dim=0)]
574
+
575
+ return batch
576
+
577
+
578
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
579
+ data_args) -> Dict:
580
+ """Make dataset and collator for supervised fine-tuning."""
581
+ dataset_cls = (LazySupervisedDataset
582
+ if data_args.lazy_preprocess else SupervisedDataset)
583
+ train_dataset = dataset_cls(tokenizer=tokenizer,
584
+ data_path=data_args.data_path,
585
+ multimodal_cfg=dict(
586
+ is_multimodal=data_args.is_multimodal,
587
+ sep_image_conv_front=data_args.sep_image_conv_front,
588
+ image_token_len=data_args.image_token_len,
589
+ image_folder=data_args.image_folder,
590
+ image_aspect_ratio=data_args.image_aspect_ratio,
591
+ use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False),
592
+ image_processor=getattr(data_args, 'image_processor', None)))
593
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
594
+ return dict(train_dataset=train_dataset,
595
+ eval_dataset=None,
596
+ data_collator=data_collator)
597
+
598
+
599
+ def train():
600
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
601
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
602
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
603
+
604
+ bnb_model_from_pretrained_args = {}
605
+ if training_args.bits in [4, 8]:
606
+ from transformers import BitsAndBytesConfig
607
+ from peft import prepare_model_for_int8_training
608
+ bnb_model_from_pretrained_args.update(dict(
609
+ device_map={"": training_args.device},
610
+ load_in_4bit=training_args.bits == 4,
611
+ load_in_8bit=training_args.bits == 8,
612
+ quantization_config=BitsAndBytesConfig(
613
+ load_in_4bit=training_args.bits == 4,
614
+ load_in_8bit=training_args.bits == 8,
615
+ llm_int8_threshold=6.0,
616
+ llm_int8_has_fp16_weight=False,
617
+ bnb_4bit_compute_dtype=compute_dtype,
618
+ bnb_4bit_use_double_quant=training_args.double_quant,
619
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
620
+ )
621
+ ))
622
+
623
+ if model_args.vision_tower is not None:
624
+ if 'mpt' in model_args.model_name_or_path:
625
+ model = LlavaMPTForCausalLM.from_pretrained(
626
+ model_args.model_name_or_path,
627
+ cache_dir=training_args.cache_dir,
628
+ **bnb_model_from_pretrained_args
629
+ )
630
+ else:
631
+ model = LlavaLlamaForCausalLM.from_pretrained(
632
+ model_args.model_name_or_path,
633
+ cache_dir=training_args.cache_dir,
634
+ **bnb_model_from_pretrained_args
635
+ )
636
+ else:
637
+ model = transformers.LlamaForCausalLM.from_pretrained(
638
+ model_args.model_name_or_path,
639
+ cache_dir=training_args.cache_dir,
640
+ **bnb_model_from_pretrained_args
641
+ )
642
+ model.config.use_cache = False
643
+
644
+ if model_args.freeze_backbone:
645
+ model.model.requires_grad_(False)
646
+
647
+ if training_args.bits in [4, 8]:
648
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
649
+ model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
650
+
651
+ if training_args.gradient_checkpointing and model_args.vision_tower is None:
652
+ if hasattr(model, "enable_input_require_grads"):
653
+ model.enable_input_require_grads()
654
+ else:
655
+ def make_inputs_require_grad(module, input, output):
656
+ output.requires_grad_(True)
657
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
658
+
659
+ if training_args.lora_enable:
660
+ from peft import LoraConfig, get_peft_model
661
+ lora_config = LoraConfig(
662
+ r=training_args.lora_r,
663
+ lora_alpha=training_args.lora_alpha,
664
+ target_modules=find_all_linear_names(model),
665
+ lora_dropout=training_args.lora_dropout,
666
+ bias=training_args.lora_bias,
667
+ task_type="CAUSAL_LM",
668
+ )
669
+ if training_args.bits == 16:
670
+ if training_args.bf16:
671
+ model.to(torch.bfloat16)
672
+ if training_args.fp16:
673
+ model.to(torch.float16)
674
+ logging.warning("Adding LoRA adapters...")
675
+ model = get_peft_model(model, lora_config)
676
+
677
+ if 'mpt' in model_args.model_name_or_path:
678
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
679
+ model_args.model_name_or_path,
680
+ cache_dir=training_args.cache_dir,
681
+ model_max_length=training_args.model_max_length,
682
+ padding_side="right"
683
+ )
684
+ else:
685
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
686
+ model_args.model_name_or_path,
687
+ cache_dir=training_args.cache_dir,
688
+ model_max_length=training_args.model_max_length,
689
+ padding_side="right",
690
+ use_fast=False,
691
+ )
692
+
693
+ if model_args.version == "v0":
694
+ if tokenizer.pad_token is None:
695
+ smart_tokenizer_and_embedding_resize(
696
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
697
+ tokenizer=tokenizer,
698
+ model=model,
699
+ )
700
+ if "llama" in model_args.model_name_or_path:
701
+ tokenizer.add_special_tokens({
702
+ "eos_token": DEFAULT_EOS_TOKEN,
703
+ "bos_token": DEFAULT_BOS_TOKEN,
704
+ "unk_token": DEFAULT_UNK_TOKEN,
705
+ })
706
+ else:
707
+ tokenizer.pad_token = tokenizer.unk_token
708
+ if "mpt" in model_args.model_name_or_path:
709
+ conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
710
+ else:
711
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
712
+
713
+ if model_args.vision_tower is not None:
714
+ model_vision_dict = model.get_model().initialize_vision_modules(
715
+ vision_tower=model_args.vision_tower,
716
+ mm_vision_select_layer=model_args.mm_vision_select_layer,
717
+ pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
718
+ fsdp=training_args.fsdp
719
+ )
720
+ model.get_vision_tower().to(dtype=torch.float16, device=training_args.device)
721
+ vision_config = model_vision_dict['vision_config']
722
+
723
+ data_args.image_token_len = model_vision_dict['image_token_len']
724
+ data_args.image_processor = model_vision_dict['image_processor']
725
+ data_args.is_multimodal = True
726
+
727
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
728
+ if model_args.tune_mm_mlp_adapter:
729
+ model.requires_grad_(False)
730
+ for p in model.get_model().mm_projector.parameters():
731
+ p.requires_grad = True
732
+
733
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
734
+ if training_args.freeze_mm_mlp_adapter:
735
+ for p in model.get_model().mm_projector.parameters():
736
+ p.requires_grad = False
737
+
738
+ if training_args.bits in [4, 8]:
739
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
740
+
741
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
742
+ vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end
743
+ model.config.sep_image_conv_front = data_args.sep_image_conv_front
744
+ model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device,
745
+ tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
746
+
747
+ params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
748
+ if len(params_no_grad) > 0:
749
+ if training_args.fsdp is not None and len(training_args.fsdp) > 0:
750
+ if len(params_no_grad) < 10:
751
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
752
+ else:
753
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
754
+ print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
755
+ print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
756
+
757
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
758
+ def patch_FSDP_use_orig_params(func):
759
+ def wrap_func(*args, **kwargs):
760
+ use_orig_params = kwargs.pop('use_orig_params', True)
761
+ return func(*args, **kwargs, use_orig_params=use_orig_params)
762
+ return wrap_func
763
+
764
+ FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
765
+
766
+ if training_args.bits in [4, 8]:
767
+ from peft.tuners.lora import LoraLayer
768
+ for name, module in model.named_modules():
769
+ if isinstance(module, LoraLayer):
770
+ if training_args.bf16:
771
+ module = module.to(torch.bfloat16)
772
+ if 'norm' in name:
773
+ module = module.to(torch.float32)
774
+ if 'lm_head' in name or 'embed_tokens' in name:
775
+ if hasattr(module, 'weight'):
776
+ if training_args.bf16 and module.weight.dtype == torch.float32:
777
+ module = module.to(torch.bfloat16)
778
+
779
+ # start for MGIE
780
+ os.makedirs('_log', exist_ok=True)
781
+
782
+ pt = {}
783
+ for i in tqdm(range(2)): pt.update(torch.load('./_ckpt/LLaVA-7B-v1/pytorch_model-0000%d-of-00002.bin'%(i+1), map_location='cpu'))
784
+ miss, unexp = model.load_state_dict(pt, strict=False)
785
+ print('miss:', miss), print('unexp:', unexp)
786
+
787
+ tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
788
+ model.resize_token_embeddings(len(tokenizer))
789
+ print(tokenizer), json.dump(tokenizer.get_vocab(), open('_log/vocabs.json', 'w'), indent=2)
790
+
791
+ for n, p in model.named_parameters():
792
+ if 'embed_tokens' in n or 'lm_head' in n or 'edit_head' in n or 'unet' in n: p.requires_grad = True
793
+ else: p.requires_grad = False
794
+ with open('_log/parameters.txt', 'w') as F:
795
+ for n, p in model.named_parameters(): F.write('%s %s %s\n'%(n, str(p.shape), str(p.requires_grad)))
796
+
797
+ with open('_log/args_train.txt', 'w') as F:
798
+ for key in vars(training_args): F.write('%s: %s\n'%(str(key), str(vars(training_args)[key])))
799
+ # end for MGIE
800
+
801
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
802
+ data_args=data_args)
803
+ trainer = LLaVATrainer(model=model,
804
+ tokenizer=tokenizer,
805
+ args=training_args,
806
+ **data_module)
807
+
808
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
809
+ trainer.train(resume_from_checkpoint=True)
810
+ else:
811
+ trainer.train()
812
+ trainer.save_state()
813
+
814
+ if training_args.lora_enable:
815
+ state_dict = get_peft_state_maybe_zero_3(
816
+ model.named_parameters(), training_args.lora_bias
817
+ )
818
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
819
+ model.named_parameters()
820
+ )
821
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
822
+ model.config.save_pretrained(training_args.output_dir)
823
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
824
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
825
+ else:
826
+ safe_save_model_for_hf_trainer(trainer=trainer,
827
+ output_dir=training_args.output_dir)
828
+
829
+
830
+ if __name__ == "__main__":
831
+ train()