jadechoghari commited on
Commit
6d56da9
1 Parent(s): a5dd8d6

mega update

Browse files
Files changed (1) hide show
  1. inference.py +236 -21
inference.py CHANGED
@@ -3,13 +3,198 @@ from PIL import Image
3
  from conversation import conv_templates
4
  from builder import load_pretrained_model # Assuming this is your custom model loader
5
  from functools import partial
 
 
 
6
  import numpy as np
 
 
 
 
 
 
 
 
7
 
8
  # define the task categories
9
  box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
10
  box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
11
  no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # function to generate the mask
14
  def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
15
  """
@@ -36,32 +221,29 @@ def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
36
  if mask is not None:
37
  coor_mask = coor_mask * mask
38
 
39
- # Convert to torch tensor and ensure it contains non-zero values
40
  coor_mask = torch.from_numpy(coor_mask)
41
  assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
42
 
 
43
  return coor_mask
44
 
45
 
46
- def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_gemma", conv_mode="ferret_gemma_instruct"):
47
  img = Image.open(image_path).convert('RGB')
48
 
49
  # this loads the model, image processor and tokenizer
50
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
51
-
52
- # define the image size (e.g., 224x224 or 336x336)
53
  image_size = {"height": 336, "width": 336}
54
 
55
- # process the image
56
- image_tensor = image_processor.preprocess(
57
- img,
58
- return_tensors='pt',
59
- do_resize=True,
60
- do_center_crop=False,
61
- size=(image_size['height'], image_size['width'])
62
- )['pixel_values'][0].unsqueeze(0)
63
 
64
- image_tensor = image_tensor.half().cuda()
 
 
 
65
 
66
  # generate the prompt per template requirement
67
  conv = conv_templates[conv_mode].copy()
@@ -69,16 +251,45 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
69
  conv.append_message(conv.roles[1], None)
70
  prompt_input = conv.get_prompt()
71
 
72
- # tokenize prompt
73
  input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # region mask logic (if region is provided)
76
  region_masks = None
77
- if region is not None:
 
78
  raw_w, raw_h = img.size
79
- region_masks = generate_mask_for_feature(region, raw_w, raw_h).unsqueeze(0).cuda().half()
80
- region_masks = [[region_masks]] # Wrap the mask in lists as expected by the model
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
82
  # generate model output
83
  with torch.inference_mode():
84
  # Use region_masks in model's forward call
@@ -87,9 +298,10 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
87
  model.orig_forward,
88
  region_masks=region_masks
89
  )
 
90
  output_ids = model.generate(
91
  input_ids,
92
- images=image_tensor,
93
  max_new_tokens=1024,
94
  num_beams=1,
95
  region_masks=region_masks, # pass the region mask to the model
@@ -102,16 +314,19 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
102
  return output_text.strip()
103
 
104
  # We also define a task-specific inference function
105
- def infer_ui_task(image_path, prompt, model_path, task, region=None):
 
106
  """
107
  Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
108
  """
 
 
109
  if task in box_in_tasks and region is None:
110
  raise ValueError(f"Task {task} requires a bounding box region.")
111
 
112
  if task in box_in_tasks:
113
  print(f"Processing {task} with bounding box region.")
114
- return infer_single_prompt(image_path, prompt, model_path, region)
115
 
116
  elif task in box_out_tasks:
117
  print(f"Processing {task} without bounding box region.")
@@ -122,4 +337,4 @@ def infer_ui_task(image_path, prompt, model_path, task, region=None):
122
  return infer_single_prompt(image_path, prompt, model_path)
123
 
124
  else:
125
- raise ValueError(f"Unknown task type: {task}")
 
3
  from conversation import conv_templates
4
  from builder import load_pretrained_model # Assuming this is your custom model loader
5
  from functools import partial
6
+ from typing import Optional, Callable
7
+ import ast
8
+ import math
9
  import numpy as np
10
+ DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IM_START_TOKEN = "<im_start>"
13
+ DEFAULT_IM_END_TOKEN = "<im_end>"
14
+ VOCAB_IMAGE_W = 1000 # 224
15
+ VOCAB_IMAGE_H = 1000 # 224
16
+ IMAGE_TOKEN_INDEX = -200
17
+
18
 
19
  # define the task categories
20
  box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
21
  box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
22
  no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
23
 
24
+ def get_bbox_coor(box, ratio_w, ratio_h):
25
+ return box[0] * ratio_w, box[1] * ratio_h, box[2] * ratio_w, box[3] * ratio_h
26
+
27
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
28
+ if '<image>' in prompt:
29
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
30
+ input_ids = []
31
+ for i, chunk in enumerate(prompt_chunks):
32
+ input_ids.extend(chunk)
33
+ if i < len(prompt_chunks) - 1:
34
+ input_ids.append(image_token_index)
35
+ else:
36
+ input_ids = tokenizer(prompt).input_ids
37
+ # if return_tensors == 'pt':
38
+ # import torch
39
+ # input_ids = torch.tensor(input_ids).unsqueeze(0)
40
+
41
+ return input_ids
42
+
43
+
44
+ def expand2square(pil_img, background_color):
45
+ width, height = pil_img.size
46
+ if width == height:
47
+ return pil_img
48
+ elif width > height:
49
+ result = Image.new(pil_img.mode, (width, width), background_color)
50
+ result.paste(pil_img, (0, (width - height) // 2))
51
+ return result
52
+ else:
53
+ result = Image.new(pil_img.mode, (height, height), background_color)
54
+ result.paste(pil_img, ((height - width) // 2, 0))
55
+ return result
56
+
57
+ def select_best_resolution(original_size, possible_resolutions):
58
+ """
59
+ Selects the best resolution from a list of possible resolutions based on the original size.
60
+
61
+ Args:
62
+ original_size (tuple): The original size of the image in the format (width, height).
63
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
64
+
65
+ Returns:
66
+ tuple: The best fit resolution in the format (width, height).
67
+ """
68
+ original_width, original_height = original_size
69
+ best_fit = None
70
+ max_effective_resolution = 0
71
+ min_wasted_resolution = float('inf')
72
+
73
+ for width, height in possible_resolutions:
74
+ scale = min(width / original_width, height / original_height)
75
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
76
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
77
+ wasted_resolution = (width * height) - effective_resolution
78
+
79
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
80
+ max_effective_resolution = effective_resolution
81
+ min_wasted_resolution = wasted_resolution
82
+ best_fit = (width, height)
83
+
84
+ return best_fit
85
+
86
+ def divide_to_patches(image, patch_size):
87
+ """
88
+ Divides an image into patches of a specified size.
89
+
90
+ Args:
91
+ image (PIL.Image.Image): The input image.
92
+ patch_size (int): The size of each patch.
93
+
94
+ Returns:
95
+ list: A list of PIL.Image.Image objects representing the patches.
96
+ """
97
+ patches = []
98
+ width, height = image.size
99
+ for i in range(0, height, patch_size):
100
+ for j in range(0, width, patch_size):
101
+ box = (j, i, j + patch_size, i + patch_size)
102
+ patch = image.crop(box)
103
+ patches.append(patch)
104
+
105
+ return patches
106
+ def resize_and_pad_image(image, target_resolution, is_pad=False):
107
+ """
108
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
109
+ Args:
110
+ image (PIL.Image.Image): The input image.
111
+ target_resolution (tuple): The target resolution (width, height) of the image.
112
+ Returns:
113
+ PIL.Image.Image: The resized and padded image.
114
+ """
115
+ original_width, original_height = image.size
116
+ target_width, target_height = target_resolution
117
+
118
+ if is_pad:
119
+ scale_w = target_width / original_width
120
+ scale_h = target_height / original_height
121
+
122
+ if scale_w < scale_h:
123
+ new_width = target_width
124
+ new_height = min(math.ceil(original_height * scale_w), target_height)
125
+ else:
126
+ new_height = target_height
127
+ new_width = min(math.ceil(original_width * scale_h), target_width)
128
+
129
+ # Resize the image
130
+ resized_image = image.resize((new_width, new_height))
131
+
132
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
133
+ paste_x = (target_width - new_width) // 2
134
+ paste_y = (target_height - new_height) // 2
135
+ new_image.paste(resized_image, (paste_x, paste_y))
136
+ else:
137
+ new_image = image.resize((target_width, target_height))
138
+
139
+ return new_image
140
+
141
+ def process_anyres_image(image, processor, grid_pinpoints, image_process_func: Optional[Callable] = None):
142
+ """
143
+ Process an image with variable resolutions.
144
+
145
+ Args:
146
+ image (PIL.Image.Image): The input image to be processed.
147
+ processor: The image processor object.
148
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
149
+
150
+ Returns:
151
+ torch.Tensor: A tensor containing the processed image patches.
152
+ """
153
+ if type(grid_pinpoints) is list:
154
+ possible_resolutions = grid_pinpoints
155
+ else:
156
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
157
+
158
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
159
+
160
+ # FIXME: not sure if do_pad or undo_pad may affect the referring side
161
+ image_padded = resize_and_pad_image(image, best_resolution, is_pad=False)
162
+
163
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
164
+
165
+ if image_process_func:
166
+ resized_image_h, resized_image_w = image_process_func.keywords['size']
167
+ image_original_resize = image.resize((resized_image_w, resized_image_h))
168
+ image_patches = [image_original_resize] + patches
169
+ image_patches = [image_process_func(image_patch)['pixel_values'][0]
170
+ for image_patch in image_patches]
171
+ else:
172
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
173
+ image_patches = [image_original_resize] + patches
174
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
175
+ for image_patch in image_patches]
176
+
177
+ return torch.stack(image_patches, dim=0)
178
+
179
+
180
+ def process_images(images, image_processor, model_cfg, image_process_func: Optional[Callable] = None):
181
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
182
+ new_images = []
183
+ if image_aspect_ratio == 'pad':
184
+ for image in images:
185
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
186
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
187
+ new_images.append(image)
188
+ elif image_aspect_ratio == "anyres":
189
+ # image_processor(images, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])['pixel_values']
190
+ for image in images:
191
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints, image_process_func=image_process_func)
192
+ new_images.append(image)
193
+ else:
194
+ return image_processor(images, return_tensors='pt')['pixel_values']
195
+ if all(x.shape == new_images[0].shape for x in new_images):
196
+ new_images = torch.stack(new_images, dim=0)
197
+ return new_images
198
  # function to generate the mask
199
  def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
200
  """
 
221
  if mask is not None:
222
  coor_mask = coor_mask * mask
223
 
224
+ # convert to torch tensor and ensure it contains non-zero values
225
  coor_mask = torch.from_numpy(coor_mask)
226
  assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
227
 
228
+
229
  return coor_mask
230
 
231
 
232
+ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_gemma", conv_mode="ferret_gemma_instruct", add_region_feature=False):
233
  img = Image.open(image_path).convert('RGB')
234
 
235
  # this loads the model, image processor and tokenizer
236
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
237
+ # define the image size required by clip
 
238
  image_size = {"height": 336, "width": 336}
239
 
240
+ if "<image>" in prompt:
241
+ prompt = prompt.split('\n')[1]
 
 
 
 
 
 
242
 
243
+ if model.config.mm_use_im_start_end:
244
+ prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
245
+ else:
246
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
247
 
248
  # generate the prompt per template requirement
249
  conv = conv_templates[conv_mode].copy()
 
251
  conv.append_message(conv.roles[1], None)
252
  prompt_input = conv.get_prompt()
253
 
 
254
  input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
255
 
256
+ # raw_w, raw_h = img.size # check if shouldnt be width and height
257
+ raw_w = image_size["width"]
258
+ raw_h = image_size["height"]
259
+ if model.config.image_aspect_ratio == "square_nocrop":
260
+ image_tensor = image_processor.preprocess(img, return_tensors='pt', do_resize=True,
261
+ do_center_crop=False, size=[raw_h, raw_w])['pixel_values'][0]
262
+ elif model.config.image_aspect_ratio == "anyres":
263
+ image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[raw_h, raw_h])
264
+ image_tensor = process_images([img], image_processor, model.config, image_process_func=image_process_func)[0]
265
+ else:
266
+ image_tensor = process_images([img], image_processor, model.config)[0]
267
+
268
+ images = image_tensor.unsqueeze(0).to(torch.float16).cuda()
269
+
270
+
271
+
272
  # region mask logic (if region is provided)
273
  region_masks = None
274
+ if add_region_feature and region is not None:
275
+ # box_in is true
276
  raw_w, raw_h = img.size
277
+ ratio_w = VOCAB_IMAGE_W * 1.0 / raw_w
278
+ ratio_h = VOCAB_IMAGE_H * 1.0 / raw_h
279
+ # preprocess the region
280
+ box_x1, box_y1, box_x2, box_y2 = region
281
+ box_x1_textvocab, box_y1_textvocab, box_x2_textvocab, box_y2_textvocab = get_bbox_coor(box=region, ratio_h=ratio_h, ratio_w=ratio_w)
282
+ region_coordinate_raw = [box_x1, box_y1, box_x2, box_y2]
283
 
284
+ region_masks = generate_mask_for_feature(region_coordinate_raw, raw_w, raw_h).unsqueeze(0).cuda().half()
285
+ region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
286
+ prompt_input = prompt_input.replace("<bbox_location0>", f"[{box_x1_textvocab}, {box_y1_textvocab}, {box_x2_textvocab}, {box_y2_textvocab}] {DEFAULT_REGION_FEA_TOKEN}")
287
+
288
+ # tokenize prompt
289
+ # input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
290
+
291
+
292
+
293
  # generate model output
294
  with torch.inference_mode():
295
  # Use region_masks in model's forward call
 
298
  model.orig_forward,
299
  region_masks=region_masks
300
  )
301
+ # explcit add of attention mask
302
  output_ids = model.generate(
303
  input_ids,
304
+ images=images,
305
  max_new_tokens=1024,
306
  num_beams=1,
307
  region_masks=region_masks, # pass the region mask to the model
 
314
  return output_text.strip()
315
 
316
  # We also define a task-specific inference function
317
+ def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_feature=False):
318
+ # region = torch.tensor(region).cuda()
319
  """
320
  Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
321
  """
322
+ if region is not None:
323
+ add_region_feature=True
324
  if task in box_in_tasks and region is None:
325
  raise ValueError(f"Task {task} requires a bounding box region.")
326
 
327
  if task in box_in_tasks:
328
  print(f"Processing {task} with bounding box region.")
329
+ return infer_single_prompt(image_path, prompt, model_path, region, add_region_feature=add_region_feature)
330
 
331
  elif task in box_out_tasks:
332
  print(f"Processing {task} without bounding box region.")
 
337
  return infer_single_prompt(image_path, prompt, model_path)
338
 
339
  else:
340
+ raise ValueError(f"Unknown task type: {task}")