Abhaykoul commited on
Commit
33d3d7d
·
verified ·
1 Parent(s): 14e1728

Update processing_llava.py

Browse files
Files changed (1) hide show
  1. processing_llava.py +97 -47
processing_llava.py CHANGED
@@ -13,12 +13,16 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """
16
- Processor class for Llava.
17
  """
18
 
19
 
 
20
  from typing import List, Optional, Union
21
 
 
 
 
22
  from transformers.feature_extraction_utils import BatchFeature
23
  from transformers.image_utils import ImageInput
24
  from transformers.tokenization_utils_base import (
@@ -28,52 +32,73 @@ from transformers.tokenization_utils_base import (
28
  TruncationStrategy,
29
  )
30
  from transformers.utils import TensorType
31
- import torch
32
- from open_clip.transform import PreprocessCfg, image_transform_v2
33
- from modeling_llava import LlavaForConditionalGeneration
34
- from PIL import Image
35
- import math
36
 
37
 
38
- class OpenCLIPImageProcessor:
39
- def __init__(self, config, crop_size=384, max_tokens=100):
40
- cfg = PreprocessCfg(**config)
41
- transform = image_transform_v2(cfg=cfg, is_train=False)
42
- self.transform = transform
43
- self.crop_size = crop_size
44
- self.max_tokens = max_tokens
45
 
46
- def __call__(self, image: Image.Image):
47
- output = self.transform_func(image)
48
- return {
49
- "pixel_values": output,
 
 
 
 
50
  }
 
 
 
 
 
 
 
51
 
52
- def transform_func(self, image: Image.Image):
 
 
 
 
53
  outputs = []
54
- outputs.append(self.transform(image))
 
 
55
  width, height = image.size
56
  crop_size = self.crop_size
57
- if width <= crop_size and height <= crop_size:
58
- outputs = torch.stack(outputs, dim=0)
59
- return outputs
 
 
 
 
 
 
60
  total_tokens = math.inf
61
- while total_tokens > self.max_tokens:
62
- total_tokens = math.floor(
63
- (2 * width - crop_size)
64
- / crop_size
65
- * (2 * height - crop_size)
66
- / crop_size
67
  )
68
- if total_tokens > self.max_tokens:
69
  crop_size += 10
70
- stride = crop_size // 2
71
- x_steps = int(round((2 * width - crop_size) / crop_size))
 
72
  if x_steps < 1:
73
  x_steps = 1
74
- y_steps = int(round((2 * height - crop_size) / crop_size))
75
  if y_steps < 1:
76
  y_steps = 1
 
 
 
 
77
  x_coords = []
78
  y_coords = []
79
  for i in range(x_steps):
@@ -85,6 +110,7 @@ class OpenCLIPImageProcessor:
85
  if y_coords[-1][1] != height:
86
  y_coords[-1][1] = height
87
  image_parts = []
 
88
  for i in range(len(x_coords)):
89
  for j in range(len(y_coords)):
90
  image_parts.append(
@@ -92,20 +118,38 @@ class OpenCLIPImageProcessor:
92
  (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1])
93
  )
94
  )
 
 
 
 
 
 
 
 
95
  for image_part in image_parts:
96
- outputs.append(self.transform(image_part))
97
- outputs = torch.stack(outputs, dim=0)
98
- return outputs
 
 
 
99
 
100
- @property
101
- def model_input_names(self):
102
- return ["pixel_values"]
103
 
 
 
 
 
104
 
105
- class LlavaProcessor:
106
- def __init__(self, image_processor: OpenCLIPImageProcessor, tokenizer):
107
  self.image_processor = image_processor
108
  self.tokenizer = tokenizer
 
 
 
 
 
 
 
109
 
110
  def __call__(
111
  self,
@@ -113,20 +157,24 @@ class LlavaProcessor:
113
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
114
  ] = None,
115
  images: ImageInput = None,
116
- model: LlavaForConditionalGeneration = None,
 
 
117
  padding: Union[bool, str, PaddingStrategy] = False,
118
  truncation: Union[bool, str, TruncationStrategy] = None,
119
  max_length=None,
120
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
121
  ) -> BatchFeature:
122
  if images is not None:
123
- pixel_values = self.image_processor(images)[
124
- "pixel_values"
 
 
125
  ]
126
- pixel_values = pixel_values.to(model.device).to(model.dtype)
127
- image_outputs = model.vision_model(pixel_values)
 
128
  image_features = model.multi_modal_projector(image_outputs)
129
- image_features = image_features.unsqueeze(0)
130
  else:
131
  image_features = None
132
  text_inputs = self.tokenizer(
@@ -136,7 +184,8 @@ class LlavaProcessor:
136
  truncation=truncation,
137
  max_length=max_length,
138
  )
139
-
 
140
  return BatchFeature(data={**text_inputs, "image_features": image_features})
141
 
142
  def batch_decode(self, *args, **kwargs):
@@ -150,3 +199,4 @@ class LlavaProcessor:
150
  tokenizer_input_names = self.tokenizer.model_input_names
151
  image_processor_input_names = self.image_processor.model_input_names
152
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """
16
+ Processor class for HelpingAI-V.
17
  """
18
 
19
 
20
+ import math
21
  from typing import List, Optional, Union
22
 
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import ImageProcessingMixin, ProcessorMixin, SiglipImageProcessor, AutoTokenizer, AutoImageProcessor
26
  from transformers.feature_extraction_utils import BatchFeature
27
  from transformers.image_utils import ImageInput
28
  from transformers.tokenization_utils_base import (
 
32
  TruncationStrategy,
33
  )
34
  from transformers.utils import TensorType
 
 
 
 
 
35
 
36
 
37
+ class MultiCropImageProcessor(ImageProcessingMixin):
38
+ def __init__(self, model_name, max_crops=0, **kwargs):
39
+ self.processor = SiglipImageProcessor.from_pretrained(model_name)
40
+ self.crop_size = 384
41
+ self.max_crops = max_crops
42
+ self.stride_ratio = 2
 
43
 
44
+ def __call__(
45
+ self,
46
+ images: List[Image.Image],
47
+ max_crops: int = -1,
48
+ ):
49
+ res = {
50
+ "pixel_values": [],
51
+ "coords": [],
52
  }
53
+ if max_crops < 0:
54
+ max_crops = self.max_crops
55
+ for image in images:
56
+ outputs, output_coords = self.process_image(image, max_crops)
57
+ res["pixel_values"].append(outputs)
58
+ res["coords"].append(output_coords)
59
+ return res
60
 
61
+ def process_image(
62
+ self,
63
+ image: Image.Image,
64
+ max_crops: int
65
+ ):
66
  outputs = []
67
+ output_coords = []
68
+ outputs.append(self.processor(image, return_tensors="pt").pixel_values)
69
+ output_coords.append(torch.tensor([0.5, 0.5]))
70
  width, height = image.size
71
  crop_size = self.crop_size
72
+ stride = crop_size // self.stride_ratio
73
+ if (
74
+ max_crops == 0
75
+ or width <= (crop_size + stride)
76
+ and height <= (crop_size + stride)
77
+ ):
78
+ outputs = torch.cat(outputs, dim=0)
79
+ output_coords = torch.cat(output_coords, dim=0)
80
+ return outputs, output_coords
81
  total_tokens = math.inf
82
+ while total_tokens > max_crops:
83
+ total_tokens = (
84
+ math.floor((width - crop_size) / stride) + 1
85
+ ) * (
86
+ math.floor((height - crop_size) / stride) + 1
 
87
  )
88
+ if total_tokens > max_crops:
89
  crop_size += 10
90
+ stride = crop_size // self.stride_ratio
91
+ stride = crop_size // self.stride_ratio
92
+ x_steps = int(math.floor((width - crop_size) / stride) + 1)
93
  if x_steps < 1:
94
  x_steps = 1
95
+ y_steps = int(math.floor((height - crop_size) / stride) + 1)
96
  if y_steps < 1:
97
  y_steps = 1
98
+ if x_steps == 1 and y_steps == 1:
99
+ outputs = torch.cat(outputs, dim=0)
100
+ output_coords = torch.cat(output_coords, dim=0)
101
+ return outputs, output_coords
102
  x_coords = []
103
  y_coords = []
104
  for i in range(x_steps):
 
110
  if y_coords[-1][1] != height:
111
  y_coords[-1][1] = height
112
  image_parts = []
113
+ part_coords = []
114
  for i in range(len(x_coords)):
115
  for j in range(len(y_coords)):
116
  image_parts.append(
 
118
  (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1])
119
  )
120
  )
121
+ part_coords.append(
122
+ torch.tensor(
123
+ [
124
+ (x_coords[i][0] + x_coords[i][1]) / 2 / width,
125
+ (y_coords[j][0] + y_coords[j][1]) / 2 / height,
126
+ ]
127
+ )
128
+ )
129
  for image_part in image_parts:
130
+ outputs.append(self.processor(image_part, return_tensors="pt").pixel_values)
131
+ for part_coord in part_coords:
132
+ output_coords.append(part_coord)
133
+ outputs = torch.cat(outputs, dim=0)
134
+ output_coords = torch.stack(output_coords, dim=0)
135
+ return outputs, output_coords
136
 
 
 
 
137
 
138
+ class LlavaProcessor(ProcessorMixin):
139
+ attributes = ["image_processor", "tokenizer"]
140
+ image_processor_class = MultiCropImageProcessor
141
+ tokenizer_class = "SiglipTokenizer"
142
 
143
+ def __init__(self, image_processor: MultiCropImageProcessor, tokenizer):
 
144
  self.image_processor = image_processor
145
  self.tokenizer = tokenizer
146
+ self.search_model = None
147
+
148
+ @classmethod
149
+ def from_pretrained(cls, path, trust_remote_code=True, **kwargs):
150
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code)
151
+ image_processor = MultiCropImageProcessor(path, trust_remote_code=trust_remote_code)
152
+ return LlavaProcessor(image_processor, tokenizer)
153
 
154
  def __call__(
155
  self,
 
157
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
158
  ] = None,
159
  images: ImageInput = None,
160
+ model = None,
161
+ max_crops: int = 0,
162
+ num_tokens = None,
163
  padding: Union[bool, str, PaddingStrategy] = False,
164
  truncation: Union[bool, str, TruncationStrategy] = None,
165
  max_length=None,
166
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
167
  ) -> BatchFeature:
168
  if images is not None:
169
+ processor_outputs = self.image_processor(images, max_crops)
170
+ pixel_values = processor_outputs["pixel_values"]
171
+ pixel_values = [
172
+ value.to(model.device).to(model.dtype) for value in pixel_values
173
  ]
174
+ coords = processor_outputs["coords"]
175
+ coords = [value.to(model.device).to(model.dtype) for value in coords]
176
+ image_outputs = model.vision_model(pixel_values, coords, num_tokens)
177
  image_features = model.multi_modal_projector(image_outputs)
 
178
  else:
179
  image_features = None
180
  text_inputs = self.tokenizer(
 
184
  truncation=truncation,
185
  max_length=max_length,
186
  )
187
+ text_inputs['input_ids'] = text_inputs['input_ids'].to(model.device)
188
+ text_inputs['attention_mask'] = text_inputs['attention_mask'].to(model.device)
189
  return BatchFeature(data={**text_inputs, "image_features": image_features})
190
 
191
  def batch_decode(self, *args, **kwargs):
 
199
  tokenizer_input_names = self.tokenizer.model_input_names
200
  image_processor_input_names = self.image_processor.model_input_names
201
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
202
+