Mageia commited on
Commit
71819c7
·
unverified ·
1 Parent(s): e5bb317

fix: format .py

Browse files
config.json CHANGED
@@ -35,4 +35,4 @@
35
  "use_im_start_end": true,
36
  "use_sliding_window": false,
37
  "vocab_size": 151860
38
- }
 
35
  "use_im_start_end": true,
36
  "use_sliding_window": false,
37
  "vocab_size": 151860
38
+ }
generation_config.json CHANGED
@@ -3,4 +3,4 @@
3
  "eos_token_id": 151643,
4
  "max_new_tokens": 2048,
5
  "transformers_version": "4.37.2"
6
- }
 
3
  "eos_token_id": 151643,
4
  "max_new_tokens": 2048,
5
  "transformers_version": "4.37.2"
6
+ }
got_vision_b.py CHANGED
@@ -1,10 +1,9 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from typing import Optional, Tuple, Type
4
  from functools import partial
5
- import torch.nn as nn
6
- from typing import Type
7
 
 
 
 
8
 
9
 
10
  class MLPBlock(nn.Module):
@@ -23,7 +22,6 @@ class MLPBlock(nn.Module):
23
  return self.lin2(self.act(self.lin1(x)))
24
 
25
 
26
-
27
  class LayerNorm2d(nn.Module):
28
  def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
29
  super().__init__()
@@ -39,7 +37,6 @@ class LayerNorm2d(nn.Module):
39
  return x
40
 
41
 
42
-
43
  class ImageEncoderViT(nn.Module):
44
  def __init__(
45
  self,
@@ -91,9 +88,7 @@ class ImageEncoderViT(nn.Module):
91
  self.pos_embed: Optional[nn.Parameter] = None
92
  if use_abs_pos:
93
  # Initialize absolute positional embedding with pretrain image size.
94
- self.pos_embed = nn.Parameter(
95
- torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
96
- )
97
 
98
  self.blocks = nn.ModuleList()
99
  for i in range(depth):
@@ -129,7 +124,6 @@ class ImageEncoderViT(nn.Module):
129
  LayerNorm2d(out_chans),
130
  )
131
 
132
-
133
  self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
  self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
 
@@ -145,7 +139,6 @@ class ImageEncoderViT(nn.Module):
145
  x = self.net_2(x)
146
  x = self.net_3(x)
147
 
148
-
149
  return x
150
 
151
 
@@ -247,9 +240,7 @@ class Attention(nn.Module):
247
 
248
  self.use_rel_pos = use_rel_pos
249
  if self.use_rel_pos:
250
- assert (
251
- input_size is not None
252
- ), "Input size must be provided if using relative positional encoding."
253
  # initialize relative positional embeddings
254
  self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
255
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
@@ -297,9 +288,7 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
297
  return windows, (Hp, Wp)
298
 
299
 
300
- def window_unpartition(
301
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
- ) -> torch.Tensor:
303
  """
304
  Window unpartition into original sequences and removing padding.
305
  Args:
@@ -385,9 +374,7 @@ def add_decomposed_rel_pos(
385
  rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
386
  rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
387
 
388
- attn = (
389
- attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
390
- ).view(B, q_h * q_w, k_h * k_w)
391
 
392
  return attn
393
 
@@ -415,9 +402,7 @@ class PatchEmbed(nn.Module):
415
  """
416
  super().__init__()
417
 
418
- self.proj = nn.Conv2d(
419
- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
- )
421
 
422
  def forward(self, x: torch.Tensor) -> torch.Tensor:
423
  x = self.proj(x)
@@ -426,7 +411,6 @@ class PatchEmbed(nn.Module):
426
  return x
427
 
428
 
429
-
430
  def build_GOT_vit_b(checkpoint=None):
431
  return _build_GOT_vision(
432
  encoder_embed_dim=768,
@@ -448,21 +432,19 @@ def _build_GOT_vision(
448
  image_size = 1024
449
  vit_patch_size = 16
450
  image_embedding_size = image_size // vit_patch_size
451
- image_encoder=ImageEncoderViT(
452
- depth=encoder_depth,
453
- embed_dim=encoder_embed_dim,
454
- img_size=image_size,
455
- mlp_ratio=4,
456
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
457
- num_heads=encoder_num_heads,
458
- patch_size=vit_patch_size,
459
- qkv_bias=True,
460
- use_rel_pos=True,
461
- global_attn_indexes=encoder_global_attn_indexes,
462
- window_size=14,
463
- out_chans=prompt_embed_dim,
464
- )
465
-
466
 
467
  return image_encoder
468
-
 
 
 
 
1
  from functools import partial
2
+ from typing import Optional, Tuple, Type
 
3
 
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
 
8
 
9
  class MLPBlock(nn.Module):
 
22
  return self.lin2(self.act(self.lin1(x)))
23
 
24
 
 
25
  class LayerNorm2d(nn.Module):
26
  def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
27
  super().__init__()
 
37
  return x
38
 
39
 
 
40
  class ImageEncoderViT(nn.Module):
41
  def __init__(
42
  self,
 
88
  self.pos_embed: Optional[nn.Parameter] = None
89
  if use_abs_pos:
90
  # Initialize absolute positional embedding with pretrain image size.
91
+ self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
 
 
92
 
93
  self.blocks = nn.ModuleList()
94
  for i in range(depth):
 
124
  LayerNorm2d(out_chans),
125
  )
126
 
 
127
  self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
128
  self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
129
 
 
139
  x = self.net_2(x)
140
  x = self.net_3(x)
141
 
 
142
  return x
143
 
144
 
 
240
 
241
  self.use_rel_pos = use_rel_pos
242
  if self.use_rel_pos:
243
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
 
 
244
  # initialize relative positional embeddings
245
  self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
246
  self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
 
288
  return windows, (Hp, Wp)
289
 
290
 
291
+ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) -> torch.Tensor:
 
 
292
  """
293
  Window unpartition into original sequences and removing padding.
294
  Args:
 
374
  rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
375
  rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
376
 
377
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
 
 
378
 
379
  return attn
380
 
 
402
  """
403
  super().__init__()
404
 
405
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
 
 
406
 
407
  def forward(self, x: torch.Tensor) -> torch.Tensor:
408
  x = self.proj(x)
 
411
  return x
412
 
413
 
 
414
  def build_GOT_vit_b(checkpoint=None):
415
  return _build_GOT_vision(
416
  encoder_embed_dim=768,
 
432
  image_size = 1024
433
  vit_patch_size = 16
434
  image_embedding_size = image_size // vit_patch_size
435
+ image_encoder = ImageEncoderViT(
436
+ depth=encoder_depth,
437
+ embed_dim=encoder_embed_dim,
438
+ img_size=image_size,
439
+ mlp_ratio=4,
440
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
441
+ num_heads=encoder_num_heads,
442
+ patch_size=vit_patch_size,
443
+ qkv_bias=True,
444
+ use_rel_pos=True,
445
+ global_attn_indexes=encoder_global_attn_indexes,
446
+ window_size=14,
447
+ out_chans=prompt_embed_dim,
448
+ )
 
449
 
450
  return image_encoder
 
modeling_GOT.py CHANGED
@@ -1,27 +1,32 @@
1
- from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
3
  from typing import List, Optional, Tuple, Union
4
- from transformers.cache_utils import Cache
5
  import requests
6
- from PIL import Image
7
- from io import BytesIO
8
  import torch
9
  import torch.nn as nn
 
10
  from torch.nn import CrossEntropyLoss
11
- from .got_vision_b import build_GOT_vit_b
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
- import dataclasses
 
 
 
 
 
15
  ###
16
 
17
  DEFAULT_IMAGE_TOKEN = "<image>"
18
- DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
- DEFAULT_IM_START_TOKEN = '<img>'
20
- DEFAULT_IM_END_TOKEN = '</img>'
 
21
 
22
- from enum import auto, Enum
23
  class SeparatorStyle(Enum):
24
  """Different separator style."""
 
25
  SINGLE = auto()
26
  TWO = auto()
27
  MPT = auto()
@@ -30,6 +35,7 @@ class SeparatorStyle(Enum):
30
  @dataclasses.dataclass
31
  class Conversation:
32
  """A class that keeps all conversation history."""
 
33
  system: str
34
  roles: List[str]
35
  messages: List[List[str]]
@@ -43,7 +49,7 @@ class Conversation:
43
 
44
  def get_prompt(self):
45
  if self.sep_style == SeparatorStyle.SINGLE:
46
- ret = self.system + self.sep + '\n'
47
  for role, message in self.messages:
48
  if message:
49
  if type(message) is tuple:
@@ -65,9 +71,9 @@ class Conversation:
65
  return ret
66
  if self.sep_style == SeparatorStyle.MPT:
67
  if self.system:
68
- ret = self.system + self.sep
69
  else:
70
- ret = ''
71
  for role, message in self.messages:
72
  if message:
73
  if type(message) is tuple:
@@ -79,7 +85,6 @@ class Conversation:
79
  else:
80
  raise ValueError(f"Invalid style: {self.sep_style}")
81
 
82
-
83
  def append_message(self, role, message):
84
  self.messages.append([role, message])
85
 
@@ -91,8 +96,8 @@ class Conversation:
91
  offset=self.offset,
92
  sep_style=self.sep_style,
93
  sep=self.sep,
94
- sep2=self.sep2)
95
-
96
 
97
 
98
  class KeywordsStoppingCriteria(StoppingCriteria):
@@ -111,12 +116,12 @@ class KeywordsStoppingCriteria(StoppingCriteria):
111
  for keyword_id in self.keyword_ids:
112
  if output_ids[0, -1] == keyword_id:
113
  return True
114
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
115
  for keyword in self.keywords:
116
  if keyword in outputs:
117
  return True
118
  return False
119
-
120
 
121
  class GOTImageEvalProcessor:
122
  def __init__(self, image_size=384, mean=None, std=None):
@@ -129,18 +134,16 @@ class GOTImageEvalProcessor:
129
 
130
  self.transform = transforms.Compose(
131
  [
132
- transforms.Resize(
133
- (image_size, image_size), interpolation=InterpolationMode.BICUBIC
134
- ),
135
  transforms.ToTensor(),
136
  self.normalize,
137
  ]
138
  )
 
139
  def __call__(self, item):
140
  return self.transform(item)
141
 
142
 
143
-
144
  class GOTConfig(Qwen2Config):
145
  model_type = "GOT"
146
 
@@ -153,28 +156,24 @@ class GOTQwenModel(Qwen2Model):
153
 
154
  self.vision_tower_high = build_GOT_vit_b()
155
 
156
- self.mm_projector_vary = nn.Linear(1024, 1024)
157
-
158
 
159
  def initialize_vision_modules(
160
- self,
161
  vision_tower,
162
  pretrained_stage1_model=None,
163
  freeze_vision_tower=False,
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
- device="cuda"
168
  ):
169
-
170
-
171
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
-
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
175
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
 
177
-
178
  image_token_len = 256
179
 
180
  self.config.vision_tower = vision_tower
@@ -184,13 +183,12 @@ class GOTQwenModel(Qwen2Model):
184
 
185
  self.config.vision_select_layer = vision_select_layer
186
  self.config.freeze_vision_tower = freeze_vision_tower
187
-
188
  return dict(
189
  image_processor_high=image_processor_high,
190
  image_token_len=image_token_len,
191
  )
192
-
193
-
194
  def forward(
195
  self,
196
  input_ids: torch.LongTensor = None,
@@ -204,19 +202,16 @@ class GOTQwenModel(Qwen2Model):
204
  images: Optional[torch.FloatTensor] = None,
205
  return_dict: Optional[bool] = None,
206
  ) -> Union[Tuple, BaseModelOutputWithPast]:
207
-
208
  # HACK: replace back original embeddings for LLaVA pretraining
209
- orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
  if orig_embeds_params is not None:
211
  with torch.no_grad():
212
- self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
213
 
214
  if inputs_embeds is None:
215
  inputs_embeds = self.embed_tokens(input_ids)
216
 
217
-
218
- vision_tower_high = getattr(self, 'vision_tower_high', None)
219
-
220
 
221
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
222
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
@@ -232,15 +227,15 @@ class GOTQwenModel(Qwen2Model):
232
  im_start_token = 151857
233
 
234
  im_end_token = 151858
235
-
236
  image_features = []
237
-
238
  for image in images:
239
  P, C, H, W = image.shape
240
  if P == 1:
241
  with torch.set_grad_enabled(False):
242
  cnn_feature = vision_tower_high(image)
243
- cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
244
  image_feature = self.mm_projector_vary(cnn_feature)
245
  image_features.append(image_feature)
246
 
@@ -249,7 +244,7 @@ class GOTQwenModel(Qwen2Model):
249
  image_patches_features = []
250
  for image_patch in image_patches:
251
  image_p = torch.stack([image_patch])
252
-
253
  with torch.set_grad_enabled(False):
254
  cnn_feature_p = vision_tower_high(image_p)
255
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
@@ -258,21 +253,20 @@ class GOTQwenModel(Qwen2Model):
258
  image_feature = torch.cat(image_patches_features, dim=1)
259
  image_features.append(image_feature)
260
 
261
-
262
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
263
  dummy_image_features = dummy_image_features_2
264
  use_im_start_end = True
265
  new_input_embeds = []
266
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
267
  if (cur_input_ids == im_patch_token).sum() == 0:
268
- cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
  new_input_embeds.append(cur_input_embeds)
270
  continue
271
 
272
  if use_im_start_end:
273
  if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
274
  raise ValueError("The number of image start tokens and image end tokens should be the same.")
275
-
276
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
277
  for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
278
  per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
@@ -280,17 +274,16 @@ class GOTQwenModel(Qwen2Model):
280
 
281
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
282
  raise ValueError("The image end token should follow the image start token.")
283
-
284
  cur_input_embeds = torch.cat(
285
  (
286
- cur_input_embeds[:image_start_token_pos+1],
287
- per_cur_image_features,
288
- cur_input_embeds[image_start_token_pos + num_patches + 1:]
289
- ),
290
- dim=0
291
  )
292
 
293
-
294
  new_input_embeds.append(cur_input_embeds)
295
  else:
296
  raise NotImplementedError
@@ -298,14 +291,18 @@ class GOTQwenModel(Qwen2Model):
298
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
299
 
300
  return super(GOTQwenModel, self).forward(
301
- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
302
- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
303
- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
304
- return_dict=return_dict
 
 
 
 
 
305
  )
306
 
307
 
308
-
309
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
310
  config_class = GOTConfig
311
  # supports_gradient_checkpointing = True
@@ -336,15 +333,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
336
  output_hidden_states: Optional[bool] = None,
337
  images: Optional[torch.FloatTensor] = None,
338
  return_dict: Optional[bool] = None,
339
-
340
  ) -> Union[Tuple, CausalLMOutputWithPast]:
341
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
- output_hidden_states = (
343
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
- )
345
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
 
347
- outputs = self.model(
348
  input_ids=input_ids,
349
  past_key_values=past_key_values,
350
  attention_mask=attention_mask,
@@ -354,8 +348,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
354
  output_attentions=output_attentions,
355
  output_hidden_states=output_hidden_states,
356
  images=images,
357
- return_dict=return_dict
358
-
359
  )
360
 
361
  hidden_states = outputs[0]
@@ -389,10 +382,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
389
  attentions=outputs.attentions,
390
  )
391
 
392
-
393
- def prepare_inputs_for_generation(
394
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
- ):
396
  # Omit tokens covered by past_key_values
397
  if past_key_values is not None:
398
  if isinstance(past_key_values, Cache):
@@ -416,11 +406,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
416
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
 
418
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
- if (
420
- max_cache_length is not None
421
- and attention_mask is not None
422
- and cache_length + input_ids.shape[1] > max_cache_length
423
- ):
424
  attention_mask = attention_mask[:, -max_cache_length:]
425
 
426
  position_ids = kwargs.get("position_ids", None)
@@ -448,16 +434,9 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
448
  )
449
  return model_inputs
450
 
451
- def initialize_vision_tokenizer(
452
- self,
453
- tokenizer,
454
- freeze_lm_model=False,
455
- pretrained_stage1_model=None,
456
- device="cuda"
457
- ):
458
  config = self.get_model().config
459
 
460
-
461
  self.resize_token_embeddings(len(tokenizer))
462
 
463
  config.im_patch_token = 151859
@@ -469,11 +448,11 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
469
  config.im_start_token, config.im_end_token = 151857, 151858
470
 
471
  def load_image(self, image_file):
472
- if image_file.startswith('http') or image_file.startswith('https'):
473
  response = requests.get(image_file)
474
- image = Image.open(BytesIO(response.content)).convert('RGB')
475
  else:
476
- image = Image.open(image_file).convert('RGB')
477
  return image
478
 
479
  def disable_torch_init(self):
@@ -481,15 +460,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
481
  Disable the redundant torch default initialization to accelerate model creation.
482
  """
483
  import torch
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
488
-
 
 
 
 
 
 
 
 
 
 
 
489
  self.disable_torch_init()
490
 
491
-
492
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
 
494
  use_im_start_end = True
495
 
@@ -501,38 +491,37 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
501
  image = self.load_image(image_file)
502
 
503
  w, h = image.size
504
-
505
- if ocr_type == 'format':
506
- qs = 'OCR with format: '
507
  else:
508
- qs = 'OCR: '
509
 
510
  if ocr_box:
511
  bbox = eval(ocr_box)
512
  if len(bbox) == 2:
513
- bbox[0] = int(bbox[0]/w*1000)
514
- bbox[1] = int(bbox[1]/h*1000)
515
  if len(bbox) == 4:
516
- bbox[0] = int(bbox[0]/w*1000)
517
- bbox[1] = int(bbox[1]/h*1000)
518
- bbox[2] = int(bbox[2]/w*1000)
519
- bbox[3] = int(bbox[3]/h*1000)
520
- if ocr_type == 'format':
521
- qs = str(bbox) + ' ' + 'OCR with format: '
522
  else:
523
- qs = str(bbox) + ' ' + 'OCR: '
524
 
525
  if ocr_color:
526
- if ocr_type == 'format':
527
- qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
528
  else:
529
- qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
530
 
531
  if use_im_start_end:
532
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
533
  else:
534
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
535
-
536
 
537
  conv_mpt = Conversation(
538
  system="""<|im_start|>system
@@ -571,109 +560,113 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
  do_sample=False,
574
- num_beams = 1,
575
- no_repeat_ngram_size = 20,
576
  streamer=streamer,
577
  max_new_tokens=4096,
578
- stopping_criteria=[stopping_criteria]
579
- )
580
  else:
581
  with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
585
  do_sample=False,
586
- num_beams = 1,
587
- no_repeat_ngram_size = 20,
588
  # streamer=streamer,
589
  max_new_tokens=4096,
590
- stopping_criteria=[stopping_criteria]
591
- )
592
-
593
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
-
595
  if outputs.endswith(stop_str):
596
- outputs = outputs[:-len(stop_str)]
597
  outputs = outputs.strip()
598
  response_str = outputs
599
 
600
  if render:
601
- print('==============rendering===============')
602
- from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
603
 
604
- if '**kern' in outputs:
605
  import verovio
 
606
  tk = verovio.toolkit()
607
  tk.loadData(outputs)
608
- tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
  tk.getPageCount()
612
  svg = tk.renderToSVG()
613
- svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
614
 
615
  svg_to_html(svg, save_render_file)
616
 
617
- if ocr_type == 'format' and '**kern' not in outputs:
618
-
619
-
620
- if '\\begin{tikzpicture}' not in outputs:
621
  html_path_2 = save_render_file
622
- right_num = outputs.count('\\right')
623
- left_num = outputs.count('\left')
624
 
625
  if right_num != left_num:
626
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
627
-
 
 
 
 
 
 
 
 
 
 
628
 
629
- outputs = outputs.replace('"', '``').replace('$', '')
630
 
631
- outputs_list = outputs.split('\n')
632
- gt= ''
633
  for out in outputs_list:
634
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
-
636
- gt = gt[:-2]
637
 
 
638
 
639
  lines = content_mmd_to_html
640
  lines = lines.split("const text =")
641
- new_web = lines[0] + 'const text =' + gt + lines[1]
642
 
643
  else:
644
  html_path_2 = save_render_file
645
  outputs = outputs.translate(translation_table)
646
- outputs_list = outputs.split('\n')
647
- gt= ''
648
  for out in outputs_list:
649
  if out:
650
- if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
651
- while out[-1] == ' ':
652
  out = out[:-1]
653
  if out is None:
654
  break
655
-
656
  if out:
657
- if out[-1] != ';':
658
- gt += out[:-1] + ';\n'
659
  else:
660
- gt += out + '\n'
661
  else:
662
- gt += out + '\n'
663
-
664
 
665
  lines = tik_html
666
  lines = lines.split("const text =")
667
  new_web = lines[0] + gt + lines[1]
668
 
669
- with open(html_path_2, 'w') as web_f_new:
670
  web_f_new.write(new_web)
671
  return response_str
672
 
673
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
674
-
675
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
676
- best_ratio_diff = float('inf')
677
  best_ratio = (1, 1)
678
  area = width * height
679
  for ratio in target_ratios:
@@ -687,20 +680,19 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
687
  best_ratio = ratio
688
  # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
689
  return best_ratio
690
-
691
  orig_width, orig_height = image.size
692
  aspect_ratio = orig_width / orig_height
693
 
694
  # calculate the existing image aspect ratio
695
  target_ratios = set(
696
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
697
- i * j <= max_num and i * j >= min_num)
698
  # print(target_ratios)
699
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
700
 
701
  # find the closest aspect ratio to the target
702
- target_aspect_ratio = find_closest_aspect_ratio(
703
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
704
 
705
  # print(target_aspect_ratio)
706
  # calculate the target width and height
@@ -716,7 +708,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
716
  (i % (target_width // image_size)) * image_size,
717
  (i // (target_width // image_size)) * image_size,
718
  ((i % (target_width // image_size)) + 1) * image_size,
719
- ((i // (target_width // image_size)) + 1) * image_size
720
  )
721
  # split the image
722
  split_img = resized_img.crop(box)
@@ -727,18 +719,15 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
727
  processed_images.append(thumbnail_img)
728
  return processed_images
729
 
730
-
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
  # Model
733
  self.disable_torch_init()
734
- multi_page=False
735
-
736
 
737
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
738
 
739
  use_im_start_end = True
740
 
741
-
742
  image_token_len = 256
743
 
744
  image_list = []
@@ -747,7 +736,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
747
  # multi_page = True
748
 
749
  if multi_page:
750
- qs = 'OCR with format across multi pages: '
751
  # only for png files
752
  # import glob
753
  # from natsort import natsorted
@@ -763,10 +752,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
763
  # print("len ll: ", ll)
764
 
765
  else:
766
- if ocr_type == 'format':
767
- qs = 'OCR with format upon the patch reference: '
768
  else:
769
- qs = 'OCR upon the patch reference: '
770
  if gradio_input:
771
  img = image_file.copy()
772
  else:
@@ -778,17 +767,14 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
778
  image_tensor_1 = image_processor_high(image)
779
  image_list.append(image_tensor_1)
780
 
781
-
782
  image_list = torch.stack(image_list)
783
 
784
- print('====new images batch size======: \n',image_list.shape)
785
-
786
 
787
  if use_im_start_end:
788
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
789
  else:
790
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
-
792
 
793
  conv_mpt = Conversation(
794
  system="""<|im_start|>system
@@ -825,57 +811,68 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
825
  input_ids,
826
  images=[image_list.half().cuda()],
827
  do_sample=False,
828
- num_beams = 1,
829
  # no_repeat_ngram_size = 20,
830
  streamer=streamer,
831
  max_new_tokens=4096,
832
- stopping_criteria=[stopping_criteria]
833
- )
834
  else:
835
  with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half().cuda()],
839
  do_sample=False,
840
- num_beams = 1,
841
  # no_repeat_ngram_size = 20,
842
  # streamer=streamer,
843
  max_new_tokens=4096,
844
- stopping_criteria=[stopping_criteria]
845
- )
 
 
846
 
847
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
-
849
  if outputs.endswith(stop_str):
850
- outputs = outputs[:-len(stop_str)]
851
- outputs = outputs.strip()
852
  response_str = outputs
853
 
854
  if render:
855
- print('==============rendering===============')
856
  from .render_tools import content_mmd_to_html
 
857
  html_path_2 = save_render_file
858
- right_num = outputs.count('\\right')
859
- left_num = outputs.count('\left')
860
 
861
  if right_num != left_num:
862
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
863
-
864
-
865
- outputs = outputs.replace('"', '``').replace('$', '')
866
-
867
- outputs_list = outputs.split('\n')
868
- gt= ''
 
 
 
 
 
 
 
 
 
 
869
  for out in outputs_list:
870
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
871
-
872
  gt = gt[:-2]
873
 
874
  lines = content_mmd_to_html
875
  lines = lines.split("const text =")
876
- new_web = lines[0] + 'const text =' + gt + lines[1]
877
-
878
- with open(html_path_2, 'w') as web_f_new:
879
  web_f_new.write(new_web)
880
 
881
- return response_str
 
1
+ import dataclasses
2
+ from enum import Enum, auto
3
+ from io import BytesIO
4
  from typing import List, Optional, Tuple, Union
5
+
6
  import requests
 
 
7
  import torch
8
  import torch.nn as nn
9
+ from PIL import Image
10
  from torch.nn import CrossEntropyLoss
 
11
  from torchvision import transforms
12
  from torchvision.transforms.functional import InterpolationMode
13
+ from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model, StoppingCriteria, TextStreamer
14
+ from transformers.cache_utils import Cache
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .got_vision_b import build_GOT_vit_b
18
+
19
  ###
20
 
21
  DEFAULT_IMAGE_TOKEN = "<image>"
22
+ DEFAULT_IMAGE_PATCH_TOKEN = "<imgpad>"
23
+ DEFAULT_IM_START_TOKEN = "<img>"
24
+ DEFAULT_IM_END_TOKEN = "</img>"
25
+
26
 
 
27
  class SeparatorStyle(Enum):
28
  """Different separator style."""
29
+
30
  SINGLE = auto()
31
  TWO = auto()
32
  MPT = auto()
 
35
  @dataclasses.dataclass
36
  class Conversation:
37
  """A class that keeps all conversation history."""
38
+
39
  system: str
40
  roles: List[str]
41
  messages: List[List[str]]
 
49
 
50
  def get_prompt(self):
51
  if self.sep_style == SeparatorStyle.SINGLE:
52
+ ret = self.system + self.sep + "\n"
53
  for role, message in self.messages:
54
  if message:
55
  if type(message) is tuple:
 
71
  return ret
72
  if self.sep_style == SeparatorStyle.MPT:
73
  if self.system:
74
+ ret = self.system + self.sep
75
  else:
76
+ ret = ""
77
  for role, message in self.messages:
78
  if message:
79
  if type(message) is tuple:
 
85
  else:
86
  raise ValueError(f"Invalid style: {self.sep_style}")
87
 
 
88
  def append_message(self, role, message):
89
  self.messages.append([role, message])
90
 
 
96
  offset=self.offset,
97
  sep_style=self.sep_style,
98
  sep=self.sep,
99
+ sep2=self.sep2,
100
+ )
101
 
102
 
103
  class KeywordsStoppingCriteria(StoppingCriteria):
 
116
  for keyword_id in self.keyword_ids:
117
  if output_ids[0, -1] == keyword_id:
118
  return True
119
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len :], skip_special_tokens=True)[0]
120
  for keyword in self.keywords:
121
  if keyword in outputs:
122
  return True
123
  return False
124
+
125
 
126
  class GOTImageEvalProcessor:
127
  def __init__(self, image_size=384, mean=None, std=None):
 
134
 
135
  self.transform = transforms.Compose(
136
  [
137
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
 
 
138
  transforms.ToTensor(),
139
  self.normalize,
140
  ]
141
  )
142
+
143
  def __call__(self, item):
144
  return self.transform(item)
145
 
146
 
 
147
  class GOTConfig(Qwen2Config):
148
  model_type = "GOT"
149
 
 
156
 
157
  self.vision_tower_high = build_GOT_vit_b()
158
 
159
+ self.mm_projector_vary = nn.Linear(1024, 1024)
 
160
 
161
  def initialize_vision_modules(
162
+ self,
163
  vision_tower,
164
  pretrained_stage1_model=None,
165
  freeze_vision_tower=False,
166
  use_im_start_end=False,
167
  vision_select_layer=-1,
168
  dtype=torch.float16,
169
+ device="cuda",
170
  ):
 
 
171
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
+
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
175
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
 
 
177
  image_token_len = 256
178
 
179
  self.config.vision_tower = vision_tower
 
183
 
184
  self.config.vision_select_layer = vision_select_layer
185
  self.config.freeze_vision_tower = freeze_vision_tower
186
+
187
  return dict(
188
  image_processor_high=image_processor_high,
189
  image_token_len=image_token_len,
190
  )
191
+
 
192
  def forward(
193
  self,
194
  input_ids: torch.LongTensor = None,
 
202
  images: Optional[torch.FloatTensor] = None,
203
  return_dict: Optional[bool] = None,
204
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
205
  # HACK: replace back original embeddings for LLaVA pretraining
206
+ orig_embeds_params = getattr(self, "orig_embeds_params", None)
207
  if orig_embeds_params is not None:
208
  with torch.no_grad():
209
+ self.get_input_embeddings().weight[: -self.num_new_tokens] = orig_embeds_params[: -self.num_new_tokens].data
210
 
211
  if inputs_embeds is None:
212
  inputs_embeds = self.embed_tokens(input_ids)
213
 
214
+ vision_tower_high = getattr(self, "vision_tower_high", None)
 
 
215
 
216
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
217
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
 
227
  im_start_token = 151857
228
 
229
  im_end_token = 151858
230
+
231
  image_features = []
232
+
233
  for image in images:
234
  P, C, H, W = image.shape
235
  if P == 1:
236
  with torch.set_grad_enabled(False):
237
  cnn_feature = vision_tower_high(image)
238
+ cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
239
  image_feature = self.mm_projector_vary(cnn_feature)
240
  image_features.append(image_feature)
241
 
 
244
  image_patches_features = []
245
  for image_patch in image_patches:
246
  image_p = torch.stack([image_patch])
247
+
248
  with torch.set_grad_enabled(False):
249
  cnn_feature_p = vision_tower_high(image_p)
250
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
 
253
  image_feature = torch.cat(image_patches_features, dim=1)
254
  image_features.append(image_feature)
255
 
 
256
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
257
  dummy_image_features = dummy_image_features_2
258
  use_im_start_end = True
259
  new_input_embeds = []
260
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
261
  if (cur_input_ids == im_patch_token).sum() == 0:
262
+ cur_input_embeds = cur_input_embeds + (0.0 * dummy_image_features).sum()
263
  new_input_embeds.append(cur_input_embeds)
264
  continue
265
 
266
  if use_im_start_end:
267
  if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
268
  raise ValueError("The number of image start tokens and image end tokens should be the same.")
269
+
270
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
271
  for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
272
  per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
 
274
 
275
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
276
  raise ValueError("The image end token should follow the image start token.")
277
+
278
  cur_input_embeds = torch.cat(
279
  (
280
+ cur_input_embeds[: image_start_token_pos + 1],
281
+ per_cur_image_features,
282
+ cur_input_embeds[image_start_token_pos + num_patches + 1 :],
283
+ ),
284
+ dim=0,
285
  )
286
 
 
287
  new_input_embeds.append(cur_input_embeds)
288
  else:
289
  raise NotImplementedError
 
291
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
292
 
293
  return super(GOTQwenModel, self).forward(
294
+ input_ids=None,
295
+ attention_mask=attention_mask,
296
+ past_key_values=past_key_values,
297
+ inputs_embeds=inputs_embeds,
298
+ use_cache=use_cache,
299
+ position_ids=position_ids,
300
+ output_attentions=output_attentions,
301
+ output_hidden_states=output_hidden_states,
302
+ return_dict=return_dict,
303
  )
304
 
305
 
 
306
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
307
  config_class = GOTConfig
308
  # supports_gradient_checkpointing = True
 
333
  output_hidden_states: Optional[bool] = None,
334
  images: Optional[torch.FloatTensor] = None,
335
  return_dict: Optional[bool] = None,
 
336
  ) -> Union[Tuple, CausalLMOutputWithPast]:
337
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
338
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
339
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340
 
341
+ outputs = self.model(
342
  input_ids=input_ids,
343
  past_key_values=past_key_values,
344
  attention_mask=attention_mask,
 
348
  output_attentions=output_attentions,
349
  output_hidden_states=output_hidden_states,
350
  images=images,
351
+ return_dict=return_dict,
 
352
  )
353
 
354
  hidden_states = outputs[0]
 
382
  attentions=outputs.attentions,
383
  )
384
 
385
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
 
 
 
386
  # Omit tokens covered by past_key_values
387
  if past_key_values is not None:
388
  if isinstance(past_key_values, Cache):
 
406
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
407
 
408
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
409
+ if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
 
 
 
 
410
  attention_mask = attention_mask[:, -max_cache_length:]
411
 
412
  position_ids = kwargs.get("position_ids", None)
 
434
  )
435
  return model_inputs
436
 
437
+ def initialize_vision_tokenizer(self, tokenizer, freeze_lm_model=False, pretrained_stage1_model=None, device="cuda"):
 
 
 
 
 
 
438
  config = self.get_model().config
439
 
 
440
  self.resize_token_embeddings(len(tokenizer))
441
 
442
  config.im_patch_token = 151859
 
448
  config.im_start_token, config.im_end_token = 151857, 151858
449
 
450
  def load_image(self, image_file):
451
+ if image_file.startswith("http") or image_file.startswith("https"):
452
  response = requests.get(image_file)
453
+ image = Image.open(BytesIO(response.content)).convert("RGB")
454
  else:
455
+ image = Image.open(image_file).convert("RGB")
456
  return image
457
 
458
  def disable_torch_init(self):
 
460
  Disable the redundant torch default initialization to accelerate model creation.
461
  """
462
  import torch
463
+
464
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
465
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
466
 
467
+ def chat(
468
+ self,
469
+ tokenizer,
470
+ image_file,
471
+ ocr_type,
472
+ ocr_box="",
473
+ ocr_color="",
474
+ render=False,
475
+ save_render_file=None,
476
+ print_prompt=False,
477
+ gradio_input=False,
478
+ stream_flag=False,
479
+ ):
480
  self.disable_torch_init()
481
 
482
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
 
483
 
484
  use_im_start_end = True
485
 
 
491
  image = self.load_image(image_file)
492
 
493
  w, h = image.size
494
+
495
+ if ocr_type == "format":
496
+ qs = "OCR with format: "
497
  else:
498
+ qs = "OCR: "
499
 
500
  if ocr_box:
501
  bbox = eval(ocr_box)
502
  if len(bbox) == 2:
503
+ bbox[0] = int(bbox[0] / w * 1000)
504
+ bbox[1] = int(bbox[1] / h * 1000)
505
  if len(bbox) == 4:
506
+ bbox[0] = int(bbox[0] / w * 1000)
507
+ bbox[1] = int(bbox[1] / h * 1000)
508
+ bbox[2] = int(bbox[2] / w * 1000)
509
+ bbox[3] = int(bbox[3] / h * 1000)
510
+ if ocr_type == "format":
511
+ qs = str(bbox) + " " + "OCR with format: "
512
  else:
513
+ qs = str(bbox) + " " + "OCR: "
514
 
515
  if ocr_color:
516
+ if ocr_type == "format":
517
+ qs = "[" + ocr_color + "]" + " " + "OCR with format: "
518
  else:
519
+ qs = "[" + ocr_color + "]" + " " + "OCR: "
520
 
521
  if use_im_start_end:
522
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + "\n" + qs
523
  else:
524
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
 
525
 
526
  conv_mpt = Conversation(
527
  system="""<|im_start|>system
 
560
  input_ids,
561
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
562
  do_sample=False,
563
+ num_beams=1,
564
+ no_repeat_ngram_size=20,
565
  streamer=streamer,
566
  max_new_tokens=4096,
567
+ stopping_criteria=[stopping_criteria],
568
+ )
569
  else:
570
  with torch.autocast("cuda", dtype=torch.bfloat16):
571
  output_ids = self.generate(
572
  input_ids,
573
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
574
  do_sample=False,
575
+ num_beams=1,
576
+ no_repeat_ngram_size=20,
577
  # streamer=streamer,
578
  max_new_tokens=4096,
579
+ stopping_criteria=[stopping_criteria],
580
+ )
581
+
582
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
583
+
584
  if outputs.endswith(stop_str):
585
+ outputs = outputs[: -len(stop_str)]
586
  outputs = outputs.strip()
587
  response_str = outputs
588
 
589
  if render:
590
+ print("==============rendering===============")
591
+ from .render_tools import content_mmd_to_html, svg_to_html, tik_html, translation_table
592
 
593
+ if "**kern" in outputs:
594
  import verovio
595
+
596
  tk = verovio.toolkit()
597
  tk.loadData(outputs)
598
+ tk.setOptions({"pageWidth": 2100, "footer": "none", "barLineWidth": 0.5, "beamMaxSlope": 15, "staffLineWidth": 0.2, "spacingStaff": 6})
 
 
599
  tk.getPageCount()
600
  svg = tk.renderToSVG()
601
+ svg = svg.replace('overflow="inherit"', 'overflow="visible"')
602
 
603
  svg_to_html(svg, save_render_file)
604
 
605
+ if ocr_type == "format" and "**kern" not in outputs:
606
+ if "\\begin{tikzpicture}" not in outputs:
 
 
607
  html_path_2 = save_render_file
608
+ right_num = outputs.count("\\right")
609
+ left_num = outputs.count("\left")
610
 
611
  if right_num != left_num:
612
+ outputs = (
613
+ outputs.replace("\left(", "(")
614
+ .replace("\\right)", ")")
615
+ .replace("\left[", "[")
616
+ .replace("\\right]", "]")
617
+ .replace("\left{", "{")
618
+ .replace("\\right}", "}")
619
+ .replace("\left|", "|")
620
+ .replace("\\right|", "|")
621
+ .replace("\left.", ".")
622
+ .replace("\\right.", ".")
623
+ )
624
 
625
+ outputs = outputs.replace('"', "``").replace("$", "")
626
 
627
+ outputs_list = outputs.split("\n")
628
+ gt = ""
629
  for out in outputs_list:
630
+ gt += '"' + out.replace("\\", "\\\\") + r"\n" + '"' + "+" + "\n"
 
 
631
 
632
+ gt = gt[:-2]
633
 
634
  lines = content_mmd_to_html
635
  lines = lines.split("const text =")
636
+ new_web = lines[0] + "const text =" + gt + lines[1]
637
 
638
  else:
639
  html_path_2 = save_render_file
640
  outputs = outputs.translate(translation_table)
641
+ outputs_list = outputs.split("\n")
642
+ gt = ""
643
  for out in outputs_list:
644
  if out:
645
+ if "\\begin{tikzpicture}" not in out and "\\end{tikzpicture}" not in out:
646
+ while out[-1] == " ":
647
  out = out[:-1]
648
  if out is None:
649
  break
650
+
651
  if out:
652
+ if out[-1] != ";":
653
+ gt += out[:-1] + ";\n"
654
  else:
655
+ gt += out + "\n"
656
  else:
657
+ gt += out + "\n"
 
658
 
659
  lines = tik_html
660
  lines = lines.split("const text =")
661
  new_web = lines[0] + gt + lines[1]
662
 
663
+ with open(html_path_2, "w") as web_f_new:
664
  web_f_new.write(new_web)
665
  return response_str
666
 
667
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
 
668
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
669
+ best_ratio_diff = float("inf")
670
  best_ratio = (1, 1)
671
  area = width * height
672
  for ratio in target_ratios:
 
680
  best_ratio = ratio
681
  # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
682
  return best_ratio
683
+
684
  orig_width, orig_height = image.size
685
  aspect_ratio = orig_width / orig_height
686
 
687
  # calculate the existing image aspect ratio
688
  target_ratios = set(
689
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num
690
+ )
691
  # print(target_ratios)
692
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
693
 
694
  # find the closest aspect ratio to the target
695
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
 
696
 
697
  # print(target_aspect_ratio)
698
  # calculate the target width and height
 
708
  (i % (target_width // image_size)) * image_size,
709
  (i // (target_width // image_size)) * image_size,
710
  ((i % (target_width // image_size)) + 1) * image_size,
711
+ ((i // (target_width // image_size)) + 1) * image_size,
712
  )
713
  # split the image
714
  split_img = resized_img.crop(box)
 
719
  processed_images.append(thumbnail_img)
720
  return processed_images
721
 
722
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
 
723
  # Model
724
  self.disable_torch_init()
725
+ multi_page = False
 
726
 
727
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
728
 
729
  use_im_start_end = True
730
 
 
731
  image_token_len = 256
732
 
733
  image_list = []
 
736
  # multi_page = True
737
 
738
  if multi_page:
739
+ qs = "OCR with format across multi pages: "
740
  # only for png files
741
  # import glob
742
  # from natsort import natsorted
 
752
  # print("len ll: ", ll)
753
 
754
  else:
755
+ if ocr_type == "format":
756
+ qs = "OCR with format upon the patch reference: "
757
  else:
758
+ qs = "OCR upon the patch reference: "
759
  if gradio_input:
760
  img = image_file.copy()
761
  else:
 
767
  image_tensor_1 = image_processor_high(image)
768
  image_list.append(image_tensor_1)
769
 
 
770
  image_list = torch.stack(image_list)
771
 
772
+ print("====new images batch size======: \n", image_list.shape)
 
773
 
774
  if use_im_start_end:
775
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len * ll + DEFAULT_IM_END_TOKEN + "\n" + qs
776
  else:
777
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
 
778
 
779
  conv_mpt = Conversation(
780
  system="""<|im_start|>system
 
811
  input_ids,
812
  images=[image_list.half().cuda()],
813
  do_sample=False,
814
+ num_beams=1,
815
  # no_repeat_ngram_size = 20,
816
  streamer=streamer,
817
  max_new_tokens=4096,
818
+ stopping_criteria=[stopping_criteria],
819
+ )
820
  else:
821
  with torch.autocast("cuda", dtype=torch.bfloat16):
822
  output_ids = self.generate(
823
  input_ids,
824
  images=[image_list.half().cuda()],
825
  do_sample=False,
826
+ num_beams=1,
827
  # no_repeat_ngram_size = 20,
828
  # streamer=streamer,
829
  max_new_tokens=4096,
830
+ stopping_criteria=[stopping_criteria],
831
+ )
832
+
833
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
834
 
 
 
835
  if outputs.endswith(stop_str):
836
+ outputs = outputs[: -len(stop_str)]
837
+ outputs = outputs.strip()
838
  response_str = outputs
839
 
840
  if render:
841
+ print("==============rendering===============")
842
  from .render_tools import content_mmd_to_html
843
+
844
  html_path_2 = save_render_file
845
+ right_num = outputs.count("\\right")
846
+ left_num = outputs.count("\left")
847
 
848
  if right_num != left_num:
849
+ outputs = (
850
+ outputs.replace("\left(", "(")
851
+ .replace("\\right)", ")")
852
+ .replace("\left[", "[")
853
+ .replace("\\right]", "]")
854
+ .replace("\left{", "{")
855
+ .replace("\\right}", "}")
856
+ .replace("\left|", "|")
857
+ .replace("\\right|", "|")
858
+ .replace("\left.", ".")
859
+ .replace("\\right.", ".")
860
+ )
861
+
862
+ outputs = outputs.replace('"', "``").replace("$", "")
863
+
864
+ outputs_list = outputs.split("\n")
865
+ gt = ""
866
  for out in outputs_list:
867
+ gt += '"' + out.replace("\\", "\\\\") + r"\n" + '"' + "+" + "\n"
868
+
869
  gt = gt[:-2]
870
 
871
  lines = content_mmd_to_html
872
  lines = lines.split("const text =")
873
+ new_web = lines[0] + "const text =" + gt + lines[1]
874
+
875
+ with open(html_path_2, "w") as web_f_new:
876
  web_f_new.write(new_web)
877
 
878
+ return response_str
render_tools.py CHANGED
@@ -1,13 +1,9 @@
 
1
 
2
- punctuation_dict = {
3
- ",": ",",
4
- "。": ".",
5
-
6
- }
7
  translation_table = str.maketrans(punctuation_dict)
8
-
9
- def svg_to_html(svg_content, output_filename):
10
 
 
 
11
  html_content = f"""
12
  <!DOCTYPE html>
13
  <html lang="en">
@@ -24,9 +20,8 @@ def svg_to_html(svg_content, output_filename):
24
  </html>
25
  """
26
 
27
- with open(output_filename, 'w') as file:
28
  file.write(html_content)
29
-
30
 
31
 
32
  content_mmd_to_html = """<!DOCTYPE html>
@@ -34,7 +29,7 @@ content_mmd_to_html = """<!DOCTYPE html>
34
  <meta charset="UTF-8">
35
  <title>Title</title>
36
  <script>
37
- const text =
38
  </script>
39
  <style>
40
  #content {
@@ -71,7 +66,6 @@ content_mmd_to_html = """<!DOCTYPE html>
71
  """
72
 
73
 
74
-
75
  tik_html = """
76
  <!DOCTYPE html>
77
 
@@ -92,5 +86,4 @@ const text =
92
  </html>"""
93
 
94
 
95
-
96
- # print(tik_html)
 
1
+ punctuation_dict = {",": ",", "。": "."}
2
 
 
 
 
 
 
3
  translation_table = str.maketrans(punctuation_dict)
 
 
4
 
5
+
6
+ def svg_to_html(svg_content, output_filename):
7
  html_content = f"""
8
  <!DOCTYPE html>
9
  <html lang="en">
 
20
  </html>
21
  """
22
 
23
+ with open(output_filename, "w") as file:
24
  file.write(html_content)
 
25
 
26
 
27
  content_mmd_to_html = """<!DOCTYPE html>
 
29
  <meta charset="UTF-8">
30
  <title>Title</title>
31
  <script>
32
+ const text =
33
  </script>
34
  <style>
35
  #content {
 
66
  """
67
 
68
 
 
69
  tik_html = """
70
  <!DOCTYPE html>
71
 
 
86
  </html>"""
87
 
88
 
89
+ # print(tik_html)
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tiktoken
2
+ transformers
3
+ torch
4
+ torchvision
5
+ requests
6
+ verovio
special_tokens_map.json CHANGED
@@ -6,4 +6,4 @@
6
  "rstrip": false,
7
  "single_word": false
8
  }
9
- }
 
6
  "rstrip": false,
7
  "single_word": false
8
  }
9
+ }
tokenization_qwen.py CHANGED
@@ -12,7 +12,7 @@ import unicodedata
12
  from typing import Collection, Dict, List, Set, Tuple, Union
13
 
14
  import tiktoken
15
- from transformers import PreTrainedTokenizer, AddedToken
16
 
17
  logger = logging.getLogger(__name__)
18
 
@@ -37,10 +37,8 @@ SPECIAL_TOKENS = (
37
  def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
  with open(tiktoken_bpe_file, "rb") as f:
39
  contents = f.read()
40
- return {
41
- base64.b64decode(token): int(rank)
42
- for token, rank in (line.split() for line in contents.splitlines() if line)
43
- }
44
 
45
  class QWenTokenizer(PreTrainedTokenizer):
46
  """QWen tokenizer."""
@@ -51,19 +49,19 @@ class QWenTokenizer(PreTrainedTokenizer):
51
  self,
52
  vocab_file,
53
  errors="replace",
54
- image_start_tag='<img>',
55
- image_end_tag='</img>',
56
- image_pad_tag='<imgpad>',
57
- ref_start_tag='<ref>',
58
- ref_end_tag='</ref>',
59
- box_start_tag='<box>',
60
- box_end_tag='</box>',
61
- quad_start_tag='<quad>',
62
- quad_end_tag='</quad>',
63
  **kwargs,
64
  ):
65
  super().__init__(**kwargs)
66
-
67
  self.image_start_tag = image_start_tag
68
  self.image_end_tag = image_end_tag
69
  self.image_pad_tag = image_pad_tag
@@ -73,24 +71,13 @@ class QWenTokenizer(PreTrainedTokenizer):
73
  self.box_end_tag = box_end_tag
74
  self.quad_start_tag = quad_start_tag
75
  self.quad_end_tag = quad_end_tag
76
- self.IMAGE_ST = (
77
- ref_start_tag, ref_end_tag,
78
- box_start_tag, box_end_tag,
79
- quad_start_tag, quad_end_tag,
80
- image_start_tag, image_end_tag,
81
- image_pad_tag
82
- )
83
 
84
  self.errors = errors # how to handle errors in decoding
85
 
86
  self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
87
- self.special_tokens = {
88
- token: index
89
- for index, token in enumerate(
90
- SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
91
- )
92
- }
93
-
94
  self.img_start_id = self.special_tokens[self.image_start_tag]
95
  self.img_end_id = self.special_tokens[self.image_end_tag]
96
  self.img_pad_id = self.special_tokens[self.image_pad_tag]
@@ -111,9 +98,7 @@ class QWenTokenizer(PreTrainedTokenizer):
111
  len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
112
  ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
113
 
114
- self.decoder = {
115
- v: k for k, v in self.mergeable_ranks.items()
116
- } # type: dict[int, bytes|str]
117
  self.decoder.update({v: k for k, v in self.special_tokens.items()})
118
 
119
  self.tokenizer = enc # type: tiktoken.Encoding
@@ -128,9 +113,7 @@ class QWenTokenizer(PreTrainedTokenizer):
128
  def get_vocab(self) -> Dict[bytes, int]:
129
  return self.mergeable_ranks
130
 
131
- def convert_tokens_to_ids(
132
- self, tokens: Union[bytes, str, List[Union[bytes, str]]]
133
- ) -> List[int]:
134
  ids = []
135
  if isinstance(tokens, (str, bytes)):
136
  if tokens in self.special_tokens:
@@ -146,11 +129,11 @@ class QWenTokenizer(PreTrainedTokenizer):
146
 
147
  def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
148
  if not special_tokens and new_tokens:
149
- raise ValueError('Adding regular tokens is not supported')
150
  for token in new_tokens:
151
  surface_form = token.content if isinstance(token, AddedToken) else token
152
  if surface_form not in SPECIAL_TOKENS:
153
- raise ValueError('Adding unknown special tokens is not supported')
154
  return 0
155
 
156
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
@@ -197,9 +180,7 @@ class QWenTokenizer(PreTrainedTokenizer):
197
  text = unicodedata.normalize("NFC", text)
198
 
199
  # this implementation takes a detour: text -> token id -> token surface forms
200
- for t in self.tokenizer.encode(
201
- text, allowed_special=allowed_special, disallowed_special=disallowed_special
202
- ):
203
  tokens.append(self.decoder[t])
204
  return tokens
205
 
 
12
  from typing import Collection, Dict, List, Set, Tuple, Union
13
 
14
  import tiktoken
15
+ from transformers import AddedToken, PreTrainedTokenizer
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
37
  def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
  with open(tiktoken_bpe_file, "rb") as f:
39
  contents = f.read()
40
+ return {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
41
+
 
 
42
 
43
  class QWenTokenizer(PreTrainedTokenizer):
44
  """QWen tokenizer."""
 
49
  self,
50
  vocab_file,
51
  errors="replace",
52
+ image_start_tag="<img>",
53
+ image_end_tag="</img>",
54
+ image_pad_tag="<imgpad>",
55
+ ref_start_tag="<ref>",
56
+ ref_end_tag="</ref>",
57
+ box_start_tag="<box>",
58
+ box_end_tag="</box>",
59
+ quad_start_tag="<quad>",
60
+ quad_end_tag="</quad>",
61
  **kwargs,
62
  ):
63
  super().__init__(**kwargs)
64
+
65
  self.image_start_tag = image_start_tag
66
  self.image_end_tag = image_end_tag
67
  self.image_pad_tag = image_pad_tag
 
71
  self.box_end_tag = box_end_tag
72
  self.quad_start_tag = quad_start_tag
73
  self.quad_end_tag = quad_end_tag
74
+ self.IMAGE_ST = (ref_start_tag, ref_end_tag, box_start_tag, box_end_tag, quad_start_tag, quad_end_tag, image_start_tag, image_end_tag, image_pad_tag)
 
 
 
 
 
 
75
 
76
  self.errors = errors # how to handle errors in decoding
77
 
78
  self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
79
+ self.special_tokens = {token: index for index, token in enumerate(SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks))}
80
+
 
 
 
 
 
81
  self.img_start_id = self.special_tokens[self.image_start_tag]
82
  self.img_end_id = self.special_tokens[self.image_end_tag]
83
  self.img_pad_id = self.special_tokens[self.image_pad_tag]
 
98
  len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
99
  ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
100
 
101
+ self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
 
 
102
  self.decoder.update({v: k for k, v in self.special_tokens.items()})
103
 
104
  self.tokenizer = enc # type: tiktoken.Encoding
 
113
  def get_vocab(self) -> Dict[bytes, int]:
114
  return self.mergeable_ranks
115
 
116
+ def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
 
 
117
  ids = []
118
  if isinstance(tokens, (str, bytes)):
119
  if tokens in self.special_tokens:
 
129
 
130
  def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
131
  if not special_tokens and new_tokens:
132
+ raise ValueError("Adding regular tokens is not supported")
133
  for token in new_tokens:
134
  surface_form = token.content if isinstance(token, AddedToken) else token
135
  if surface_form not in SPECIAL_TOKENS:
136
+ raise ValueError("Adding unknown special tokens is not supported")
137
  return 0
138
 
139
  def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
 
180
  text = unicodedata.normalize("NFC", text)
181
 
182
  # this implementation takes a detour: text -> token id -> token surface forms
183
+ for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
 
 
184
  tokens.append(self.decoder[t])
185
  return tokens
186