Singularity666 commited on
Commit
7f6295c
·
verified ·
1 Parent(s): 55553ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +829 -47
app.py CHANGED
@@ -1,44 +1,826 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- import huggingface_hub, spaces
4
- huggingface_hub.snapshot_download(repo_id='tsujuifu/ml-mgie', repo_type='model', local_dir='_ckpt', local_dir_use_symlinks=False)
5
- os.system('ls _ckpt')
6
 
7
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import torch as T
11
- import transformers, diffusers
 
 
 
 
 
 
12
 
13
- from conversation import conv_templates
14
- from mgie_llava import *
15
 
16
- import gradio as gr
17
 
18
  def crop_resize(f, sz=512):
19
  w, h = f.size
20
- if w>h:
21
- p = (w-h)//2
22
- f = f.crop([p, 0, p+h, h])
23
- elif h>w:
24
- p = (h-w)//2
25
- f = f.crop([0, p, w, p+w])
26
  f = f.resize([sz, sz])
27
  return f
28
- def remove_alter(s): # hack expressive instruction
29
- if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()
 
30
  if '</s>' in s: s = s[:s.index('</s>')].strip()
31
  if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
32
  if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
33
  s = '.'.join([s.strip() for s in s.split('.')[:2]])
34
- if s[-1]!='.': s += '.'
35
  return s.strip()
36
 
37
  DEFAULT_IMAGE_TOKEN = '<image>'
38
  DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
39
  DEFAULT_IM_START_TOKEN = '<im_start>'
40
  DEFAULT_IM_END_TOKEN = '<im_end>'
41
- PATH_LLAVA = '_ckpt/LLaVA-7B-v1'
42
 
43
  tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
44
  model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
@@ -47,7 +829,7 @@ image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.m
47
  tokenizer.padding_side = 'left'
48
  tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
49
  model.resize_token_embeddings(len(tokenizer))
50
- ckpt = T.load('_ckpt/mgie_7b/mllm.pt', map_location='cpu')
51
  model.load_state_dict(ckpt, strict=False)
52
 
53
  mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
@@ -61,26 +843,25 @@ vision_config = vision_tower.config
61
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
62
  vision_config.use_im_start_end = mm_use_im_start_end
63
  if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
64
- image_token_len = (vision_config.image_size//vision_config.patch_size)**2
65
 
66
  _ = model.eval()
67
 
68
  pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
69
  pipe.set_progress_bar_config(disable=True)
70
- pipe.unet.load_state_dict(T.load('_ckpt/mgie_7b/unet.pt', map_location='cpu'))
71
  print('--init MGIE--')
72
 
73
- @spaces.GPU(enable_queue=True)
74
  def go_mgie(img, txt, seed, cfg_txt, cfg_img):
75
  EMB = ckpt['emb'].cuda()
76
  with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
77
-
78
  img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
79
  inp = img
80
 
81
  img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
82
- txt = "what will this image be like if '%s'"%(txt)
83
- txt = txt+'\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
84
  conv = conv_templates['vicuna_v1_1'].copy()
85
  conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
86
  txt = conv.get_prompt()
@@ -89,41 +870,42 @@ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
89
 
90
  with T.inference_mode():
91
  _ = model.cuda()
92
- out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
93
- do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
94
  return_dict_in_generate=True, output_hidden_states=True)
95
  out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]
96
-
97
- if 32003 in out: p = out.index(32003)-1
98
- else: p = len(hid)-9
99
- p = min(p, len(hid)-9)
100
- hid = hid[p:p+8]
101
 
102
  out = remove_alter(tokenizer.decode(out))
103
  _ = model.cuda()
104
  emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
105
- res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
106
  generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
107
 
108
  return res, out
109
 
110
- go_mgie(np.array(Image.open('./_input/0.jpg').convert('RGB')), 'make the frame red', 13331, 7.5, 1.5)
111
- print('--init GO--')
112
-
113
  with gr.Blocks() as app:
114
  gr.Markdown(
115
  """
116
- # MagiX: Edit Personalized Images using Gen AI
117
  """
118
  )
119
- with gr.Row(): inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
120
- gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
121
- with gr.Row(): txt, out = [gr.Textbox(label='Instruction', interactive=True),
122
- gr.Textbox(label='Expressive Instruction', interactive=False)]
123
- with gr.Row(): seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
124
- gr.Number(value=7.5, label='Text CFG', interactive=True),
125
- gr.Number(value=1.5, label='Image CFG', interactive=True)]
126
- with gr.Row(): btn_sub = gr.Button('Submit')
 
 
 
 
127
  btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
128
-
129
  app.launch()
 
1
+ !pip install sentencepiece
2
+ !pip install git+https://github.com/huggingface/transformers.git@cae78c46
3
+ !pip install diffusers
4
+ !pip install tokenizers==0.12.1
5
+ !pip install datasets
6
+ !pip install accelerate
7
+ !pip install evaluate
8
+ !pip install gradio==4.12.0
9
+ !pip install gradio_client==0.8.0
10
+ !pip install -i https://download.pytorch.org/whl/cu118 torch==2.0 torchvision==0.15 torchaudio==2.0
11
+ #conversation.py:
12
+ import dataclasses
13
+ from enum import auto, Enum
14
+ from typing import List, Tuple
15
 
 
 
 
16
 
17
+ class SeparatorStyle(Enum):
18
+ """Different separator style."""
19
+ SINGLE = auto()
20
+ TWO = auto()
21
+ MPT = auto()
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class Conversation:
26
+ """A class that keeps all conversation history."""
27
+ system: str
28
+ roles: List[str]
29
+ messages: List[List[str]]
30
+ offset: int
31
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32
+ sep: str = "###"
33
+ sep2: str = None
34
+ version: str = "Unknown"
35
+
36
+ skip_next: bool = False
37
+
38
+ def get_prompt(self):
39
+ if self.sep_style == SeparatorStyle.SINGLE:
40
+ ret = self.system + self.sep
41
+ for role, message in self.messages:
42
+ if message:
43
+ if type(message) is tuple:
44
+ message, _, _ = message
45
+ ret += role + ": " + message + self.sep
46
+ else:
47
+ ret += role + ":"
48
+ return ret
49
+ elif self.sep_style == SeparatorStyle.TWO:
50
+ seps = [self.sep, self.sep2]
51
+ ret = self.system + seps[0]
52
+ for i, (role, message) in enumerate(self.messages):
53
+ if message:
54
+ if type(message) is tuple:
55
+ message, _, _ = message
56
+ ret += role + ": " + message + seps[i % 2]
57
+ else:
58
+ ret += role + ":"
59
+ return ret
60
+ if self.sep_style == SeparatorStyle.MPT:
61
+ ret = self.system + self.sep
62
+ for role, message in self.messages:
63
+ if message:
64
+ if type(message) is tuple:
65
+ message, _, _ = message
66
+ ret += role + message + self.sep
67
+ else:
68
+ ret += role
69
+ return ret
70
+ else:
71
+ raise ValueError(f"Invalid style: {self.sep_style}")
72
+
73
+ def append_message(self, role, message):
74
+ self.messages.append([role, message])
75
+
76
+ def get_images(self, return_pil=False):
77
+ images = []
78
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
79
+ if i % 2 == 0:
80
+ if type(msg) is tuple:
81
+ import base64
82
+ from io import BytesIO
83
+ from PIL import Image
84
+ msg, image, image_process_mode = msg
85
+ if image_process_mode == "Pad":
86
+ def expand2square(pil_img, background_color=(122, 116, 104)):
87
+ width, height = pil_img.size
88
+ if width == height:
89
+ return pil_img
90
+ elif width > height:
91
+ result = Image.new(pil_img.mode, (width, width), background_color)
92
+ result.paste(pil_img, (0, (width - height) // 2))
93
+ return result
94
+ else:
95
+ result = Image.new(pil_img.mode, (height, height), background_color)
96
+ result.paste(pil_img, ((height - width) // 2, 0))
97
+ return result
98
+ image = expand2square(image)
99
+ elif image_process_mode == "Crop":
100
+ pass
101
+ elif image_process_mode == "Resize":
102
+ image = image.resize((224, 224))
103
+ else:
104
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
105
+ max_hw, min_hw = max(image.size), min(image.size)
106
+ aspect_ratio = max_hw / min_hw
107
+ max_len, min_len = 800, 400
108
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
109
+ longest_edge = int(shortest_edge * aspect_ratio)
110
+ W, H = image.size
111
+ if H > W:
112
+ H, W = longest_edge, shortest_edge
113
+ else:
114
+ H, W = shortest_edge, longest_edge
115
+ image = image.resize((W, H))
116
+ if return_pil:
117
+ images.append(image)
118
+ else:
119
+ buffered = BytesIO()
120
+ image.save(buffered, format="JPEG")
121
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
122
+ images.append(img_b64_str)
123
+ return images
124
+
125
+ def to_gradio_chatbot(self):
126
+ ret = []
127
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
128
+ if i % 2 == 0:
129
+ if type(msg) is tuple:
130
+ import base64
131
+ from io import BytesIO
132
+ msg, image, image_process_mode = msg
133
+ max_hw, min_hw = max(image.size), min(image.size)
134
+ aspect_ratio = max_hw / min_hw
135
+ max_len, min_len = 800, 400
136
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
137
+ longest_edge = int(shortest_edge * aspect_ratio)
138
+ W, H = image.size
139
+ if H > W:
140
+ H, W = longest_edge, shortest_edge
141
+ else:
142
+ H, W = shortest_edge, longest_edge
143
+ image = image.resize((W, H))
144
+ # image = image.resize((224, 224))
145
+ buffered = BytesIO()
146
+ image.save(buffered, format="JPEG")
147
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
148
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
149
+ msg = msg.replace('<image>', img_str)
150
+ ret.append([msg, None])
151
+ else:
152
+ ret[-1][-1] = msg
153
+ return ret
154
+
155
+ def copy(self):
156
+ return Conversation(
157
+ system=self.system,
158
+ roles=self.roles,
159
+ messages=[[x, y] for x, y in self.messages],
160
+ offset=self.offset,
161
+ sep_style=self.sep_style,
162
+ sep=self.sep,
163
+ sep2=self.sep2)
164
+
165
+ def dict(self):
166
+ if len(self.get_images()) > 0:
167
+ return {
168
+ "system": self.system,
169
+ "roles": self.roles,
170
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
171
+ "offset": self.offset,
172
+ "sep": self.sep,
173
+ "sep2": self.sep2,
174
+ }
175
+ return {
176
+ "system": self.system,
177
+ "roles": self.roles,
178
+ "messages": self.messages,
179
+ "offset": self.offset,
180
+ "sep": self.sep,
181
+ "sep2": self.sep2,
182
+ }
183
+
184
+
185
+ conv_v1 = Conversation(
186
+ system="A chat between a curious human and an artificial intelligence assistant. "
187
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
188
+ roles=("Human", "Assistant"),
189
+ messages=(
190
+ ("Human", "Give three tips for staying healthy."),
191
+ ("Assistant",
192
+ "Sure, here are three tips for staying healthy:\n"
193
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
194
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
195
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
196
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
197
+ "activities at least two days per week.\n"
198
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
199
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
200
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
201
+ "and aim to drink plenty of water throughout the day.\n"
202
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
203
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
204
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
205
+ "help improve the quality of your sleep.")
206
+ ),
207
+ offset=2,
208
+ sep_style=SeparatorStyle.SINGLE,
209
+ sep="###",
210
+ )
211
+
212
+ conv_v1_2 = Conversation(
213
+ system="A chat between a curious human and an artificial intelligence assistant. "
214
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
215
+ roles=("Human", "Assistant"),
216
+ messages=(
217
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
218
+ ("Assistant",
219
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
220
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
221
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
222
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
223
+ "renewable and non-renewable energy sources:\n"
224
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
225
+ "energy sources are finite and will eventually run out.\n"
226
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
227
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
228
+ "and other negative effects.\n"
229
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
230
+ "have lower operational costs than non-renewable sources.\n"
231
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
232
+ "locations than non-renewable sources.\n"
233
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
234
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
235
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
236
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
237
+ ),
238
+ offset=2,
239
+ sep_style=SeparatorStyle.SINGLE,
240
+ sep="###",
241
+ )
242
+
243
+ conv_vicuna_v1_1 = Conversation(
244
+ system="A chat between a curious user and an artificial intelligence assistant. "
245
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
246
+ roles=("USER", "ASSISTANT"),
247
+ version="v1",
248
+ messages=(),
249
+ offset=0,
250
+ sep_style=SeparatorStyle.TWO,
251
+ sep=" ",
252
+ sep2="</s>",
253
+ )
254
+
255
+ conv_mpt = Conversation(
256
+ system="""system
257
+ - You are a helpful language and vision assistant.
258
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
259
+ - You should follow the instructions carefully and explain your answers in detail.""",
260
+ roles=("user\n", "assistant\n"),
261
+ version="mpt",
262
+ messages=(),
263
+ offset=0,
264
+ sep_style=SeparatorStyle.MPT,
265
+ sep="",
266
+ )
267
+
268
+ conv_mpt_text = Conversation(
269
+ system="""system
270
+ - You are a helpful assistant chatbot trained by MosaicML.
271
+ - You answer questions.
272
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
273
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
274
+ roles=("user\n", "assistant\n"),
275
+ version="mpt",
276
+ messages=(),
277
+ offset=0,
278
+ sep_style=SeparatorStyle.MPT,
279
+ sep="",
280
+ )
281
+
282
+ conv_bair_v1 = Conversation(
283
+ system="BEGINNING OF CONVERSATION:",
284
+ roles=("USER", "GPT"),
285
+ messages=(),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.TWO,
288
+ sep=" ",
289
+ sep2="</s>",
290
+ )
291
+
292
+ simple_conv = Conversation(
293
+ system="A chat between a curious human and an artificial intelligence assistant. "
294
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
295
+ roles=("Human", "Assistant"),
296
+ messages=(
297
+ ("Human", "Hi!"),
298
+ ("Assistant", "Hi there! How can I help you today?")
299
+ ),
300
+ offset=2,
301
+ sep_style=SeparatorStyle.SINGLE,
302
+ sep="###",
303
+ )
304
+
305
+ simple_conv_multimodal = Conversation(
306
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
307
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
308
+ "Follow the instructions carefully and explain your answers in detail.",
309
+ roles=("Human", "Assistant"),
310
+ messages=(
311
+ ("Human", "Hi!"),
312
+ ("Assistant", "Hi there! How can I help you today?\n")
313
+ ),
314
+ offset=2,
315
+ sep_style=SeparatorStyle.SINGLE,
316
+ sep="###",
317
+ )
318
+
319
+ simple_conv_mpt_multimodal = Conversation(
320
+ system="""system
321
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
322
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
323
+ - You should follow the instructions carefully and explain your answers in detail.""",
324
+ roles=("user\n", "assistant\n"),
325
+ version="mpt",
326
+ messages=(),
327
+ offset=0,
328
+ sep_style=SeparatorStyle.MPT,
329
+ sep="",
330
+ )
331
+
332
+ simple_conv_legacy = Conversation(
333
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
334
+ "You are designed to assist human with a variety of tasks using natural language."
335
+ "Follow the instructions carefully.",
336
+ roles=("Human", "Assistant"),
337
+ messages=(
338
+ ("Human", "Hi!\n\n### Response:"),
339
+ ("Assistant", "Hi there! How can I help you today?\n")
340
+ ),
341
+ offset=2,
342
+ sep_style=SeparatorStyle.SINGLE,
343
+ sep="###",
344
+ )
345
+
346
+ conv_llava_v1 = Conversation(
347
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
348
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
349
+ "Follow the instructions carefully and explain your answers in detail.",
350
+ roles=("USER", "ASSISTANT"),
351
+ version="v1",
352
+ messages=(),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.TWO,
355
+ sep=" ",
356
+ sep2="</s>",
357
+ )
358
+
359
+ default_conversation = conv_v1_2
360
+ conv_templates = {
361
+ "default": conv_v1_2,
362
+ "simple": simple_conv,
363
+ "simple_legacy": simple_conv_legacy,
364
+ "multimodal": simple_conv_multimodal,
365
+ "mpt_multimodal": simple_conv_mpt_multimodal,
366
+ "llava_v1": conv_llava_v1,
367
+
368
+ # fastchat
369
+ "v1": conv_v1_2,
370
+ "bair_v1": conv_bair_v1,
371
+ "vicuna_v1_1": conv_vicuna_v1_1,
372
+ "mpt": conv_mpt,
373
+ "mpt_text": conv_mpt_text,
374
+ }
375
+
376
+
377
+ if __name__ == "__main__":
378
+ print(default_conversation.get_prompt())
379
+ #mgie_llava.py:
380
+ from typing import List, Optional, Tuple, Union
381
+
382
+ import torch
383
+ import torch.nn as nn
384
+ import torch.nn.functional as F
385
+ from torch.nn import CrossEntropyLoss
386
+
387
+ from transformers import AutoConfig, AutoModelForCausalLM, \
388
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
389
+ CLIPVisionModel, CLIPImageProcessor
390
+
391
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
392
+
393
+ import os, diffusers
394
+
395
+ DEFAULT_IMAGE_TOKEN = "<image>"
396
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
397
+ DEFAULT_IM_START_TOKEN = "<im_start>"
398
+ DEFAULT_IM_END_TOKEN = "<im_end>"
399
+
400
+
401
+ class LlavaConfig(LlamaConfig):
402
+ model_type = "llava"
403
+
404
+
405
+ class LlavaLlamaModel(LlamaModel):
406
+ config_class = LlavaConfig
407
+
408
+ def __init__(self, config: LlamaConfig):
409
+ super(LlavaLlamaModel, self).__init__(config)
410
+
411
+ if hasattr(config, "mm_vision_tower"):
412
+ # HACK: for FSDP
413
+ self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
414
+ # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
415
+
416
+ if hasattr(config, "use_mm_proj"):
417
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
418
+
419
+ def get_vision_tower(self):
420
+ vision_tower = getattr(self, 'vision_tower', None)
421
+ if type(vision_tower) is list:
422
+ vision_tower = vision_tower[0]
423
+ return vision_tower
424
+
425
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
426
+ pretrain_mm_mlp_adapter=None, fsdp=None):
427
+ self.config.mm_vision_tower = vision_tower
428
+
429
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
430
+
431
+ if not hasattr(self, 'vision_tower'):
432
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
433
+ else:
434
+ vision_tower = self.vision_tower[0]
435
+ vision_tower.requires_grad_(False)
436
+
437
+ if fsdp is not None and len(fsdp) > 0:
438
+ self.vision_tower = [vision_tower]
439
+ else:
440
+ self.vision_tower = vision_tower
441
+
442
+ vision_config = vision_tower.config
443
+ num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
444
+
445
+ self.config.use_mm_proj = True
446
+ self.config.mm_hidden_size = vision_config.hidden_size
447
+ self.config.mm_vision_select_layer = mm_vision_select_layer
448
+
449
+ if not hasattr(self, 'mm_projector'):
450
+ self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
451
+
452
+ if pretrain_mm_mlp_adapter is not None:
453
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
454
+ self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
455
+
456
+ return dict(
457
+ image_processor=image_processor,
458
+ image_token_len=num_patches,
459
+ vision_config=vision_config
460
+ )
461
+
462
+ def forward(
463
+ self,
464
+ input_ids: torch.LongTensor = None,
465
+ attention_mask: Optional[torch.Tensor] = None,
466
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
467
+ inputs_embeds: Optional[torch.FloatTensor] = None,
468
+ use_cache: Optional[bool] = None,
469
+ output_attentions: Optional[bool] = None,
470
+ output_hidden_states: Optional[bool] = None,
471
+ images: Optional[torch.FloatTensor] = None,
472
+ return_dict: Optional[bool] = None,
473
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
474
+
475
+ # HACK: replace back original embeddings for LLaVA pretraining
476
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
477
+ # if orig_embeds_params is not None:
478
+ # orig_embeds_params = orig_embeds_params[0]
479
+ # with torch.no_grad():
480
+ # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
481
+
482
+ if inputs_embeds is None:
483
+ inputs_embeds = self.embed_tokens(input_ids)
484
+
485
+ vision_tower = self.get_vision_tower()
486
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
487
+ # TODO: this is a modified multimodal LLM -- Haotian Liu
488
+ with torch.no_grad():
489
+ if type(images) is list:
490
+ # variable length images
491
+ image_features = []
492
+ for image in images:
493
+ image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
494
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
495
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
496
+ image_feature = select_hidden_state[:, 1:]
497
+ image_features.append(image_feature)
498
+ else:
499
+ image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True)
500
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
501
+ select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
502
+ image_features = select_hidden_state[:, 1:].to(images.dtype)
503
+ if type(images) is list:
504
+ image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
505
+ else:
506
+ image_features = self.mm_projector(image_features)
507
+ dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
508
+ dummy_image_features = self.mm_projector(dummy_image_features)
509
 
510
+ new_input_embeds = []
511
+ cur_image_idx = 0
512
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
513
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
514
+ # multimodal LLM, but the current sample is not multimodal
515
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
516
+ new_input_embeds.append(cur_input_embeds)
517
+ cur_image_idx += 1
518
+ continue
519
+ if vision_tower.config.use_im_start_end:
520
+ cur_image_features = image_features[cur_image_idx]
521
+ num_patches = cur_image_features.shape[0]
522
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
523
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
524
+ image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
525
+ for image_start_token_pos in image_start_tokens:
526
+ cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
527
+ num_patches = cur_image_features.shape[0]
528
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
529
+ raise ValueError("The image end token should follow the image start token.")
530
+ if orig_embeds_params is not None:
531
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
532
+ else:
533
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
534
+ cur_image_idx += 1
535
+ new_input_embeds.append(cur_new_input_embeds)
536
+ else:
537
+ cur_image_features = image_features[cur_image_idx]
538
+ num_patches = cur_image_features.shape[0]
539
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
540
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
541
+ masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
542
+ mask_index_start = masked_indices[0]
543
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
544
+ raise ValueError("The image patch tokens should be consecutive.")
545
+ if orig_embeds_params is not None:
546
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
547
+ else:
548
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
549
+ new_input_embeds.append(cur_new_input_embeds)
550
+ cur_image_idx += 1
551
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
552
+
553
+ return super(LlavaLlamaModel, self).forward(
554
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
555
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
556
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
557
+ return_dict=return_dict
558
+ )
559
+
560
+ class EditMapper(nn.Module):
561
+ def __init__(self):
562
+ super().__init__()
563
+
564
+ self.llm2hid = nn.Linear(4096, 512)
565
+ self.query = nn.Parameter(torch.randn(1, 77, 512))
566
+ self.mapper = nn.Transformer(batch_first=True, norm_first=True,
567
+ d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
568
+ dim_feedforward=2048, dropout=0.0)
569
+ self.hid2feat = nn.Linear(512, 768)
570
+
571
+ def forward(self, llm, emb):
572
+ hid = self.llm2hid(llm+emb)
573
+ hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
574
+ feat = self.hid2feat(hid)
575
+
576
+ return feat
577
+
578
+ class LlavaLlamaForCausalLM(LlamaForCausalLM):
579
+ config_class = LlavaConfig
580
+
581
+ def __init__(self, config):
582
+ super(LlamaForCausalLM, self).__init__(config)
583
+ self.model = LlavaLlamaModel(config)
584
+
585
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
586
+
587
+ self.edit_head = EditMapper()
588
+
589
+ '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
590
+ diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
591
+ diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
592
+ self.vae.requires_grad_(False)
593
+ self.unet.register_to_config(in_channels=8)
594
+ with torch.no_grad():
595
+ conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
596
+ conv.weight.zero_()
597
+ conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
598
+ self.unet.conv_in = conv'''
599
+
600
+ # Initialize weights and apply final processing
601
+ self.post_init()
602
+
603
+ def get_model(self):
604
+ return self.model
605
+
606
+ def get_vision_tower(self):
607
+ return self.get_model().get_vision_tower()
608
+
609
+ def get_vision_tower(self):
610
+ model = self.get_model()
611
+ vision_tower = model.vision_tower
612
+ if type(vision_tower) is list:
613
+ vision_tower = vision_tower[0]
614
+ return vision_tower
615
+
616
+ def forward(
617
+ self,
618
+ input_ids: torch.LongTensor = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
621
+ inputs_embeds: Optional[torch.FloatTensor] = None,
622
+ labels: Optional[torch.LongTensor] = None,
623
+ use_cache: Optional[bool] = None,
624
+ output_attentions: Optional[bool] = None,
625
+ output_hidden_states: Optional[bool] = None,
626
+ images: Optional[torch.FloatTensor] = None,
627
+ return_dict: Optional[bool] = None,
628
+ p2p_inp=None, p2p_ans=None
629
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
630
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
631
+ output_hidden_states = (
632
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
633
+ )
634
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
635
+
636
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
637
+ outputs = self.model(
638
+ input_ids=input_ids,
639
+ attention_mask=attention_mask,
640
+ past_key_values=past_key_values,
641
+ inputs_embeds=inputs_embeds,
642
+ use_cache=use_cache,
643
+ output_attentions=output_attentions,
644
+ output_hidden_states=output_hidden_states,
645
+ return_dict=return_dict,
646
+ images=images
647
+ )
648
+
649
+ hidden_states = outputs[0]
650
+ logits = self.lm_head(hidden_states)
651
+
652
+ loss = None
653
+ if labels is not None:
654
+ # Shift so that tokens < n predict n
655
+ shift_logits = logits[..., :-1, :].contiguous()
656
+ shift_labels = labels[..., 1:].contiguous()
657
+ # Flatten the tokens
658
+ loss_fct = CrossEntropyLoss()
659
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
660
+ shift_labels = shift_labels.view(-1)
661
+ # Enable model/pipeline parallelism
662
+ shift_labels = shift_labels.to(shift_logits.device)
663
+ loss = loss_fct(shift_logits, shift_labels)
664
+
665
+ if labels is not None:
666
+ llm = []
667
+ for i in range(labels.shape[0]):
668
+ try: p = labels[i].data.cpu().tolist().index(32003)-1
669
+ except: p = len(labels[i])-9
670
+ p = min(len(hidden_states[i])-9, p)
671
+ llm.append(hidden_states[i][p:p+8].unsqueeze(0))
672
+ llm = torch.cat(llm, dim=0)
673
+ hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
674
+
675
+ B, DROP = labels.shape[0], 0.05
676
+
677
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
678
+ self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
679
+
680
+ with torch.no_grad():
681
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
682
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
683
+ torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
684
+
685
+ noise = torch.randn_like(lat_ans)
686
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
687
+ lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
688
+
689
+ prob = torch.rand(B, device=lat_ans.device)
690
+ mask = (prob<(DROP*2)).reshape(B, 1, 1)
691
+ hid_edit = torch.where(mask, hid_null, hid_edit)
692
+ mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
693
+ lat_inp *= mask
694
+
695
+ out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
696
+
697
+ loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
698
+ if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
699
+ loss = loss_ce+loss_edit*0.5
700
+
701
+ if not return_dict:
702
+ output = (logits,) + outputs[1:]
703
+ return (loss,) + output if loss is not None else output
704
+
705
+ return CausalLMOutputWithPast(
706
+ loss=loss,
707
+ logits=logits,
708
+ past_key_values=outputs.past_key_values,
709
+ hidden_states=outputs.hidden_states,
710
+ attentions=outputs.attentions,
711
+ )
712
+
713
+ def prepare_inputs_for_generation(
714
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
715
+ ):
716
+ if past_key_values:
717
+ input_ids = input_ids[:, -1:]
718
+
719
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
720
+ if inputs_embeds is not None and past_key_values is None:
721
+ model_inputs = {"inputs_embeds": inputs_embeds}
722
+ else:
723
+ model_inputs = {"input_ids": input_ids}
724
+
725
+ model_inputs.update(
726
+ {
727
+ "past_key_values": past_key_values,
728
+ "use_cache": kwargs.get("use_cache"),
729
+ "attention_mask": attention_mask,
730
+ "images": kwargs.get("images", None),
731
+ }
732
+ )
733
+ return model_inputs
734
+
735
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
736
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
737
+ vision_config = self.get_vision_tower().config
738
+ vision_config.use_im_start_end = mm_use_im_start_end
739
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
740
+ self.resize_token_embeddings(len(tokenizer))
741
+
742
+ if mm_use_im_start_end:
743
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
744
+ self.resize_token_embeddings(len(tokenizer))
745
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
746
+
747
+ if num_new_tokens > 0:
748
+ input_embeddings = self.get_input_embeddings().weight.data
749
+ output_embeddings = self.get_output_embeddings().weight.data
750
+
751
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
752
+ dim=0, keepdim=True)
753
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
754
+ dim=0, keepdim=True)
755
+
756
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
757
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
758
+
759
+ if tune_mm_mlp_adapter:
760
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
761
+ for p in self.get_input_embeddings().parameters():
762
+ p.requires_grad = True
763
+ for p in self.get_output_embeddings().parameters():
764
+ p.requires_grad = False
765
+
766
+ if pretrain_mm_mlp_adapter:
767
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
768
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
769
+ assert num_new_tokens == 2
770
+ if input_embeddings.shape == embed_tokens_weight.shape:
771
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
772
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
773
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
774
+ else:
775
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
776
+
777
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
778
+
779
+ AutoConfig.register("llava", LlavaConfig)
780
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
781
+ #main.py:
782
+ from google.colab import drive
783
+ drive.mount('/content/drive')
784
+
785
+ import os
786
+ from PIL import Image
787
  import numpy as np
788
  import torch as T
789
+ import transformers
790
+ import diffusers
791
+ import gradio as gr
792
+ import huggingface_hub
793
+
794
+ CKPT_DIR = '/content/drive/My Drive/_ckpt'
795
+
796
 
 
 
797
 
 
798
 
799
  def crop_resize(f, sz=512):
800
  w, h = f.size
801
+ if w > h:
802
+ p = (w - h) // 2
803
+ f = f.crop([p, 0, p + h, h])
804
+ elif h > w:
805
+ p = (h - w) // 2
806
+ f = f.crop([0, p, w, p + w])
807
  f = f.resize([sz, sz])
808
  return f
809
+
810
+ def remove_alter(s):
811
+ if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:') + 10:].strip()
812
  if '</s>' in s: s = s[:s.index('</s>')].strip()
813
  if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
814
  if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
815
  s = '.'.join([s.strip() for s in s.split('.')[:2]])
816
+ if s[-1] != '.': s += '.'
817
  return s.strip()
818
 
819
  DEFAULT_IMAGE_TOKEN = '<image>'
820
  DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
821
  DEFAULT_IM_START_TOKEN = '<im_start>'
822
  DEFAULT_IM_END_TOKEN = '<im_end>'
823
+ PATH_LLAVA = f'{CKPT_DIR}/LLaVA-7B-v1'
824
 
825
  tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
826
  model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
 
829
  tokenizer.padding_side = 'left'
830
  tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
831
  model.resize_token_embeddings(len(tokenizer))
832
+ ckpt = T.load(f'{CKPT_DIR}/mgie_7b/mllm.pt', map_location='cpu')
833
  model.load_state_dict(ckpt, strict=False)
834
 
835
  mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
 
843
  vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
844
  vision_config.use_im_start_end = mm_use_im_start_end
845
  if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
846
+ image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
847
 
848
  _ = model.eval()
849
 
850
  pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
851
  pipe.set_progress_bar_config(disable=True)
852
+ pipe.unet.load_state_dict(T.load(f'{CKPT_DIR}/mgie_7b/unet.pt', map_location='cpu'))
853
  print('--init MGIE--')
854
 
 
855
  def go_mgie(img, txt, seed, cfg_txt, cfg_img):
856
  EMB = ckpt['emb'].cuda()
857
  with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
858
+
859
  img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
860
  inp = img
861
 
862
  img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
863
+ txt = "what will this image be like if '%s'" % (txt)
864
+ txt = txt + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
865
  conv = conv_templates['vicuna_v1_1'].copy()
866
  conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
867
  txt = conv.get_prompt()
 
870
 
871
  with T.inference_mode():
872
  _ = model.cuda()
873
+ out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
874
+ do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
875
  return_dict_in_generate=True, output_hidden_states=True)
876
  out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]
877
+
878
+ if 32003 in out: p = out.index(32003) - 1
879
+ else: p = len(hid) - 9
880
+ p = min(p, len(hid) - 9)
881
+ hid = hid[p:p + 8]
882
 
883
  out = remove_alter(tokenizer.decode(out))
884
  _ = model.cuda()
885
  emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
886
+ res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
887
  generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
888
 
889
  return res, out
890
 
 
 
 
891
  with gr.Blocks() as app:
892
  gr.Markdown(
893
  """
894
+ # MagiX: Edit Personalized Images using Gen AI by Ateeb Taser
895
  """
896
  )
897
+ with gr.Row():
898
+ inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
899
+ gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
900
+ with gr.Row():
901
+ txt, out = [gr.Textbox(label='Instruction', interactive=True),
902
+ gr.Textbox(label='Expressive Instruction', interactive=False)]
903
+ with gr.Row():
904
+ seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
905
+ gr.Number(value=7.5, label='Text CFG', interactive=True),
906
+ gr.Number(value=1.5, label='Image CFG', interactive=True)]
907
+ with gr.Row():
908
+ btn_sub = gr.Button('Submit')
909
  btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
910
+
911
  app.launch()