aria-dev commited on
Commit
9710af0
·
1 Parent(s): 0531a03

support spliting image

Browse files
Files changed (2) hide show
  1. processing_aria.py +25 -10
  2. vision_processor.py +119 -6
processing_aria.py CHANGED
@@ -18,6 +18,7 @@
18
  # under the License.
19
 
20
  import inspect
 
21
  from typing import List, Optional, Union
22
 
23
  from transformers import AutoTokenizer, BatchFeature
@@ -61,7 +62,7 @@ class AriaProcessor(ProcessorMixin):
61
  super().__init__(chat_template=chat_template)
62
 
63
  if image_processor is None:
64
- self.image_processor = AriaVisionProcessor(image_max_size=patch_size)
65
  else:
66
  self.image_processor = image_processor
67
 
@@ -87,6 +88,7 @@ class AriaProcessor(ProcessorMixin):
87
  truncation: Union[bool, str, TruncationStrategy] = None,
88
  max_length: Optional[int] = None,
89
  max_image_size: Optional[int] = 980,
 
90
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
91
  ) -> BatchFeature:
92
  """
@@ -114,6 +116,8 @@ class AriaProcessor(ProcessorMixin):
114
  Maximum length of the returned list and optionally padding length (see above).
115
  max_image_size (`int`, *optional*):
116
  Maximum size of the image to be processed.
 
 
117
  truncation (`bool`, *optional*):
118
  Activates truncation to cut input sequences longer than `max_length` to `max_length`.
119
  return_tensors (`str` or [`~utils.TensorType`], *optional*):
@@ -134,24 +138,35 @@ class AriaProcessor(ProcessorMixin):
134
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
135
  - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
136
  """
 
 
 
 
 
 
 
137
  if images is not None:
138
  image_inputs = self.image_processor(
139
  images,
140
  return_tensors=return_tensors,
141
  max_image_size=max_image_size,
 
142
  )
 
 
 
 
 
 
 
 
 
 
 
 
143
  else:
144
  image_inputs = {}
145
 
146
- if isinstance(text, str):
147
- text = [text]
148
- elif not isinstance(text, list) and not isinstance(text[0], str):
149
- raise ValueError(
150
- "Invalid input text. Please provide a string, or a list of strings"
151
- )
152
-
153
- prompt_strings = text
154
-
155
  text_inputs = self.tokenizer(
156
  prompt_strings,
157
  return_tensors=return_tensors,
 
18
  # under the License.
19
 
20
  import inspect
21
+ import re
22
  from typing import List, Optional, Union
23
 
24
  from transformers import AutoTokenizer, BatchFeature
 
62
  super().__init__(chat_template=chat_template)
63
 
64
  if image_processor is None:
65
+ self.image_processor = AriaVisionProcessor(max_image_size=patch_size)
66
  else:
67
  self.image_processor = image_processor
68
 
 
88
  truncation: Union[bool, str, TruncationStrategy] = None,
89
  max_length: Optional[int] = None,
90
  max_image_size: Optional[int] = 980,
91
+ split_image: Optional[bool] = False,
92
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
93
  ) -> BatchFeature:
94
  """
 
116
  Maximum length of the returned list and optionally padding length (see above).
117
  max_image_size (`int`, *optional*):
118
  Maximum size of the image to be processed.
119
+ split_image (`bool`, *optional*):
120
+ Whether to split the image into patches before processing.
121
  truncation (`bool`, *optional*):
122
  Activates truncation to cut input sequences longer than `max_length` to `max_length`.
123
  return_tensors (`str` or [`~utils.TensorType`], *optional*):
 
138
  - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
139
  - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
140
  """
141
+ if isinstance(text, str):
142
+ text = [text]
143
+ elif not isinstance(text, list) and not isinstance(text[0], str):
144
+ raise ValueError(
145
+ "Invalid input text. Please provide a string, or a list of strings"
146
+ )
147
+
148
  if images is not None:
149
  image_inputs = self.image_processor(
150
  images,
151
  return_tensors=return_tensors,
152
  max_image_size=max_image_size,
153
+ split_image=split_image,
154
  )
155
+ # expand the image_token according to the num_crops of image
156
+ prompt_strings = []
157
+ crop_iter = iter(image_inputs.pop("num_crops"))
158
+ for prompt in text:
159
+ prompt_strings.append(
160
+ re.sub(
161
+ re.escape(self.image_token),
162
+ lambda _: next(crop_iter) * self.image_token,
163
+ prompt,
164
+ )
165
+ )
166
+
167
  else:
168
  image_inputs = {}
169
 
 
 
 
 
 
 
 
 
 
170
  text_inputs = self.tokenizer(
171
  prompt_strings,
172
  return_tensors=return_tensors,
vision_processor.py CHANGED
@@ -19,12 +19,93 @@
19
 
20
  from typing import List, Optional, Union
21
 
 
22
  import torch
23
  from PIL import Image, ImageOps
24
  from torchvision import transforms
25
  from transformers import BaseImageProcessor, BatchFeature, TensorType
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def keep_ratio_resize_and_pixel_mask(
29
  img: Image.Image, max_size, min_size=336, padding_value=0
30
  ):
@@ -127,6 +208,17 @@ class AriaVisionProcessor(BaseImageProcessor):
127
  max_image_size: Optional[int] = 980,
128
  min_image_size: Optional[int] = 336,
129
  return_tensors: Optional[Union[str, TensorType]] = "pt",
 
 
 
 
 
 
 
 
 
 
 
130
  ):
131
  """
132
  Process a list of images.
@@ -135,6 +227,8 @@ class AriaVisionProcessor(BaseImageProcessor):
135
  images (list): List of PIL.Image objects.
136
  max_image_size (int, optional): Override the default max image size. Defaults to None.
137
  return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt".
 
 
138
  Returns:
139
  BatchFeature: A BatchFeature object containing:
140
  - 'pixel_values': Tensor of processed image pixel values.
@@ -142,6 +236,7 @@ class AriaVisionProcessor(BaseImageProcessor):
142
  - True (1) values indicate pixels that belong to the original resized image.
143
  - False (0) values indicate pixels that are part of the padding.
144
  The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
 
145
  """
146
  max_size = self.max_image_size if max_image_size is None else max_image_size
147
  min_size = self.min_image_size if min_image_size is None else min_image_size
@@ -154,19 +249,24 @@ class AriaVisionProcessor(BaseImageProcessor):
154
 
155
  pixel_values = []
156
  pixel_masks = []
 
157
 
158
  for image in images:
159
- img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
160
- image, max_size, min_size
161
- )
162
- img_padded = self.transform(img_padded)
163
- pixel_values.append(img_padded)
164
- pixel_masks.append(pixel_mask)
 
 
 
165
 
166
  return BatchFeature(
167
  data={
168
  "pixel_values": torch.stack(pixel_values),
169
  "pixel_mask": torch.stack(pixel_masks),
 
170
  },
171
  tensor_type=return_tensors,
172
  )
@@ -177,10 +277,23 @@ class AriaVisionProcessor(BaseImageProcessor):
177
  max_image_size=None,
178
  min_image_size=None,
179
  return_tensors: Optional[Union[str, TensorType]] = None,
 
 
 
 
 
 
 
 
 
 
 
180
  ):
181
  return self.__call__(
182
  images,
183
  max_image_size=max_image_size,
184
  min_image_size=min_image_size,
185
  return_tensors=return_tensors,
 
 
186
  )
 
19
 
20
  from typing import List, Optional, Union
21
 
22
+ import numpy as np
23
  import torch
24
  from PIL import Image, ImageOps
25
  from torchvision import transforms
26
  from transformers import BaseImageProcessor, BatchFeature, TensorType
27
 
28
 
29
+ def _select_best_resolution(
30
+ img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int
31
+ ):
32
+ """
33
+ Selects the best resolution from a list of possible resolutions based on the original size.
34
+
35
+ Args:
36
+ img_width: the original widths of images.
37
+ img_height: the original heights of images.
38
+ target_ratios (2d numpy array): dimension size (M,2)
39
+ patch_size (int): image patch size
40
+
41
+ Returns:
42
+ tuple: The best fit resolution in the format (width, height).
43
+ """
44
+
45
+ aspect_ratio = img_width / img_height
46
+ best_ratio_diff = float("inf")
47
+ best_ratio_w, best_ratio_h = 1, 1
48
+ area = np.int32(img_height) * np.int32(img_height)
49
+ for ratio in target_ratios:
50
+ target_aspect_ratio = ratio[0] / ratio[1]
51
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
52
+ if ratio_diff < best_ratio_diff:
53
+ best_ratio_diff = ratio_diff
54
+ best_ratio_w, best_ratio_h = ratio[0], ratio[1]
55
+ elif (
56
+ ratio_diff == best_ratio_diff
57
+ and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1]
58
+ ):
59
+ best_ratio_w, best_ratio_h = ratio[0], ratio[1]
60
+
61
+ return best_ratio_w, best_ratio_h
62
+
63
+
64
+ def _split_image(
65
+ image: Image.Image,
66
+ split_image: bool,
67
+ split_ratio: List[List[int]],
68
+ patch_size: int,
69
+ ) -> List[Image.Image]:
70
+ """
71
+ Split image into multiple patches
72
+
73
+ Args:
74
+ image (PIL.Image): Input image.
75
+ split_image (bool): Whether to split the image into patches.
76
+ split_ratio (2d numpy array): dimension size (M,2)
77
+ patch_size (int): image patch size
78
+
79
+ Returns:
80
+ List[PIL.Image]: List of splitted images.
81
+ """
82
+ if split_image:
83
+ ratio_width, ratio_height = _select_best_resolution(
84
+ image.width, image.height, split_ratio, patch_size
85
+ )
86
+ resize_width = patch_size * ratio_width
87
+ resize_height = patch_size * ratio_height
88
+ blocks = ratio_width * ratio_height
89
+ resized_img = image.resize((resize_width, resize_height))
90
+ processed_images = []
91
+ for i in range(blocks):
92
+ box = (
93
+ (i % (resize_width // patch_size)) * patch_size,
94
+ (i // (resize_width // patch_size)) * patch_size,
95
+ ((i % (resize_width // patch_size)) + 1) * patch_size,
96
+ ((i // (resize_width // patch_size)) + 1) * patch_size,
97
+ )
98
+ # split the image
99
+ split_img = resized_img.crop(box)
100
+ processed_images.append(split_img)
101
+ assert len(processed_images) == blocks
102
+ if len(processed_images) != 1:
103
+ processed_images.insert(0, image)
104
+ return processed_images
105
+ else:
106
+ return [image]
107
+
108
+
109
  def keep_ratio_resize_and_pixel_mask(
110
  img: Image.Image, max_size, min_size=336, padding_value=0
111
  ):
 
208
  max_image_size: Optional[int] = 980,
209
  min_image_size: Optional[int] = 336,
210
  return_tensors: Optional[Union[str, TensorType]] = "pt",
211
+ split_image: Optional[bool] = False,
212
+ split_ratio: Optional[List[List[int]]] = [
213
+ [1, 1],
214
+ [1, 2],
215
+ [1, 3],
216
+ [1, 4],
217
+ [2, 2],
218
+ [2, 1],
219
+ [3, 1],
220
+ [4, 1],
221
+ ],
222
  ):
223
  """
224
  Process a list of images.
 
227
  images (list): List of PIL.Image objects.
228
  max_image_size (int, optional): Override the default max image size. Defaults to None.
229
  return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt".
230
+ split_image (bool, optional): Whether to split the image. Defaults to False.
231
+ split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios.
232
  Returns:
233
  BatchFeature: A BatchFeature object containing:
234
  - 'pixel_values': Tensor of processed image pixel values.
 
236
  - True (1) values indicate pixels that belong to the original resized image.
237
  - False (0) values indicate pixels that are part of the padding.
238
  The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
239
+ - 'num_crops': Tensor of the number of crops for each image.
240
  """
241
  max_size = self.max_image_size if max_image_size is None else max_image_size
242
  min_size = self.min_image_size if min_image_size is None else min_image_size
 
249
 
250
  pixel_values = []
251
  pixel_masks = []
252
+ num_crops = []
253
 
254
  for image in images:
255
+ crop_images = _split_image(image, split_image, split_ratio, max_size)
256
+ num_crops.append(torch.tensor(len(crop_images)))
257
+ for crop_image in crop_images:
258
+ img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
259
+ crop_image, max_size, min_size
260
+ )
261
+ img_padded = self.transform(img_padded)
262
+ pixel_values.append(img_padded)
263
+ pixel_masks.append(pixel_mask)
264
 
265
  return BatchFeature(
266
  data={
267
  "pixel_values": torch.stack(pixel_values),
268
  "pixel_mask": torch.stack(pixel_masks),
269
+ "num_crops": torch.stack(num_crops),
270
  },
271
  tensor_type=return_tensors,
272
  )
 
277
  max_image_size=None,
278
  min_image_size=None,
279
  return_tensors: Optional[Union[str, TensorType]] = None,
280
+ split_image: Optional[bool] = False,
281
+ split_ratio: Optional[List[List[int]]] = [
282
+ [1, 1],
283
+ [1, 2],
284
+ [1, 3],
285
+ [1, 4],
286
+ [2, 2],
287
+ [2, 1],
288
+ [3, 1],
289
+ [4, 1],
290
+ ],
291
  ):
292
  return self.__call__(
293
  images,
294
  max_image_size=max_image_size,
295
  min_image_size=min_image_size,
296
  return_tensors=return_tensors,
297
+ split_image=split_image,
298
+ split_ratio=split_ratio,
299
  )