DongfuJiang commited on
Commit
8fbc209
·
1 Parent(s): b89f04a
app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from PIL import Image
4
+ from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
5
+ from typing import List
6
+ processor = MLlavaProcessor()
7
+ model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_v2_4096")
8
+
9
+ @spaces.GPU
10
+ def generate(text:str, images:List[Image.Image], history: List[dict]):
11
+ model = model.to("cuda")
12
+
13
+ for text, history in chat_mllava(text, images, model, processor, history=history, stream=True):
14
+ yield text, history
15
+
16
+ def build_demo():
17
+
18
+
19
+
20
+ if __name__ == "__main__":
21
+ processor = MLlavaProcessor()
22
+ model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_v2_4096")
23
+ demo = build_demo()
24
+ demo.launch()
models/conversation.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+ MFuyu = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ init_msg = init_msg[0].replace("<image>", "").strip()
36
+ if 'mmtag' in self.version:
37
+ messages[0] = (init_role, init_msg)
38
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ messages.insert(1, (self.roles[1], "Received."))
40
+ else:
41
+ messages[0] = (init_role, "<image>" + init_msg)
42
+
43
+ if self.sep_style == SeparatorStyle.SINGLE:
44
+ ret = self.system + self.sep
45
+ for role, message in messages:
46
+ if message:
47
+ if type(message) is tuple:
48
+ message, _, _ = message
49
+ ret += role + ": " + message + self.sep
50
+ else:
51
+ ret += role + ":"
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(messages):
56
+ if message:
57
+ if type(message) is tuple:
58
+ message, _, _ = message
59
+ ret += role + ": " + message + seps[i % 2]
60
+ else:
61
+ ret += role + ":"
62
+ elif self.sep_style == SeparatorStyle.MPT:
63
+ ret = self.system + self.sep
64
+ for role, message in messages:
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + message + self.sep
69
+ else:
70
+ ret += role
71
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
72
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
73
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
74
+ ret = ""
75
+
76
+ for i, (role, message) in enumerate(messages):
77
+ if i == 0:
78
+ assert message, "first message should not be none"
79
+ assert role == self.roles[0], "first message should come from user"
80
+ if message:
81
+ if type(message) is tuple:
82
+ message, _, _ = message
83
+ if i == 0: message = wrap_sys(self.system) + message
84
+ if i % 2 == 0:
85
+ message = wrap_inst(message)
86
+ ret += self.sep + message
87
+ else:
88
+ ret += " " + message + " " + self.sep2
89
+ else:
90
+ ret += ""
91
+ ret = ret.lstrip(self.sep)
92
+ elif self.sep_style == SeparatorStyle.MFuyu:
93
+ seps = [self.sep, self.sep2]
94
+ ret = self.system + "\n"
95
+ for i, (role, message) in enumerate(messages):
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += role + ": " + message + seps[i % 2]
100
+ else:
101
+ ret += role + ":"
102
+ elif self.sep_style == SeparatorStyle.PLAIN:
103
+ seps = [self.sep, self.sep2]
104
+ ret = self.system
105
+ for i, (role, message) in enumerate(messages):
106
+ if message:
107
+ if type(message) is tuple:
108
+ message, _, _ = message
109
+ ret += message + seps[i % 2]
110
+ else:
111
+ ret += ""
112
+ else:
113
+ raise ValueError(f"Invalid style: {self.sep_style}")
114
+
115
+ return ret
116
+
117
+ def append_message(self, role, message):
118
+ self.messages.append([role, message])
119
+
120
+ def get_images(self, return_pil=False):
121
+ images = []
122
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
123
+ if i % 2 == 0:
124
+ if type(msg) is tuple:
125
+ import base64
126
+ from io import BytesIO
127
+ from PIL import Image
128
+ msg, image, image_process_mode = msg
129
+ if image_process_mode == "Pad":
130
+ def expand2square(pil_img, background_color=(122, 116, 104)):
131
+ width, height = pil_img.size
132
+ if width == height:
133
+ return pil_img
134
+ elif width > height:
135
+ result = Image.new(pil_img.mode, (width, width), background_color)
136
+ result.paste(pil_img, (0, (width - height) // 2))
137
+ return result
138
+ else:
139
+ result = Image.new(pil_img.mode, (height, height), background_color)
140
+ result.paste(pil_img, ((height - width) // 2, 0))
141
+ return result
142
+ image = expand2square(image)
143
+ elif image_process_mode in ["Default", "Crop"]:
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
149
+ max_hw, min_hw = max(image.size), min(image.size)
150
+ aspect_ratio = max_hw / min_hw
151
+ max_len, min_len = 800, 400
152
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
153
+ longest_edge = int(shortest_edge * aspect_ratio)
154
+ W, H = image.size
155
+ if longest_edge != max(image.size):
156
+ if H > W:
157
+ H, W = longest_edge, shortest_edge
158
+ else:
159
+ H, W = shortest_edge, longest_edge
160
+ image = image.resize((W, H))
161
+ if return_pil:
162
+ images.append(image)
163
+ else:
164
+ buffered = BytesIO()
165
+ image.save(buffered, format="PNG")
166
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
167
+ images.append(img_b64_str)
168
+ return images
169
+
170
+ def to_gradio_chatbot(self):
171
+ ret = []
172
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
173
+ if i % 2 == 0:
174
+ if type(msg) is tuple:
175
+ import base64
176
+ from io import BytesIO
177
+ msg, image, image_process_mode = msg
178
+ max_hw, min_hw = max(image.size), min(image.size)
179
+ aspect_ratio = max_hw / min_hw
180
+ max_len, min_len = 800, 400
181
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
182
+ longest_edge = int(shortest_edge * aspect_ratio)
183
+ W, H = image.size
184
+ if H > W:
185
+ H, W = longest_edge, shortest_edge
186
+ else:
187
+ H, W = shortest_edge, longest_edge
188
+ image = image.resize((W, H))
189
+ buffered = BytesIO()
190
+ image.save(buffered, format="JPEG")
191
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
192
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
193
+ msg = img_str + msg.replace('<image>', '').strip()
194
+ ret.append([msg, None])
195
+ else:
196
+ ret.append([msg, None])
197
+ else:
198
+ ret[-1][-1] = msg
199
+ return ret
200
+
201
+ def copy(self):
202
+ return Conversation(
203
+ system=self.system,
204
+ roles=self.roles,
205
+ messages=[[x, y] for x, y in self.messages],
206
+ offset=self.offset,
207
+ sep_style=self.sep_style,
208
+ sep=self.sep,
209
+ sep2=self.sep2,
210
+ version=self.version)
211
+
212
+ def dict(self):
213
+ if len(self.get_images()) > 0:
214
+ return {
215
+ "system": self.system,
216
+ "roles": self.roles,
217
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
218
+ "offset": self.offset,
219
+ "sep": self.sep,
220
+ "sep2": self.sep2,
221
+ }
222
+ return {
223
+ "system": self.system,
224
+ "roles": self.roles,
225
+ "messages": self.messages,
226
+ "offset": self.offset,
227
+ "sep": self.sep,
228
+ "sep2": self.sep2,
229
+ }
230
+
231
+
232
+ conv_vicuna_v0 = Conversation(
233
+ system="A chat between a curious human and an artificial intelligence assistant. "
234
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
235
+ roles=("Human", "Assistant"),
236
+ messages=(
237
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
238
+ ("Assistant",
239
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
240
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
241
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
242
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
243
+ "renewable and non-renewable energy sources:\n"
244
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
245
+ "energy sources are finite and will eventually run out.\n"
246
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
247
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
248
+ "and other negative effects.\n"
249
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
250
+ "have lower operational costs than non-renewable sources.\n"
251
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
252
+ "locations than non-renewable sources.\n"
253
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
254
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
255
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
256
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
257
+ ),
258
+ offset=2,
259
+ sep_style=SeparatorStyle.SINGLE,
260
+ sep="###",
261
+ )
262
+
263
+ conv_vicuna_v1 = Conversation(
264
+ system="A chat between a curious user and an artificial intelligence assistant. "
265
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
266
+ roles=("USER", "ASSISTANT"),
267
+ version="v1",
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.TWO,
271
+ sep=" ",
272
+ sep2="</s>",
273
+ )
274
+
275
+ conv_llama_2 = Conversation(
276
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
277
+
278
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
279
+ roles=("USER", "ASSISTANT"),
280
+ version="llama_v2",
281
+ messages=(),
282
+ offset=0,
283
+ sep_style=SeparatorStyle.LLAMA_2,
284
+ sep="<s>",
285
+ sep2="</s>",
286
+ )
287
+
288
+ conv_llava_llama_2 = Conversation(
289
+ system="You are a helpful language and vision assistant. "
290
+ "You are able to understand the visual content that the user provides, "
291
+ "and assist the user with a variety of tasks using natural language.",
292
+ roles=("USER", "ASSISTANT"),
293
+ version="llama_v2",
294
+ messages=(),
295
+ offset=0,
296
+ sep_style=SeparatorStyle.LLAMA_2,
297
+ sep="<s>",
298
+ sep2="</s>",
299
+ )
300
+
301
+ conv_mpt = Conversation(
302
+ system="""<|im_start|>system
303
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
304
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
305
+ version="mpt",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.MPT,
309
+ sep="<|im_end|>",
310
+ )
311
+
312
+ conv_llava_plain = Conversation(
313
+ system="",
314
+ roles=("", ""),
315
+ messages=(
316
+ ),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.PLAIN,
319
+ sep="\n",
320
+ )
321
+
322
+ conv_llava_v0 = Conversation(
323
+ system="A chat between a curious human and an artificial intelligence assistant. "
324
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
325
+ roles=("Human", "Assistant"),
326
+ messages=(
327
+ ),
328
+ offset=0,
329
+ sep_style=SeparatorStyle.SINGLE,
330
+ sep="###",
331
+ )
332
+
333
+ conv_llava_v0_mmtag = Conversation(
334
+ system="A chat between a curious user and an artificial intelligence assistant. "
335
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
336
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
337
+ roles=("Human", "Assistant"),
338
+ messages=(
339
+ ),
340
+ offset=0,
341
+ sep_style=SeparatorStyle.SINGLE,
342
+ sep="###",
343
+ version="v0_mmtag",
344
+ )
345
+
346
+ conv_llava_v1 = Conversation(
347
+ system="A chat between a curious human and an artificial intelligence assistant. "
348
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
349
+ roles=("USER", "ASSISTANT"),
350
+ version="v1",
351
+ messages=(),
352
+ offset=0,
353
+ sep_style=SeparatorStyle.TWO,
354
+ sep=" ",
355
+ sep2="</s>",
356
+ )
357
+
358
+ conv_llava_v1_mmtag = Conversation(
359
+ system="A chat between a curious user and an artificial intelligence assistant. "
360
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
361
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
362
+ roles=("USER", "ASSISTANT"),
363
+ messages=(),
364
+ offset=0,
365
+ sep_style=SeparatorStyle.TWO,
366
+ sep=" ",
367
+ sep2="</s>",
368
+ version="v1_mmtag",
369
+ )
370
+
371
+ conv_mfuyu_v1 = Conversation(
372
+ system="You are a helpful language and vision assistant. "
373
+ "You are able to understand the visual content that the user provides, "
374
+ "and assist the user with a variety of tasks using natural language.",
375
+ roles=("USER", "ASSISTANT"),
376
+ version="v1",
377
+ messages=(),
378
+ offset=0,
379
+ sep_style=SeparatorStyle.MFuyu,
380
+ sep="<0x04>", # begin of answer token
381
+ sep2="|ENDOFTEXT|",
382
+ ) # copied from conv_vicuna_v1
383
+
384
+ conv_mllava_v1_mmtag = Conversation(
385
+ system="A chat between a curious user and an artificial intelligence assistant. "
386
+ "The assistant is able to understand the multiple visual contents that the user provides, and assist the user with a variety of tasks using natural language."
387
+ "Each visual content will be provided with the following format: <Image>visual content</Image>.",
388
+ roles=("USER", "ASSISTANT"),
389
+ messages=(),
390
+ offset=0,
391
+ sep_style=SeparatorStyle.SINGLE,
392
+ sep="</s>",
393
+ version="v1_mmtag",
394
+ )
395
+
396
+
397
+ default_conversation = conv_mfuyu_v1
398
+ conv_templates = {
399
+ "default": conv_vicuna_v0,
400
+ "v0": conv_vicuna_v0,
401
+ "v1": conv_vicuna_v1,
402
+ "vicuna_v1": conv_vicuna_v1,
403
+ "llama_2": conv_llama_2,
404
+
405
+ "plain": conv_llava_plain,
406
+ "v0_plain": conv_llava_plain,
407
+ "llava_v0": conv_llava_v0,
408
+ "v0_mmtag": conv_llava_v0_mmtag,
409
+ "llava_v1": conv_llava_v1,
410
+ "v1_mmtag": conv_llava_v1_mmtag,
411
+ "llava_llama_2": conv_llava_llama_2,
412
+
413
+ "mpt": conv_mpt,
414
+ }
415
+
416
+
417
+ if __name__ == "__main__":
418
+ print(default_conversation.get_prompt())
models/mllava/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
2
+ from .processing_llava import MLlavaProcessor
3
+ from .utils import chat_mllava
models/mllava/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (380 Bytes). View file
 
models/mllava/__pycache__/configuration_llava.cpython-39.pyc ADDED
Binary file (4.26 kB). View file
 
models/mllava/__pycache__/modeling_llava.cpython-39.pyc ADDED
Binary file (22.8 kB). View file
 
models/mllava/__pycache__/processing_llava.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
models/mllava/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.34 kB). View file
 
models/mllava/configuration_llava.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Llava model configuration"""
15
+
16
+
17
+ # from ...configuration_utils import PretrainedConfig
18
+ # from ...utils import logging
19
+ # from ..auto import CONFIG_MAPPING
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+ from transformers.models.auto import CONFIG_MAPPING
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ class LlavaConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
35
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Llava-9B.
37
+
38
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ vision_config (`LlavaVisionConfig`, *optional*):
45
+ Custom vision config or dict
46
+ text_config (`Union[AutoConfig, dict]`, *optional*):
47
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
48
+ ignore_index (`int`, *optional*, defaults to -100):
49
+ The ignore index for the loss function.
50
+ image_token_index (`int`, *optional*, defaults to 32000):
51
+ The image token index to encode the image prompt.
52
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
53
+ The activation function used by the multimodal projector.
54
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
55
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
56
+ vision_feature_layer (`int`, *optional*, defaults to -2):
57
+ The index of the layer to select the vision feature.
58
+ vocab_size (`int`, *optional*, defaults to 32000):
59
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
60
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
61
+
62
+ Example:
63
+
64
+ ```python
65
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
66
+
67
+ >>> # Initializing a CLIP-vision config
68
+ >>> vision_config = CLIPVisionConfig()
69
+
70
+ >>> # Initializing a Llama config
71
+ >>> text_config = LlamaConfig()
72
+
73
+ >>> # Initializing a Llava llava-1.5-7b style configuration
74
+ >>> configuration = LlavaConfig(vision_config, text_config)
75
+
76
+ >>> # Initializing a model from the llava-1.5-7b style configuration
77
+ >>> model = LlavaForConditionalGeneration(configuration)
78
+
79
+ >>> # Accessing the model configuration
80
+ >>> configuration = model.config
81
+ ```"""
82
+
83
+ model_type = "llava"
84
+ is_composition = False
85
+
86
+ def __init__(
87
+ self,
88
+ vision_config=None,
89
+ text_config=None,
90
+ ignore_index=-100,
91
+ image_token_index=32000,
92
+ projector_hidden_act="gelu",
93
+ vision_feature_select_strategy="default",
94
+ vision_feature_layer=-2,
95
+ vocab_size=32000,
96
+ **kwargs,
97
+ ):
98
+ self.ignore_index = ignore_index
99
+ self.image_token_index = image_token_index
100
+ self.projector_hidden_act = projector_hidden_act
101
+ self.vision_feature_select_strategy = vision_feature_select_strategy
102
+ self.vision_feature_layer = vision_feature_layer
103
+ self.vocab_size = vocab_size
104
+
105
+ self.vision_config = vision_config
106
+
107
+ if isinstance(self.vision_config, dict):
108
+ vision_config["model_type"] = (
109
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
110
+ )
111
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
112
+ elif vision_config is None:
113
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
114
+ intermediate_size=4096,
115
+ hidden_size=1024,
116
+ patch_size=14,
117
+ image_size=336,
118
+ num_hidden_layers=24,
119
+ num_attention_heads=16,
120
+ vocab_size=32000,
121
+ projection_dim=768,
122
+ )
123
+ self.vocab_size = self.vocab_size
124
+
125
+ self.text_config = text_config
126
+
127
+ if isinstance(self.text_config, dict):
128
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
129
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
130
+ self.vocab_size = self.text_config.vocab_size
131
+ elif text_config is None:
132
+ self.text_config = CONFIG_MAPPING["llama"]()
133
+
134
+ super().__init__(**kwargs)
models/mllava/modeling_llava.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Llava model."""
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ # from ... import PreTrainedModel
24
+ # from ...activations import ACT2FN
25
+ # from ...cache_utils import Cache
26
+ # from ...modeling_outputs import ModelOutput
27
+ # from ...utils import (
28
+ # add_start_docstrings,
29
+ # add_start_docstrings_to_model_forward,
30
+ # logging,
31
+ # replace_return_docstrings,
32
+ # )
33
+ # from ..auto import AutoModel, AutoModelForCausalLM
34
+
35
+ from .configuration_llava import LlavaConfig
36
+
37
+ from transformers import PreTrainedModel
38
+ from transformers.activations import ACT2FN
39
+ from transformers.cache_utils import Cache
40
+ from transformers.modeling_outputs import ModelOutput
41
+ from transformers.utils import (
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
48
+ from .configuration_llava import LlavaConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CONFIG_FOR_DOC = "LlavaConfig"
54
+
55
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
56
+ "llava-hf/llava-1.5-7b-hf",
57
+ "llava-hf/llava-1.5-13b-hf",
58
+ "llava-hf/bakLlava-v1-hf",
59
+ # See all Llava models at https://huggingface.co/models?filter=llava
60
+ ]
61
+
62
+
63
+ @dataclass
64
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
65
+ class LlavaCausalLMOutputWithPast(ModelOutput):
66
+ """
67
+ Base class for Llava causal language model (or autoregressive) outputs.
68
+
69
+ Args:
70
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
71
+ Language modeling loss (for next-token prediction).
72
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
73
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
74
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
75
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
76
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
77
+
78
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
79
+ `past_key_values` input) to speed up sequential decoding.
80
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
81
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
82
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
83
+
84
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
85
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
86
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
87
+ sequence_length)`.
88
+
89
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
90
+ heads.
91
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
92
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
93
+ sequence_length, hidden_size)`.
94
+
95
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
96
+ """
97
+
98
+ loss: Optional[torch.FloatTensor] = None
99
+ logits: torch.FloatTensor = None
100
+ past_key_values: Optional[List[torch.FloatTensor]] = None
101
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
102
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
103
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
104
+
105
+
106
+ class LlavaMultiModalProjector(nn.Module):
107
+ def __init__(self, config: LlavaConfig):
108
+ super().__init__()
109
+
110
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
111
+ self.act = ACT2FN[config.projector_hidden_act]
112
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
113
+
114
+ def forward(self, image_features):
115
+ hidden_states = self.linear_1(image_features)
116
+ hidden_states = self.act(hidden_states)
117
+ hidden_states = self.linear_2(hidden_states)
118
+ return hidden_states
119
+
120
+
121
+ LLAVA_START_DOCSTRING = r"""
122
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
123
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
124
+ etc.)
125
+
126
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
127
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
128
+ and behavior.
129
+
130
+ Parameters:
131
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
132
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
133
+ load the weights associated with the model, only the configuration. Check out the
134
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
135
+ """
136
+
137
+
138
+ @add_start_docstrings(
139
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
140
+ LLAVA_START_DOCSTRING,
141
+ )
142
+ class LlavaPreTrainedModel(PreTrainedModel):
143
+ config_class = LlavaConfig
144
+ base_model_prefix = "model"
145
+ supports_gradient_checkpointing = True
146
+ _no_split_modules = ["LlavaVisionAttention"]
147
+ _skip_keys_device_placement = "past_key_values"
148
+ _supports_flash_attn_2 = True
149
+
150
+ def _init_weights(self, module):
151
+ # important: this ported version of Llava isn't meant for training from scratch - only
152
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
153
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
154
+ std = (
155
+ self.config.initializer_range
156
+ if hasattr(self.config, "initializer_range")
157
+ else self.config.text_config.initializer_range
158
+ )
159
+
160
+ if hasattr(module, "class_embedding"):
161
+ module.class_embedding.data.normal_(mean=0.0, std=std)
162
+
163
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
164
+ module.weight.data.normal_(mean=0.0, std=std)
165
+ if module.bias is not None:
166
+ module.bias.data.zero_()
167
+ elif isinstance(module, nn.Embedding):
168
+ module.weight.data.normal_(mean=0.0, std=std)
169
+ if module.padding_idx is not None:
170
+ module.weight.data[module.padding_idx].zero_()
171
+
172
+ @property
173
+ def _supports_sdpa(self):
174
+ """
175
+ Retrieve language_model's attribute to check whether the model supports
176
+ SDPA or not.
177
+ """
178
+ return self.language_model._supports_sdpa
179
+
180
+
181
+ LLAVA_INPUTS_DOCSTRING = r"""
182
+ Args:
183
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
184
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
185
+ it.
186
+
187
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
188
+ [`PreTrainedTokenizer.__call__`] for details.
189
+
190
+ [What are input IDs?](../glossary#input-ids)
191
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
192
+ The tensors corresponding to the input images. Pixel values can be obtained using
193
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
194
+ [`CLIPImageProcessor`] for processing images).
195
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
196
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
197
+
198
+ - 1 for tokens that are **not masked**,
199
+ - 0 for tokens that are **masked**.
200
+
201
+ [What are attention masks?](../glossary#attention-mask)
202
+
203
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
204
+ [`PreTrainedTokenizer.__call__`] for details.
205
+
206
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
207
+ `past_key_values`).
208
+
209
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
210
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
211
+ information on the default strategy.
212
+
213
+ - 1 indicates the head is **not masked**,
214
+ - 0 indicates the head is **masked**.
215
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
216
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
217
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
218
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
219
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
220
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
221
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
222
+
223
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
224
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
225
+
226
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
227
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
228
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
229
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
230
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
231
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
232
+ model's internal embedding lookup matrix.
233
+ use_cache (`bool`, *optional*):
234
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
235
+ `past_key_values`).
236
+ output_attentions (`bool`, *optional*):
237
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
238
+ tensors for more detail.
239
+ output_hidden_states (`bool`, *optional*):
240
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
241
+ more detail.
242
+ return_dict (`bool`, *optional*):
243
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
244
+ """
245
+
246
+
247
+ @add_start_docstrings(
248
+ """The LLAVA model which consists of a vision backbone and a language model.""",
249
+ LLAVA_START_DOCSTRING,
250
+ )
251
+ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
252
+ def __init__(self, config: LlavaConfig):
253
+ super().__init__(config)
254
+ self.vision_tower = AutoModel.from_config(config.vision_config)
255
+
256
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
257
+ self.vocab_size = config.vocab_size
258
+ self.language_model = AutoModelForCausalLM.from_config(
259
+ config.text_config, attn_implementation=config._attn_implementation
260
+ )
261
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
262
+ self.post_init()
263
+
264
+ def get_input_embeddings(self):
265
+ return self.language_model.get_input_embeddings()
266
+
267
+ def set_input_embeddings(self, value):
268
+ self.language_model.set_input_embeddings(value)
269
+
270
+ def get_output_embeddings(self):
271
+ return self.language_model.get_output_embeddings()
272
+
273
+ def set_output_embeddings(self, new_embeddings):
274
+ self.language_model.set_output_embeddings(new_embeddings)
275
+
276
+ def set_decoder(self, decoder):
277
+ self.language_model.set_decoder(decoder)
278
+
279
+ def get_decoder(self):
280
+ return self.language_model.get_decoder()
281
+
282
+ def tie_weights(self):
283
+ return self.language_model.tie_weights()
284
+
285
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
286
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
287
+ # update vocab size
288
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
289
+ self.config.vocab_size = model_embeds.num_embeddings
290
+ self.vocab_size = model_embeds.num_embeddings
291
+ return model_embeds
292
+
293
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
294
+ num_images, num_image_patches, embed_dim = image_features.shape
295
+ batch_size, sequence_length = input_ids.shape
296
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
297
+ # 1. Create a mask to know where special image tokens are
298
+ special_image_token_mask = input_ids == self.config.image_token_index
299
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
300
+ # Compute the maximum embed dimension
301
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
302
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
303
+
304
+ # 2. Compute the positions where text should be written
305
+ # Calculate new positions for text tokens in merged image-text sequence.
306
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
307
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
308
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
309
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
310
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
311
+ if left_padding:
312
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
313
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
314
+
315
+ # 3. Create the full embedding, already padded to the maximum position
316
+ final_embedding = torch.zeros(
317
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
318
+ )
319
+ final_attention_mask = torch.zeros(
320
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
321
+ )
322
+ if labels is not None:
323
+ final_labels = torch.full(
324
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
325
+ )
326
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
327
+ # set the corresponding tensors into their correct target device.
328
+ target_device = inputs_embeds.device
329
+ batch_indices, non_image_indices, text_to_overwrite = (
330
+ batch_indices.to(target_device),
331
+ non_image_indices.to(target_device),
332
+ text_to_overwrite.to(target_device),
333
+ )
334
+ attention_mask = attention_mask.to(target_device)
335
+
336
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
337
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
338
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
339
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
340
+ if labels is not None:
341
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
342
+
343
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
344
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
345
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
346
+
347
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
348
+ raise ValueError(
349
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
350
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
351
+ )
352
+
353
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
354
+ final_attention_mask |= image_to_overwrite
355
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
356
+
357
+ if labels is None:
358
+ final_labels = None
359
+
360
+ return final_embedding, final_attention_mask, final_labels, position_ids
361
+
362
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
363
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
364
+ def forward(
365
+ self,
366
+ input_ids: torch.LongTensor = None,
367
+ pixel_values: torch.FloatTensor = None,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
372
+ vision_feature_layer: Optional[int] = None,
373
+ vision_feature_select_strategy: Optional[str] = None,
374
+ labels: Optional[torch.LongTensor] = None,
375
+ use_cache: Optional[bool] = None,
376
+ output_attentions: Optional[bool] = None,
377
+ output_hidden_states: Optional[bool] = None,
378
+ return_dict: Optional[bool] = None,
379
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
380
+ r"""
381
+ Args:
382
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
383
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
384
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
385
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
386
+
387
+ Returns:
388
+
389
+ Example:
390
+
391
+ ```python
392
+ >>> from PIL import Image
393
+ >>> import requests
394
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
395
+
396
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
397
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
398
+
399
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
400
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
401
+ >>> image = Image.open(requests.get(url, stream=True).raw)
402
+
403
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
404
+
405
+ >>> # Generate
406
+ >>> generate_ids = model.generate(**inputs, max_length=30)
407
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
408
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
409
+ ```"""
410
+
411
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
412
+ output_hidden_states = (
413
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
414
+ )
415
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
416
+ vision_feature_layer = (
417
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
418
+ )
419
+ vision_feature_select_strategy = (
420
+ vision_feature_select_strategy
421
+ if vision_feature_select_strategy is not None
422
+ else self.config.vision_feature_select_strategy
423
+ )
424
+
425
+ if inputs_embeds is None:
426
+ # 1. Extra the input embeddings
427
+ inputs_embeds = self.get_input_embeddings()(input_ids)
428
+
429
+ # 2. Merge text and images
430
+ if pixel_values is not None and input_ids.shape[1] != 1:
431
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
432
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
433
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
434
+
435
+ if vision_feature_select_strategy == "default":
436
+ selected_image_feature = selected_image_feature[:, 1:]
437
+ elif vision_feature_select_strategy == "full":
438
+ selected_image_feature = selected_image_feature
439
+ else:
440
+ raise ValueError(
441
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
442
+ )
443
+
444
+ image_features = self.multi_modal_projector(selected_image_feature)
445
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
446
+ image_features, inputs_embeds, input_ids, attention_mask, labels
447
+ )
448
+ if labels is None:
449
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
450
+ else:
451
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
452
+ # generation with cache
453
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
454
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
455
+ # that are set to 0
456
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
457
+
458
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
459
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
460
+
461
+ # Get the target length
462
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
463
+
464
+ extended_attention_mask = torch.ones(
465
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
466
+ dtype=attention_mask.dtype,
467
+ device=attention_mask.device,
468
+ )
469
+
470
+ # Filter out only the tokens that can be un-attended, this can happen
471
+ # if one uses Llava + Fused modules where the cache on the
472
+ # first iteration is already big enough, or if one passes custom cache
473
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
474
+ new_batch_index = batch_index[valid_indices]
475
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
476
+
477
+ # Zero-out the places where we don't need to attend
478
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
479
+
480
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
481
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
482
+
483
+ outputs = self.language_model(
484
+ attention_mask=attention_mask,
485
+ position_ids=position_ids,
486
+ past_key_values=past_key_values,
487
+ inputs_embeds=inputs_embeds,
488
+ use_cache=use_cache,
489
+ output_attentions=output_attentions,
490
+ output_hidden_states=output_hidden_states,
491
+ return_dict=return_dict,
492
+ )
493
+
494
+ logits = outputs[0]
495
+
496
+ loss = None
497
+ if labels is not None:
498
+ # Shift so that tokens < n predict n
499
+ if attention_mask is not None:
500
+ shift_attention_mask = attention_mask[..., 1:]
501
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
502
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
503
+ else:
504
+ shift_logits = logits[..., :-1, :].contiguous()
505
+ shift_labels = labels[..., 1:].contiguous()
506
+ # Flatten the tokens
507
+ loss_fct = nn.CrossEntropyLoss()
508
+ loss = loss_fct(
509
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
510
+ )
511
+
512
+ if not return_dict:
513
+ output = (logits,) + outputs[1:]
514
+ return (loss,) + output if loss is not None else output
515
+
516
+ return LlavaCausalLMOutputWithPast(
517
+ loss=loss,
518
+ logits=logits,
519
+ past_key_values=outputs.past_key_values,
520
+ hidden_states=outputs.hidden_states,
521
+ attentions=outputs.attentions,
522
+ )
523
+
524
+ def prepare_inputs_for_generation(
525
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
526
+ ):
527
+ if past_key_values is not None:
528
+ if isinstance(past_key_values, Cache):
529
+ cache_length = past_key_values.get_seq_length()
530
+ past_length = past_key_values.seen_tokens
531
+ else:
532
+ cache_length = past_length = past_key_values[0][0].shape[2]
533
+
534
+ # Keep only the unprocessed tokens:
535
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
536
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
537
+ # input)
538
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
539
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
540
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
541
+ # input_ids based on the past_length.
542
+ elif past_length < input_ids.shape[1]:
543
+ input_ids = input_ids[:, past_length:]
544
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
545
+ elif self.config.image_token_index in input_ids:
546
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
547
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
548
+ # older attention values, as their corresponding values are not part of the input.
549
+ if cache_length < past_length and attention_mask is not None:
550
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
551
+
552
+ position_ids = kwargs.get("position_ids", None)
553
+ if attention_mask is not None and position_ids is None:
554
+ # create position_ids on the fly for batch generation
555
+ position_ids = attention_mask.long().cumsum(-1) - 1
556
+ position_ids.masked_fill_(attention_mask == 0, 1)
557
+ if past_key_values:
558
+ position_ids = position_ids[:, -input_ids.shape[1] :]
559
+
560
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
561
+ if inputs_embeds is not None and past_key_values is None:
562
+ model_inputs = {"inputs_embeds": inputs_embeds}
563
+ else:
564
+ model_inputs = {"input_ids": input_ids}
565
+
566
+ model_inputs.update(
567
+ {
568
+ "position_ids": position_ids,
569
+ "past_key_values": past_key_values,
570
+ "use_cache": kwargs.get("use_cache"),
571
+ "attention_mask": attention_mask,
572
+ "pixel_values": pixel_values,
573
+ }
574
+ )
575
+ return model_inputs
576
+
577
+ def _reorder_cache(self, *args, **kwargs):
578
+ return self.language_model._reorder_cache(*args, **kwargs)
579
+
580
+
581
+
582
+
583
+ from transformers.models.clip.modeling_clip import CLIPEncoderLayer, CLIPEncoder
584
+ @add_start_docstrings(
585
+ """The MLLAVA model which consists of a vision backbone and a language model.""",
586
+ LLAVA_START_DOCSTRING,
587
+ )
588
+ class MLlavaForConditionalGeneration(LlavaForConditionalGeneration):
589
+ def __init__(self, config: LlavaConfig):
590
+ super().__init__(config)
591
+ config.vision_config.type_vocab_size = 144
592
+ self.image_type_embeddings = nn.Embedding(config.vision_config.type_vocab_size, config.vision_config.hidden_size)
593
+ # self.vision_xatten_layers = nn.ModuleList([CLIPEncoderLayer(config.vision_config) for _ in range(config.vision_config.num_hidden_layers)])
594
+ self.vision_xatten_layers = CLIPEncoder(config.vision_config)
595
+
596
+
597
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
598
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
599
+ def forward(
600
+ self,
601
+ input_ids: torch.LongTensor = None,
602
+ pixel_values: torch.FloatTensor = None,
603
+ attention_mask: Optional[torch.Tensor] = None,
604
+ position_ids: Optional[torch.LongTensor] = None,
605
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
606
+ inputs_embeds: Optional[torch.FloatTensor] = None,
607
+ vision_feature_layer: Optional[int] = None,
608
+ vision_feature_select_strategy: Optional[str] = None,
609
+ labels: Optional[torch.LongTensor] = None,
610
+ use_cache: Optional[bool] = None,
611
+ output_attentions: Optional[bool] = None,
612
+ output_hidden_states: Optional[bool] = None,
613
+ return_dict: Optional[bool] = None,
614
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
615
+ r"""
616
+ Args:
617
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
618
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
619
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
620
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
621
+
622
+ Returns:
623
+
624
+ Example:
625
+
626
+ ```python
627
+ >>> from PIL import Image
628
+ >>> import requests
629
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
630
+
631
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
632
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
633
+
634
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
635
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
636
+ >>> image = Image.open(requests.get(url, stream=True).raw)
637
+
638
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
639
+
640
+ >>> # Generate
641
+ >>> generate_ids = model.generate(**inputs, max_length=30)
642
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
643
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
644
+ ```"""
645
+
646
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
647
+ output_hidden_states = (
648
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
649
+ )
650
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
651
+ vision_feature_layer = (
652
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
653
+ )
654
+ vision_feature_select_strategy = (
655
+ vision_feature_select_strategy
656
+ if vision_feature_select_strategy is not None
657
+ else self.config.vision_feature_select_strategy
658
+ )
659
+
660
+ if inputs_embeds is None:
661
+ # 1. Extra the input embeddings
662
+ inputs_embeds = self.get_input_embeddings()(input_ids)
663
+
664
+ # 2. Merge text and images
665
+ if pixel_values is not None and input_ids.shape[1] != 1:
666
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
667
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
668
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
669
+
670
+ if vision_feature_select_strategy == "default":
671
+ selected_image_feature = selected_image_feature[:, 1:]
672
+ elif vision_feature_select_strategy == "full":
673
+ selected_image_feature = selected_image_feature
674
+ else:
675
+ raise ValueError(
676
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
677
+ )
678
+
679
+ # added by Dongfu
680
+ num_images, num_image_patches, embed_dim = selected_image_feature.shape
681
+ image_type_embeddings = self.image_type_embeddings(torch.arange(num_images, device=selected_image_feature.device))
682
+ selected_image_feature += image_type_embeddings.unsqueeze(1)
683
+ xatten_output = self.vision_xatten_layers(selected_image_feature, attention_mask=None, causal_attention_mask=None)
684
+ selected_image_feature = xatten_output[0]
685
+ # end of added by Dongfu
686
+
687
+ image_features = self.multi_modal_projector(selected_image_feature)
688
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
689
+ image_features, inputs_embeds, input_ids, attention_mask, labels
690
+ )
691
+ if labels is None:
692
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
693
+ else:
694
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
695
+ # generation with cache
696
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
697
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
698
+ # that are set to 0
699
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
700
+
701
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
702
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
703
+
704
+ # Get the target length
705
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
706
+
707
+ extended_attention_mask = torch.ones(
708
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
709
+ dtype=attention_mask.dtype,
710
+ device=attention_mask.device,
711
+ )
712
+
713
+ # Filter out only the tokens that can be un-attended, this can happen
714
+ # if one uses Llava + Fused modules where the cache on the
715
+ # first iteration is already big enough, or if one passes custom cache
716
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
717
+ new_batch_index = batch_index[valid_indices]
718
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
719
+
720
+ # Zero-out the places where we don't need to attend
721
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
722
+
723
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
724
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
725
+
726
+ outputs = self.language_model(
727
+ attention_mask=attention_mask,
728
+ position_ids=position_ids,
729
+ past_key_values=past_key_values,
730
+ inputs_embeds=inputs_embeds,
731
+ use_cache=use_cache,
732
+ output_attentions=output_attentions,
733
+ output_hidden_states=output_hidden_states,
734
+ return_dict=return_dict,
735
+ )
736
+
737
+ logits = outputs[0]
738
+
739
+ loss = None
740
+ if labels is not None:
741
+ # Shift so that tokens < n predict n
742
+ if attention_mask is not None:
743
+ shift_attention_mask = attention_mask[..., 1:]
744
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
745
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
746
+ else:
747
+ shift_logits = logits[..., :-1, :].contiguous()
748
+ shift_labels = labels[..., 1:].contiguous()
749
+ # Flatten the tokens
750
+ loss_fct = nn.CrossEntropyLoss()
751
+ loss = loss_fct(
752
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
753
+ )
754
+
755
+ if not return_dict:
756
+ output = (logits,) + outputs[1:]
757
+ return (loss,) + output if loss is not None else output
758
+
759
+ return LlavaCausalLMOutputWithPast(
760
+ loss=loss,
761
+ logits=logits,
762
+ past_key_values=outputs.past_key_values,
763
+ hidden_states=outputs.hidden_states,
764
+ attentions=outputs.attentions,
765
+ )
models/mllava/processing_llava.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Llava.
17
+ """
18
+
19
+
20
+ from typing import List, Optional, Union, Dict
21
+
22
+ # from ...feature_extraction_utils import BatchFeature
23
+ # from ...image_utils import ImageInput
24
+ # from ...processing_utils import ProcessorMixin
25
+ # from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
26
+ # from ...utils import TensorType
27
+
28
+ from transformers.feature_extraction_sequence_utils import BatchFeature
29
+ from transformers.image_utils import ImageInput
30
+ from transformers.processing_utils import ProcessorMixin
31
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
32
+ from transformers.utils import TensorType
33
+
34
+ from PIL import Image
35
+ import logging
36
+ import torch
37
+ import numpy as np
38
+ logger = logging.getLogger(__name__)
39
+
40
+ class MLlavaProcessor(ProcessorMixin):
41
+ r"""
42
+ Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
43
+
44
+ [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
45
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
46
+
47
+ Args:
48
+ image_processor ([`CLIPImageProcessor`], *optional*):
49
+ The image processor is a required input.
50
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
51
+ The tokenizer is a required input.
52
+ """
53
+
54
+ attributes = ["image_processor", "tokenizer"]
55
+ image_processor_class = "CLIPImageProcessor"
56
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
57
+
58
+ def __init__(self, image_processor=None, tokenizer=None):
59
+ super().__init__(image_processor, tokenizer)
60
+
61
+ def preprocess_interleaved_images_and_text(
62
+ self,
63
+ text,
64
+ images=None,
65
+ ):
66
+ """
67
+ Args:
68
+ text (`str`, `List[str]`):
69
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
70
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
71
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
72
+ text can contain <image> tokens as the placeholder for the image(s) to be inserted.
73
+ images (`PIL.Image.Image`, `List[PIL.Image.Image]`, `List[List[PIL.Image.Image]]`):
74
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
75
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
76
+ number of channels, H and W are image height and width.
77
+ the number of the images should match the number of <image> tokens in the text.
78
+
79
+ """
80
+ assert text is not None, "text cannot be None."
81
+
82
+ if images is not None:
83
+ if isinstance(images, Image.Image):
84
+ images = [images]
85
+ if isinstance(images, list) and isinstance(images[0], Image.Image):
86
+ if isinstance(text, str):
87
+ images = [images]
88
+ elif isinstance(text, list):
89
+ if len(text) != len(images):
90
+ raise ValueError("Invalid input text. Number of texts does not match number of images.")
91
+ images = [[image] for image in images]
92
+ if isinstance(text, str):
93
+ num_images = len(images[0])
94
+ num_image_tokens = text.count("<image>")
95
+ if num_image_tokens < num_images:
96
+ # prepend empty image tokens to text
97
+ if "USER:" in text:
98
+ text = text.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
99
+ elif "Human:" in text:
100
+ text = text.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
101
+ elif "HUMAN:" in text:
102
+ text = text.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
103
+ else:
104
+ text = "<image>" * (num_images - num_image_tokens) + text
105
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
106
+ elif num_image_tokens > num_images:
107
+ text = text.split("<image>")
108
+ for i, t in enumerate(text):
109
+ if i < num_images:
110
+ text[i] = t + "<image>"
111
+ text = "".join(text)
112
+ logger.warning("Number of <image> tokens exceeds number of images. Automatically removing extra tokens at the end of the text.")
113
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
114
+ texts = [text]
115
+ elif isinstance(text, list):
116
+ if not isinstance(text[0], str):
117
+ raise ValueError("Invalid input text. Each element of text must be a string.")
118
+ for i, t in enumerate(text):
119
+ num_image_tokens = t.count("<image>")
120
+ num_images = len(images[i])
121
+ if num_image_tokens < num_images:
122
+ # prepend empty image tokens to text
123
+ if "USER:" in t:
124
+ t = t.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
125
+ else:
126
+ t = "<image>" * (num_images - num_image_tokens) + t
127
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
128
+ elif num_image_tokens > num_images:
129
+ t = t.split("<image>")
130
+ for j, s in enumerate(t):
131
+ if j < num_images:
132
+ t[j] = s + "<image>"
133
+ t = "".join(t)
134
+ logger.warning("Number of <image> tokens exceeds number of images. Automatically removing extra tokens at the end of the text.")
135
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
136
+ texts = text
137
+ else:
138
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
139
+ assert all([t.count("<image>") == len(images_per_text) for t, images_per_text in zip(texts, images)]), "Number of <image> tokens in text does not match number of images."
140
+ # add image denotation in text before each <image> as "(image {i}: <image>)"
141
+ for i, t in enumerate(texts):
142
+ for j in range(len(images[i])):
143
+ t = t.replace("<image>", f"(image {j+1}: <Image><IMAGE></Image>)", 1)
144
+ t = t.replace("<IMAGE>", "<image>")
145
+ texts[i] = t
146
+
147
+ # flatten images
148
+ images = [image for images_per_text in images for image in images_per_text]
149
+ else:
150
+ if isinstance(text, str):
151
+ texts = [text]
152
+ elif isinstance(text, list):
153
+ if not isinstance(text[0], str):
154
+ raise ValueError("Invalid input text. Each element of text must be a string.")
155
+ texts = text
156
+ else:
157
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
158
+
159
+ return texts, images
160
+
161
+ def __call__(
162
+ self,
163
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
164
+ images: ImageInput = None,
165
+ padding: Union[bool, str, PaddingStrategy] = False,
166
+ truncation: Union[bool, str, TruncationStrategy] = None,
167
+ max_length=None,
168
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
169
+ ) -> BatchFeature:
170
+ """
171
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
172
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
173
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
174
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
175
+ of the above two methods for more information.
176
+
177
+ Args:
178
+ text (`str`, `List[str]`, `List[List[str]]`):
179
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
180
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
181
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
182
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
183
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
184
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
185
+ number of channels, H and W are image height and width.
186
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
187
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
188
+ index) among:
189
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
190
+ sequence if provided).
191
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
192
+ acceptable input length for the model if that argument is not provided.
193
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
194
+ lengths).
195
+ max_length (`int`, *optional*):
196
+ Maximum length of the returned list and optionally padding length (see above).
197
+ truncation (`bool`, *optional*):
198
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
199
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
200
+ If set, will return tensors of a particular framework. Acceptable values are:
201
+
202
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
203
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
204
+ - `'np'`: Return NumPy `np.ndarray` objects.
205
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
206
+
207
+ Returns:
208
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
209
+
210
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
211
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
212
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
213
+ `None`).
214
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
215
+ """
216
+ texts, images = self.preprocess_interleaved_images_and_text(text, images)
217
+ if images is not None:
218
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] # [batch_size, num_channels, height, width], e.g. [1, 3, 336, 336]
219
+ else:
220
+ pixel_values = None
221
+ text_inputs = self.tokenizer(
222
+ texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
223
+ )
224
+ # text_inputs:
225
+ # 1. input_ids: [batch_size, sequence_length], e.g. [1, 6]
226
+ # 2. attention_mask: [batch_size, sequence_length], e.g. [1, 6]
227
+
228
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
229
+
230
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
231
+ def batch_decode(self, *args, **kwargs):
232
+ """
233
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
234
+ refer to the docstring of this method for more information.
235
+ """
236
+ return self.tokenizer.batch_decode(*args, **kwargs)
237
+
238
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
239
+ def decode(self, *args, **kwargs):
240
+ """
241
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
242
+ the docstring of this method for more information.
243
+ """
244
+ return self.tokenizer.decode(*args, **kwargs)
245
+
246
+ @property
247
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
248
+ def model_input_names(self):
249
+ tokenizer_input_names = self.tokenizer.model_input_names
250
+ image_processor_input_names = self.image_processor.model_input_names
251
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
252
+
253
+ def _right_pad_inputs_with_attention_mask(self, model_inputs: List[Dict]):
254
+ results = {}
255
+ assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
256
+ for k in model_inputs[0].keys():
257
+ if model_inputs[0][k] is not None:
258
+ results[k] = torch.cat([inputs[k] for inputs in model_inputs], dim=0)
259
+ else:
260
+ results[k] = None
261
+ return results
262
+
models/mllava/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ from .modeling_llava import LlavaForConditionalGeneration
4
+ from .processing_llava import MLlavaProcessor
5
+ from ..conversation import conv_mllava_v1_mmtag as default_conv
6
+ from typing import List, Tuple
7
+
8
+ def chat_mllava(
9
+ text:str,
10
+ images: List[PIL.Image.Image],
11
+ model:LlavaForConditionalGeneration,
12
+ processor:MLlavaProcessor,
13
+ max_input_length:int=None,
14
+ history:List[dict]=None,
15
+ stream:bool=False,
16
+ **kwargs) -> Tuple[str, List[dict]]:
17
+ """
18
+ Chat with the Mllava model
19
+ Args:
20
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
21
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
22
+ model: LlavaForConditionalGeneration, the model to be used
23
+ processor: MLlavaProcessor, the processor to be used
24
+ max_input_length: int, the maximum input length
25
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
26
+ kwargs: dict, the generation kwargs
27
+ Returns:
28
+ Tuple[str, List[dict]], the generated text and the history of the conversation
29
+
30
+
31
+ """
32
+ conv = default_conv.copy()
33
+ conv.messages = []
34
+ if history is not None:
35
+ for message in history:
36
+ message["role"] = message["role"].upper()
37
+ assert message["role"] in conv.roles
38
+ conv.append_message(message["role"], message["text"])
39
+ else:
40
+ history = []
41
+ conv.append_message(conv.roles[0], text)
42
+ conv.append_message(conv.roles[1], "")
43
+
44
+ prompt = conv.get_prompt()
45
+
46
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
47
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
48
+
49
+ if stream:
50
+ from transformers import TextIteratorStreamer
51
+ from threading import Thread
52
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
53
+ kwargs["streamer"] = streamer
54
+ inputs.update(kwargs)
55
+ thread = Thread(target=model.generate, kwargs=inputs)
56
+ thread.start()
57
+ history.append({"role": conv.roles[0], "text": text})
58
+ history.append({"role": conv.roles[1], "text": ""})
59
+ for _output in streamer:
60
+ history[-1]["text"] += _output
61
+ yield history[-1]["text"], history
62
+ else:
63
+ output_ids = model.generate(**inputs, **kwargs)
64
+ output_ids = output_ids[0]
65
+
66
+ # remove the input tokens
67
+ generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
68
+ generated_text = processor.decode(generated_ids, skip_special_tokens=True)
69
+
70
+ history.append({"role": conv.roles[0], "text": text})
71
+ history.append({"role": conv.roles[1], "text": generated_text})
72
+
73
+ return generated_text, history
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git