izhx commited on
Commit
22270be
1 Parent(s): 748be87

Create gme_inference.py

Browse files
Files changed (1) hide show
  1. gme_inference.py +331 -0
gme_inference.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import os
6
+ from typing import Dict, List, Optional
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.autonotebook import tqdm
12
+ from transformers import AutoModelForVision2Seq, AutoProcessor
13
+
14
+
15
+ class GmeQwen2VL:
16
+ def __init__(
17
+ self,
18
+ model_name: str = "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
19
+ model_path: Optional[str] = None,
20
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
21
+ min_image_tokens=256,
22
+ max_image_tokens=1280,
23
+ max_length=1800,
24
+ **kwargs,
25
+ ) -> None:
26
+ model_name = model_path or model_name
27
+ self.base = AutoModelForVision2Seq.from_pretrained(
28
+ model_name, torch_dtype=torch.float16, **kwargs
29
+ )
30
+ self.base.eval()
31
+ self.normalize = True
32
+ self.device = device
33
+ min_pixels = min_image_tokens * 28 * 28
34
+ max_pixels = max_image_tokens * 28 * 28
35
+ self.max_length = max_length
36
+ self.processor = AutoProcessor.from_pretrained(
37
+ model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
38
+ )
39
+ self.processor.tokenizer.padding_side = 'right'
40
+ self.defualt_instruction = 'You are a helpful assistant.'
41
+ self.sep = ' '
42
+
43
+ def forward(
44
+ self,
45
+ input_ids: Optional[torch.LongTensor] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ position_ids: Optional[torch.LongTensor] = None,
48
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
49
+ inputs_embeds: Optional[torch.FloatTensor] = None,
50
+ pixel_values: Optional[torch.Tensor] = None,
51
+ # pixel_values_videos: Optional[torch.FloatTensor] = None,
52
+ image_grid_thw: Optional[torch.LongTensor] = None,
53
+ # video_grid_thw: Optional[torch.LongTensor] = None,
54
+ pooling_mask: Optional[torch.LongTensor] = None,
55
+ **kwargs
56
+ ) -> torch.Tensor:
57
+ if inputs_embeds is None:
58
+ inputs_embeds = self.base.model.embed_tokens(input_ids)
59
+ if pixel_values is not None:
60
+ pixel_values = pixel_values.type(self.base.visual.get_dtype())
61
+ image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
62
+ image_mask = input_ids == self.base.config.image_token_id
63
+ inputs_embeds[image_mask] = image_embeds
64
+ # if pixel_values_videos is not None:
65
+ # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
66
+ # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
67
+ # video_mask = input_ids == self.base.config.video_token_id
68
+ # inputs_embeds[video_mask] = video_embeds
69
+ if attention_mask is not None:
70
+ attention_mask = attention_mask.to(inputs_embeds.device)
71
+
72
+ outputs = self.base.model(
73
+ input_ids=None,
74
+ position_ids=position_ids,
75
+ attention_mask=attention_mask,
76
+ past_key_values=past_key_values,
77
+ inputs_embeds=inputs_embeds,
78
+ )
79
+
80
+ pooling_mask = attention_mask if pooling_mask is None else pooling_mask
81
+ left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
82
+ if left_padding:
83
+ embeddings = outputs.last_hidden_state[:, -1]
84
+ else:
85
+ sequence_lengths = pooling_mask.sum(dim=1) - 1
86
+ batch_size = outputs.last_hidden_state.shape[0]
87
+ embeddings = outputs.last_hidden_state[torch.arange(
88
+ batch_size, device=outputs.last_hidden_state.device
89
+ ), sequence_lengths]
90
+ if self.normalize:
91
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
92
+ return embeddings.contiguous()
93
+
94
+ def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
95
+ self.base.to(self.device)
96
+ # Inputs must be batched
97
+ input_texts, input_images = list(), list()
98
+ for t, i in zip(texts, images):
99
+ if not is_query or instruction is None:
100
+ instruction = self.defualt_instruction
101
+ input_str = ''
102
+ if i is None:
103
+ input_images = None # All examples in the same batch are consistent
104
+ else:
105
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
106
+ i = fetch_image(i)
107
+ input_images.append(i)
108
+ if t is not None:
109
+ input_str += t
110
+ msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
111
+ input_texts.append(msg)
112
+
113
+ inputs = self.processor(
114
+ text=input_texts,
115
+ images=input_images,
116
+ padding=True,
117
+ truncation=True,
118
+ max_length=self.max_length,
119
+ return_tensors='pt'
120
+ )
121
+ inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
122
+ with torch.no_grad():
123
+ embeddings = self.forward(**inputs)
124
+ return embeddings
125
+
126
+ def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
127
+ return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
128
+
129
+ def encode_queries(self, queries: List[str], **kwargs):
130
+ embeddings = self.encode(queries, **kwargs)
131
+ return embeddings
132
+
133
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
134
+ if type(corpus) is dict:
135
+ sentences = [
136
+ (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
137
+ if "title" in corpus
138
+ else corpus["text"][i].strip()
139
+ for i in range(len(corpus["text"]))
140
+ ]
141
+ else:
142
+ sentences = [
143
+ (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
144
+ for doc in corpus
145
+ ]
146
+ embeddings = self.encode(sentences, is_query=False, **kwargs)
147
+ return embeddings
148
+
149
+ def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
150
+ return self.get_fused_embeddings(images=images, **kwargs)
151
+
152
+ def get_text_embeddings(self, texts: list[str], **kwargs):
153
+ return self.get_fused_embeddings(texts=texts, **kwargs)
154
+
155
+ def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
156
+ if isinstance(images, DataLoader):
157
+ image_loader = images
158
+ batch_size = image_loader.batch_size
159
+ image_loader.dataset.transform = None
160
+ else:
161
+ batch_size = kwargs.pop('batch_size', 32)
162
+ if images is None:
163
+ image_loader = None
164
+ else:
165
+ image_loader = DataLoader(
166
+ images,
167
+ batch_size=batch_size,
168
+ shuffle=False,
169
+ collate_fn=custom_collate_fn,
170
+ num_workers=min(math.floor(os.cpu_count() / 2), 8),
171
+ )
172
+
173
+ if texts is None:
174
+ assert image_loader is not None
175
+ n_batch = len(image_loader)
176
+ else:
177
+ n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
178
+ image_loader = image_loader or [None] * n_batch
179
+
180
+ all_embeddings = list()
181
+ none_batch = [None] * batch_size
182
+ show_progress_bar = kwargs.pop('show_progress_bar', True)
183
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
184
+ for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
185
+ text_batch = none_batch if texts is None else texts[n: n+batch_size]
186
+ img_batch = none_batch if img_batch is None else img_batch
187
+ embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
188
+ pbar.update(1)
189
+ all_embeddings.append(embeddings.cpu())
190
+ pbar.close()
191
+ all_embeddings = torch.cat(all_embeddings, dim=0)
192
+ return all_embeddings
193
+
194
+
195
+ def custom_collate_fn(batch):
196
+ return batch
197
+
198
+
199
+ ### Copied from qwen_vl_utils.vision_process.py
200
+ import base64
201
+ from io import BytesIO
202
+ import requests
203
+
204
+ IMAGE_FACTOR = 28
205
+ MIN_PIXELS = 4 * 28 * 28
206
+ MAX_PIXELS = 16384 * 28 * 28
207
+ MAX_RATIO = 200
208
+
209
+
210
+ def round_by_factor(number: int, factor: int) -> int:
211
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
212
+ return round(number / factor) * factor
213
+
214
+
215
+ def ceil_by_factor(number: int, factor: int) -> int:
216
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
217
+ return math.ceil(number / factor) * factor
218
+
219
+
220
+ def floor_by_factor(number: int, factor: int) -> int:
221
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
222
+ return math.floor(number / factor) * factor
223
+
224
+
225
+ def smart_resize(
226
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
227
+ ) -> tuple[int, int]:
228
+ """
229
+ Rescales the image so that the following conditions are met:
230
+
231
+ 1. Both dimensions (height and width) are divisible by 'factor'.
232
+
233
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
234
+
235
+ 3. The aspect ratio of the image is maintained as closely as possible.
236
+ """
237
+ h_bar = max(factor, round_by_factor(height, factor))
238
+ w_bar = max(factor, round_by_factor(width, factor))
239
+ if h_bar * w_bar > max_pixels:
240
+ beta = math.sqrt((height * width) / max_pixels)
241
+ h_bar = floor_by_factor(height / beta, factor)
242
+ w_bar = floor_by_factor(width / beta, factor)
243
+ elif h_bar * w_bar < min_pixels:
244
+ beta = math.sqrt(min_pixels / (height * width))
245
+ h_bar = ceil_by_factor(height * beta, factor)
246
+ w_bar = ceil_by_factor(width * beta, factor)
247
+
248
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
249
+ logging.warning(
250
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
251
+ )
252
+ if h_bar > w_bar:
253
+ h_bar = w_bar * MAX_RATIO
254
+ else:
255
+ w_bar = h_bar * MAX_RATIO
256
+ return h_bar, w_bar
257
+
258
+
259
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
260
+ image_obj = None
261
+ if isinstance(image, Image.Image):
262
+ image_obj = image
263
+ elif image.startswith("http://") or image.startswith("https://"):
264
+ image_obj = Image.open(requests.get(image, stream=True).raw)
265
+ elif image.startswith("file://"):
266
+ image_obj = Image.open(image[7:])
267
+ elif image.startswith("data:image"):
268
+ if "base64," in image:
269
+ _, base64_data = image.split("base64,", 1)
270
+ data = base64.b64decode(base64_data)
271
+ image_obj = Image.open(BytesIO(data))
272
+ else:
273
+ image_obj = Image.open(image)
274
+ if image_obj is None:
275
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
276
+ image = image_obj.convert("RGB")
277
+ ## resize
278
+ # if "resized_height" in ele and "resized_width" in ele:
279
+ # resized_height, resized_width = smart_resize(
280
+ # ele["resized_height"],
281
+ # ele["resized_width"],
282
+ # factor=size_factor,
283
+ # )
284
+ # else:
285
+ width, height = image.size
286
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
287
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
288
+ resized_height, resized_width = smart_resize(
289
+ height,
290
+ width,
291
+ factor=size_factor,
292
+ min_pixels=MIN_PIXELS,
293
+ max_pixels=MAX_PIXELS,
294
+ )
295
+ image = image.resize((resized_width, resized_height))
296
+
297
+ return image
298
+ ###
299
+
300
+
301
+ if __name__ == '__main__':
302
+ texts = [
303
+ "What kind of car is this?",
304
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
305
+ ]
306
+ images = [
307
+ # 'https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg',
308
+ '/nas-alinlp/linzhang.zx/gme_space/assets/Tesla_Cybertruck_damaged_window.jpg',
309
+ # 'https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
310
+ '/nas-alinlp/linzhang.zx/gme_space/assets/2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
311
+ ]
312
+
313
+ gme = GmeQwen2VL("/nas-alinlp/linzhang.zx/gme_space/gme-Qwen2-VL-2B-instruct")
314
+
315
+ # Single-modal embedding
316
+ e_text = gme.get_text_embeddings(texts=texts)
317
+ e_image = gme.get_image_embeddings(images=images)
318
+ print((e_text * e_image).sum(-1))
319
+ ## tensor([0.2281, 0.6001], dtype=torch.float16)
320
+
321
+ # How to set embedding instruction
322
+ e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
323
+ # If is_query=False, we always use the default instruction.
324
+ e_corpus = gme.get_image_embeddings(images=images, is_query=False)
325
+ print((e_query * e_corpus).sum(-1))
326
+ ## tensor([0.2433, 0.7051], dtype=torch.float16)
327
+
328
+ # Fused-modal embedding
329
+ e_fused = gme.get_fused_embeddings(texts=texts, images=images)
330
+ print((e_fused[0] * e_fused[1]).sum())
331
+ ## tensor(0.6108, dtype=torch.float16)