happyme531 commited on
Commit
95dfa6c
1 Parent(s): 432425c

Upload 21 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ onnx/decoder_model.rknn filter=lfs diff=lfs merge=lfs -text
37
+ onnx/encoder_model.rknn filter=lfs diff=lfs merge=lfs -text
38
+ onnx/vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "florence2",
3
+ "architectures": [
4
+ "Florence2ForConditionalGeneration"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_florence2.Florence2Config",
8
+ "AutoModelForCausalLM": "modeling_florence2.Florence2ForConditionalGeneration"
9
+ },
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 2,
12
+ "ignore_index": -100,
13
+ "model_type": "florence2",
14
+ "pad_token_id": 1,
15
+ "projection_dim": 768,
16
+ "text_config": {
17
+ "vocab_size": 51289,
18
+ "activation_dropout": 0.1,
19
+ "activation_function": "gelu",
20
+ "add_bias_logits": false,
21
+ "add_final_layer_norm": false,
22
+ "attention_dropout": 0.1,
23
+ "bos_token_id": 0,
24
+ "classif_dropout": 0.1,
25
+ "classifier_dropout": 0.0,
26
+ "d_model": 768,
27
+ "decoder_attention_heads": 12,
28
+ "decoder_ffn_dim": 3072,
29
+ "decoder_layerdrop": 0.0,
30
+ "decoder_layers": 6,
31
+ "decoder_start_token_id": 2,
32
+ "dropout": 0.1,
33
+ "early_stopping": true,
34
+ "encoder_attention_heads": 12,
35
+ "encoder_ffn_dim": 3072,
36
+ "encoder_layerdrop": 0.0,
37
+ "encoder_layers": 6,
38
+ "eos_token_id": 2,
39
+ "forced_eos_token_id": 2,
40
+ "forced_bos_token_id": 0,
41
+ "gradient_checkpointing": false,
42
+ "init_std": 0.02,
43
+ "is_encoder_decoder": true,
44
+ "label2id": {
45
+ "LABEL_0": 0,
46
+ "LABEL_1": 1,
47
+ "LABEL_2": 2
48
+ },
49
+ "max_position_embeddings": 1024,
50
+ "no_repeat_ngram_size": 3,
51
+ "normalize_before": false,
52
+ "num_hidden_layers": 6,
53
+ "pad_token_id": 1,
54
+ "scale_embedding": false,
55
+ "num_beams": 3
56
+ },
57
+ "vision_config": {
58
+ "model_type": "davit",
59
+ "drop_path_rate": 0.1,
60
+ "patch_size": [7, 3, 3, 3],
61
+ "patch_stride": [4, 2, 2, 2],
62
+ "patch_padding": [3, 1, 1, 1],
63
+ "patch_prenorm": [false, true, true, true],
64
+ "enable_checkpoint": false,
65
+ "dim_embed": [128, 256, 512, 1024],
66
+ "num_heads": [4, 8, 16, 32],
67
+ "num_groups": [4, 8, 16, 32],
68
+ "depths": [1, 1, 9, 1],
69
+ "window_size": 12,
70
+ "projection_dim": 768,
71
+ "visual_temporal_embedding": {
72
+ "type": "COSINE",
73
+ "max_temporal_embeddings": 100
74
+ },
75
+ "image_pos_embed": {
76
+ "type": "learned_abs_2d",
77
+ "max_pos_embeddings": 50
78
+ },
79
+ "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
80
+ },
81
+ "vocab_size": 51289,
82
+ "torch_dtype": "float16",
83
+ "transformers_version": "4.41.0.dev0",
84
+ "is_encoder_decoder": true
85
+ }
configuration_florence2.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+ """ Florence-2 configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from transformers import AutoConfig
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ class Florence2VisionConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
28
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
29
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
36
+ The dropout rate of the drop path layer.
37
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
38
+ The patch size of the image.
39
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
40
+ The patch stride of the image.
41
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
42
+ The patch padding of the image.
43
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
44
+ Whether to apply layer normalization before the patch embedding layer.
45
+ enable_checkpoint (`bool`, *optional*, defaults to False):
46
+ Whether to enable checkpointing.
47
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
48
+ The dimension of the embedding layer.
49
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
50
+ The number of attention heads.
51
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
52
+ The number of groups.
53
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
54
+ The depth of the model.
55
+ window_size (`int`, *optional*, defaults to 12):
56
+ The window size of the model.
57
+ projection_dim (`int`, *optional*, defaults to 1024):
58
+ The dimension of the projection layer.
59
+ visual_temporal_embedding (`dict`, *optional*):
60
+ The configuration of the visual temporal embedding.
61
+ image_pos_embed (`dict`, *optional*):
62
+ The configuration of the image position embedding.
63
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
64
+ The source of the image feature.
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
69
+
70
+ >>> # Initializing a Florence2 Vision style configuration
71
+ >>> configuration = Florence2VisionConfig()
72
+
73
+ >>> # Initializing a model (with random weights)
74
+ >>> model = Florence2VisionModel(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "florence2_vision"
81
+ keys_to_ignore_at_inference = ["past_key_values"]
82
+
83
+ def __init__(
84
+ self,
85
+ drop_path_rate=0.1,
86
+ patch_size=[7, 3, 3, 3],
87
+ patch_stride=[4, 2, 2, 2],
88
+ patch_padding=[3, 1, 1, 1],
89
+ patch_prenorm=[False, True, True, True],
90
+ enable_checkpoint=False,
91
+ dim_embed=[256, 512, 1024, 2048],
92
+ num_heads=[8, 16, 32, 64],
93
+ num_groups=[8, 16, 32, 64],
94
+ depths=[1, 1, 9, 1],
95
+ window_size=12,
96
+ projection_dim=1024,
97
+ visual_temporal_embedding=None,
98
+ image_pos_embed=None,
99
+ image_feature_source=["spatial_avg_pool", "temporal_avg_pool"],
100
+ **kwargs,
101
+ ):
102
+ self.drop_path_rate = drop_path_rate
103
+ self.patch_size = patch_size
104
+ self.patch_stride = patch_stride
105
+ self.patch_padding = patch_padding
106
+ self.patch_prenorm = patch_prenorm
107
+ self.enable_checkpoint = enable_checkpoint
108
+ self.dim_embed = dim_embed
109
+ self.num_heads = num_heads
110
+ self.num_groups = num_groups
111
+ self.depths = depths
112
+ self.window_size = window_size
113
+ self.projection_dim = projection_dim
114
+ self.visual_temporal_embedding = visual_temporal_embedding
115
+ self.image_pos_embed = image_pos_embed
116
+ self.image_feature_source = image_feature_source
117
+
118
+ super().__init__(**kwargs)
119
+
120
+
121
+
122
+ class Florence2LanguageConfig(PretrainedConfig):
123
+ r"""
124
+ This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
125
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
126
+ defaults will yield a similar configuration to that of the BART
127
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
128
+
129
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
130
+ documentation from [`PretrainedConfig`] for more information.
131
+
132
+
133
+ Args:
134
+ vocab_size (`int`, *optional*, defaults to 51289):
135
+ Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
136
+ `inputs_ids` passed when calling [`Florence2LanguageModel`].
137
+ d_model (`int`, *optional*, defaults to 1024):
138
+ Dimensionality of the layers and the pooler layer.
139
+ encoder_layers (`int`, *optional*, defaults to 12):
140
+ Number of encoder layers.
141
+ decoder_layers (`int`, *optional*, defaults to 12):
142
+ Number of decoder layers.
143
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
144
+ Number of attention heads for each attention layer in the Transformer encoder.
145
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
146
+ Number of attention heads for each attention layer in the Transformer decoder.
147
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
148
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
149
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
150
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
151
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
152
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
153
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
154
+ dropout (`float`, *optional*, defaults to 0.1):
155
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
156
+ attention_dropout (`float`, *optional*, defaults to 0.0):
157
+ The dropout ratio for the attention probabilities.
158
+ activation_dropout (`float`, *optional*, defaults to 0.0):
159
+ The dropout ratio for activations inside the fully connected layer.
160
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
161
+ The dropout ratio for classifier.
162
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
163
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
164
+ just in case (e.g., 512 or 1024 or 2048).
165
+ init_std (`float`, *optional*, defaults to 0.02):
166
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
167
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
168
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
169
+ for more details.
170
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
171
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
172
+ for more details.
173
+ scale_embedding (`bool`, *optional*, defaults to `False`):
174
+ Scale embeddings by diving by sqrt(d_model).
175
+ use_cache (`bool`, *optional*, defaults to `True`):
176
+ Whether or not the model should return the last key/values attentions (not used by all models).
177
+ num_labels (`int`, *optional*, defaults to 3):
178
+ The number of labels to use in [`Florence2LanguageForSequenceClassification`].
179
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
180
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
181
+ `eos_token_id`.
182
+
183
+ Example:
184
+
185
+ ```python
186
+ >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
187
+
188
+ >>> # Initializing a Florence2 Language style configuration
189
+ >>> configuration = Florence2LanguageConfig()
190
+
191
+ >>> # Initializing a model (with random weights)
192
+ >>> model = Florence2LangaugeModel(configuration)
193
+
194
+ >>> # Accessing the model configuration
195
+ >>> configuration = model.config
196
+ ```"""
197
+
198
+ model_type = "florence2_language"
199
+ keys_to_ignore_at_inference = ["past_key_values"]
200
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=51289,
205
+ max_position_embeddings=1024,
206
+ encoder_layers=12,
207
+ encoder_ffn_dim=4096,
208
+ encoder_attention_heads=16,
209
+ decoder_layers=12,
210
+ decoder_ffn_dim=4096,
211
+ decoder_attention_heads=16,
212
+ encoder_layerdrop=0.0,
213
+ decoder_layerdrop=0.0,
214
+ activation_function="gelu",
215
+ d_model=1024,
216
+ dropout=0.1,
217
+ attention_dropout=0.0,
218
+ activation_dropout=0.0,
219
+ init_std=0.02,
220
+ classifier_dropout=0.0,
221
+ scale_embedding=False,
222
+ use_cache=True,
223
+ num_labels=3,
224
+ pad_token_id=1,
225
+ bos_token_id=0,
226
+ eos_token_id=2,
227
+ is_encoder_decoder=True,
228
+ decoder_start_token_id=2,
229
+ forced_eos_token_id=2,
230
+ **kwargs,
231
+ ):
232
+ self.vocab_size = vocab_size
233
+ self.max_position_embeddings = max_position_embeddings
234
+ self.d_model = d_model
235
+ self.encoder_ffn_dim = encoder_ffn_dim
236
+ self.encoder_layers = encoder_layers
237
+ self.encoder_attention_heads = encoder_attention_heads
238
+ self.decoder_ffn_dim = decoder_ffn_dim
239
+ self.decoder_layers = decoder_layers
240
+ self.decoder_attention_heads = decoder_attention_heads
241
+ self.dropout = dropout
242
+ self.attention_dropout = attention_dropout
243
+ self.activation_dropout = activation_dropout
244
+ self.activation_function = activation_function
245
+ self.init_std = init_std
246
+ self.encoder_layerdrop = encoder_layerdrop
247
+ self.decoder_layerdrop = decoder_layerdrop
248
+ self.classifier_dropout = classifier_dropout
249
+ self.use_cache = use_cache
250
+ self.num_hidden_layers = encoder_layers
251
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
252
+
253
+ super().__init__(
254
+ num_labels=num_labels,
255
+ pad_token_id=pad_token_id,
256
+ bos_token_id=bos_token_id,
257
+ eos_token_id=eos_token_id,
258
+ is_encoder_decoder=is_encoder_decoder,
259
+ decoder_start_token_id=decoder_start_token_id,
260
+ forced_eos_token_id=forced_eos_token_id,
261
+ **kwargs,
262
+ )
263
+
264
+ # ensure backward compatibility for BART CNN models
265
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
266
+ self.forced_bos_token_id = self.bos_token_id
267
+ warnings.warn(
268
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
269
+ "The config can simply be saved and uploaded again to be fixed."
270
+ )
271
+
272
+ class Florence2Config(PretrainedConfig):
273
+ r"""
274
+ This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
275
+ Florence-2 model according to the specified arguments, defining the model architecture.
276
+
277
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
278
+ documentation from [`PretrainedConfig`] for more information.
279
+
280
+ Args:
281
+ vision_config (`Florence2VisionConfig`, *optional*):
282
+ Custom vision config or dict
283
+ text_config (`Union[AutoConfig, dict]`, *optional*):
284
+ The config object of the text backbone.
285
+ ignore_index (`int`, *optional*, defaults to -100):
286
+ The ignore index for the loss function.
287
+ vocab_size (`int`, *optional*, defaults to 51289):
288
+ Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
289
+ `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
290
+ projection_dim (`int`, *optional*, defaults to 1024):
291
+ Dimension of the multimodal projection space.
292
+
293
+ Example:
294
+
295
+ ```python
296
+ >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
297
+
298
+ >>> # Initializing a clip-like vision config
299
+ >>> vision_config = CLIPVisionConfig()
300
+
301
+ >>> # Initializing a Bart config
302
+ >>> text_config = BartConfig()
303
+
304
+ >>> # Initializing a Florence-2 configuration
305
+ >>> configuration = Florence2Config(vision_config, text_config)
306
+
307
+ >>> # Initializing a model from the florence-2 configuration
308
+ >>> model = Florence2ForConditionalGeneration(configuration)
309
+
310
+ >>> # Accessing the model configuration
311
+ >>> configuration = model.config
312
+ ```"""
313
+
314
+ model_type = "florence2"
315
+ is_composition = False
316
+
317
+ def __init__(
318
+ self,
319
+ vision_config=None,
320
+ text_config=None,
321
+ ignore_index=-100,
322
+ vocab_size=51289,
323
+ projection_dim=1024,
324
+ **kwargs,
325
+ ):
326
+ self.ignore_index = ignore_index
327
+ self.vocab_size = vocab_size
328
+ self.projection_dim = projection_dim
329
+ if vision_config is not None:
330
+ vision_config = PretrainedConfig(**vision_config)
331
+ self.vision_config = vision_config
332
+ self.vocab_size = self.vocab_size
333
+
334
+ self.text_config = text_config
335
+ if text_config is not None:
336
+ self.text_config = Florence2LanguageConfig(**text_config)
337
+
338
+
339
+ super().__init__(**kwargs)
340
+
modeling_florence2.py ADDED
The diff for this file is too large to render. See raw diff
 
onnx/car.jpg ADDED
onnx/convert.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import os
8
+ import urllib
9
+ import traceback
10
+ import time
11
+ import sys
12
+ import numpy as np
13
+ import cv2
14
+ from rknn.api import RKNN
15
+ from math import exp
16
+ from sys import exit
17
+
18
+ batch_size = 1
19
+ # embed_seq_len = 590
20
+
21
+ vision_size = (512, 512)
22
+
23
+ vision_tokens = 257
24
+ prompt_tokens = 14
25
+
26
+ encoder_seq_len = vision_tokens + prompt_tokens
27
+ decoder_seq_len = 1
28
+
29
+ def convert_decoder():
30
+ rknn = RKNN(verbose=True)
31
+
32
+ ONNX_MODEL="decoder_model.onnx"
33
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
34
+ DATASET="dataset.txt"
35
+ QUANTIZE=False
36
+
37
+ # pre-process config
38
+ print('--> Config model')
39
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
40
+ print('done')
41
+
42
+ # Load ONNX model
43
+ print('--> Loading model')
44
+ ret = rknn.load_onnx(model=ONNX_MODEL,
45
+ inputs=["encoder_attention_mask",
46
+ "encoder_hidden_states",
47
+ "inputs_embeds",
48
+ ],
49
+ input_size_list=[[batch_size, encoder_seq_len],
50
+ [batch_size, encoder_seq_len, 768],
51
+ [batch_size, decoder_seq_len, 768]],
52
+ )
53
+ if ret != 0:
54
+ print('Load model failed!')
55
+ exit(ret)
56
+ print('done')
57
+
58
+ # Build model
59
+ print('--> Building model')
60
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
61
+ if ret != 0:
62
+ print('Build model failed!')
63
+ exit(ret)
64
+ print('done')
65
+
66
+ #export
67
+ print('--> Export RKNN model')
68
+ ret = rknn.export_rknn(RKNN_MODEL)
69
+ if ret != 0:
70
+ print('Export RKNN model failed!')
71
+ exit(ret)
72
+ print('done')
73
+
74
+ def convert_encoder():
75
+ rknn = RKNN(verbose=True)
76
+
77
+ ONNX_MODEL="encoder_model.onnx"
78
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
79
+ DATASET="dataset.txt"
80
+ QUANTIZE=False
81
+
82
+ # pre-process config
83
+ print('--> Config model')
84
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
85
+ print('done')
86
+
87
+ # Load ONNX model
88
+ print('--> Loading model')
89
+ ret = rknn.load_onnx(model=ONNX_MODEL,
90
+ inputs=["attention_mask", "inputs_embeds"],
91
+ input_size_list=[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]],
92
+ )
93
+ if ret != 0:
94
+ print('Load model failed!')
95
+ exit(ret)
96
+ print('done')
97
+
98
+ # Build model
99
+ print('--> Building model')
100
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
101
+ if ret != 0:
102
+ print('Build model failed!')
103
+ exit(ret)
104
+ print('done')
105
+
106
+ # Export RKNN model
107
+ print('--> Export RKNN model')
108
+ ret = rknn.export_rknn(RKNN_MODEL)
109
+ if ret != 0:
110
+ print('Export RKNN model failed!')
111
+ exit(ret)
112
+ print('done')
113
+
114
+ def convert_embed():
115
+ rknn = RKNN(verbose=True)
116
+
117
+ ONNX_MODEL="embed_tokens.onnx"
118
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
119
+ DATASET="dataset.txt"
120
+ QUANTIZE=False
121
+
122
+ # pre-process config
123
+ print('--> Config model')
124
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
125
+ print('done')
126
+
127
+ # Load ONNX model
128
+ print('--> Loading model')
129
+ ret = rknn.load_onnx(model=ONNX_MODEL,
130
+ inputs=["input_ids"],
131
+ input_size_list=[[batch_size, embed_seq_len]],
132
+ )
133
+ if ret != 0:
134
+ print('Load model failed!')
135
+ exit(ret)
136
+ print('done')
137
+
138
+ # Build model
139
+ print('--> Building model')
140
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
141
+ if ret != 0:
142
+ print('Build model failed!')
143
+ exit(ret)
144
+ print('done')
145
+
146
+ # Export RKNN model
147
+ print('--> Export RKNN model')
148
+ ret = rknn.export_rknn(RKNN_MODEL)
149
+ if ret != 0:
150
+ print('Export RKNN model failed!')
151
+ exit(ret)
152
+ print('done')
153
+
154
+ def convert_vision():
155
+ rknn = RKNN(verbose=True)
156
+
157
+ ONNX_MODEL="vision_encoder.onnx"
158
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
159
+ DATASET="dataset.txt"
160
+ QUANTIZE=False
161
+
162
+ # pre-process config
163
+ print('--> Config model')
164
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
165
+ print('done')
166
+
167
+ # Load ONNX model
168
+ print('--> Loading model')
169
+ ret = rknn.load_onnx(model=ONNX_MODEL,
170
+ inputs=["pixel_values"],
171
+ input_size_list=[[batch_size, 3, vision_size[0], vision_size[1]]],
172
+ )
173
+ if ret != 0:
174
+ print('Load model failed!')
175
+ exit(ret)
176
+ print('done')
177
+
178
+ # Build model
179
+ print('--> Building model')
180
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
181
+ if ret != 0:
182
+ print('Build model failed!')
183
+ exit(ret)
184
+ print('done')
185
+
186
+ # Export RKNN model
187
+ print('--> Export RKNN model')
188
+ ret = rknn.export_rknn(RKNN_MODEL)
189
+ if ret != 0:
190
+ print('Export RKNN model failed!')
191
+ exit(ret)
192
+ print('done')
193
+
194
+
195
+ import argparse
196
+ # python convert.py <decoder|encoder|vision|all>
197
+ if __name__ == "__main__":
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument("model", type=str, help="Model to convert")
200
+ args = parser.parse_args()
201
+ if args.model == "decoder":
202
+ convert_decoder()
203
+ elif args.model == "encoder":
204
+ convert_encoder()
205
+ # elif args.model == "embed": # embed is faster with cpu
206
+ # convert_embed()
207
+ elif args.model == "vision":
208
+ convert_vision()
209
+ elif args.model == "all":
210
+ convert_decoder()
211
+ convert_encoder()
212
+ # convert_embed()
213
+ convert_vision()
214
+ else:
215
+ print("Invalid model")
216
+ exit(1)
onnx/decoder_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16b40ea746ea09802a549be74c2eedc937c76025a1ea9baa040617ba0605306d
3
+ size 388077195
onnx/decoder_model.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331a6a05a524c72ac7287a494d6cadd425266888be5ff9375649c8760417f611
3
+ size 194821309
onnx/decoder_model_merged_q4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be7a2f33e65f8d65538024772fda4d1c5a7752d60a7159aadf53f9f4798b90fa
3
+ size 64393474
onnx/embed_tokens.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90cae3deb6406938c676a35b5246db02b478c9cc8cf93508361be80c05babf95
3
+ size 157560044
onnx/encoder_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb0bccc232c64290397f5e1235eb3e1fa6ccf8c5afed9216480ee4eed80737fc
3
+ size 173380723
onnx/encoder_model.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36af46c308219399dfe5f1df53c2093c3247c9dc248dc3c3167ab88975cf62c
3
+ size 87231735
onnx/lena.png ADDED
onnx/onnxrun.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ from transformers import AutoProcessor
3
+ from PIL import Image
4
+ import numpy as np
5
+ # set current working directory to the directory of this file
6
+ import os
7
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ # embeddings
10
+ vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
11
+ text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
12
+ # encoder
13
+ encoder = ort.InferenceSession("encoder_model.onnx", providers=['CPUExecutionProvider'])
14
+ # decoder
15
+ decoder_prefill = ort.InferenceSession("decoder_model.onnx", providers=['CPUExecutionProvider'])
16
+ decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
17
+
18
+ # 1. prepare inputs
19
+ processor = AutoProcessor.from_pretrained("/home/zt/rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
20
+
21
+ # 2. prepare image
22
+ image = Image.open("./lena.png")
23
+
24
+ # resize image to 512x512
25
+ image = image.resize((512, 512))
26
+ # 3. prepare text
27
+ prompt = "<MORE_DETAILED_CAPTION>"
28
+ inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
29
+ for k, v in inputs.items():
30
+ print(k, v.shape)
31
+
32
+ # 4. run vision encoder
33
+ image_features = vision_encoder.run(None, {
34
+ "pixel_values": inputs["pixel_values"]
35
+ })
36
+ for output in image_features:
37
+ print(output.shape)
38
+ image_features = image_features[0]
39
+ np.save("image_features.npy", image_features)
40
+ # 5. run text embed
41
+ inputs_embeds = text_embed.run(None, {
42
+ "input_ids": inputs["input_ids"]
43
+ })
44
+ for output in inputs_embeds:
45
+ print(output.shape)
46
+
47
+ inputs_embeds = inputs_embeds[0]
48
+
49
+ # 6. concat image features and text embed
50
+ batch_size, image_token_length = image_features.shape[:-1]
51
+ image_attention_mask = np.ones((batch_size, image_token_length))
52
+ task_prefix_embeds = inputs_embeds
53
+ task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
54
+ if len(task_prefix_attention_mask.shape) == 3:
55
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
56
+ inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
57
+ attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
58
+
59
+ # 6. run encoder
60
+ encoder_out = encoder.run(None, {
61
+ "inputs_embeds": inputs_embeds,
62
+ "attention_mask": attention_mask.astype(np.int64)
63
+ })
64
+ for output in encoder_out:
65
+ print(output.shape)
66
+
67
+ encoder_hidden_states = encoder_out[0]
68
+
69
+ # 7. run decoder prefill stage
70
+ decoder_outs = decoder_prefill.run(None, {
71
+ "inputs_embeds": inputs_embeds[:, -1:],
72
+ "encoder_hidden_states": encoder_hidden_states,
73
+ "encoder_attention_mask": attention_mask.astype(np.int64)
74
+ })
75
+ for output in decoder_outs:
76
+ print(output.shape)
77
+
78
+ encoder_kv = decoder_outs[1:];
79
+
80
+ # 8. run decoder decode stage(autoregressive)
81
+ generated_tokens = []
82
+ max_new_tokens = 32
83
+ while generated_tokens.__len__() < max_new_tokens:
84
+ # 获取上一步的输出
85
+ logits = decoder_outs[0]
86
+ decoder_kv = decoder_outs[1:]
87
+
88
+ # 选择最后一个token的logits
89
+ next_token_logits = logits[:, -1, :]
90
+
91
+ # 使用argmax选择下一个token (贪心算法)
92
+ next_token = np.argmax(next_token_logits, axis=-1)[0]
93
+ print("next_token: ", next_token)
94
+ # 将新生成的token添加到结果中
95
+ generated_tokens.append(next_token)
96
+
97
+ # 如果生成了结束符,则停止生成
98
+ if next_token == 2: # </s>
99
+ break
100
+
101
+ # 准备下一步的输入
102
+ next_input_embeds = text_embed.run(None, {
103
+ "input_ids": np.array([[next_token]], dtype=np.int64)
104
+ })[0]
105
+
106
+ # 运行decoder的decode阶段
107
+ decoder_outs = decoder_decode.run(None, {
108
+ "use_cache_branch": np.array([True], dtype=np.bool_),
109
+ "inputs_embeds": next_input_embeds,
110
+ "encoder_hidden_states": encoder_hidden_states,
111
+ "encoder_attention_mask": attention_mask.astype(np.int64),
112
+ "past_key_values.0.decoder.key": decoder_kv[0],
113
+ "past_key_values.0.decoder.value": decoder_kv[1],
114
+ "past_key_values.0.encoder.key": encoder_kv[2],
115
+ "past_key_values.0.encoder.value": encoder_kv[3],
116
+ "past_key_values.1.decoder.key": decoder_kv[4],
117
+ "past_key_values.1.decoder.value": decoder_kv[5],
118
+ "past_key_values.1.encoder.key": encoder_kv[6],
119
+ "past_key_values.1.encoder.value": encoder_kv[7],
120
+ "past_key_values.2.decoder.key": decoder_kv[8],
121
+ "past_key_values.2.decoder.value": decoder_kv[9],
122
+ "past_key_values.2.encoder.key": encoder_kv[10],
123
+ "past_key_values.2.encoder.value": encoder_kv[11],
124
+ "past_key_values.3.decoder.key": decoder_kv[12],
125
+ "past_key_values.3.decoder.value": decoder_kv[13],
126
+ "past_key_values.3.encoder.key": encoder_kv[14],
127
+ "past_key_values.3.encoder.value": encoder_kv[15],
128
+ "past_key_values.4.decoder.key": decoder_kv[16],
129
+ "past_key_values.4.decoder.value": decoder_kv[17],
130
+ "past_key_values.4.encoder.key": encoder_kv[18],
131
+ "past_key_values.4.encoder.value": encoder_kv[19],
132
+ "past_key_values.5.decoder.key": decoder_kv[20],
133
+ "past_key_values.5.decoder.value": decoder_kv[21],
134
+ "past_key_values.5.encoder.key": encoder_kv[22],
135
+ "past_key_values.5.encoder.value": encoder_kv[23],
136
+ })
137
+ for output in decoder_outs:
138
+ print(output.shape)
139
+
140
+ # print("generated_token: ", processor.decode(next_token, skip_special_tokens=False))
141
+
142
+ # 删除第一个token
143
+ # generated_tokens = generated_tokens[1:]
144
+ # 将生成的tokens转换为文本
145
+ print("generated_tokens: ", generated_tokens)
146
+ generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
147
+ print("Generated Text:", generated_text)
148
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
149
+ print("Parsed Answer:", parsed_answer)
onnx/rknnrun.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rknnlite.api.rknn_lite import RKNNLite
2
+ from transformers import AutoProcessor
3
+ from PIL import Image
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import time
7
+ # set current working directory to the directory of this file
8
+ import os
9
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
10
+
11
+ # 初始化总时间计数器
12
+ total_time = 0
13
+
14
+ # Initialize RKNNLite instances
15
+ rknn_vision_encoder = RKNNLite(verbose=False)
16
+ rknn_encoder = RKNNLite(verbose=False)
17
+ rknn_decoder_prefill = RKNNLite(verbose=False)
18
+
19
+ # Load RKNN models
20
+ ret = rknn_vision_encoder.load_rknn('./vision_encoder.rknn')
21
+ ret = rknn_encoder.load_rknn('./encoder_model.rknn')
22
+ ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn')
23
+
24
+ # Init runtime environment for each model
25
+ ret = rknn_vision_encoder.init_runtime()
26
+ ret = rknn_encoder.init_runtime()
27
+ ret = rknn_decoder_prefill.init_runtime()
28
+
29
+ text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
30
+ decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
31
+ # vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
32
+
33
+ # 1. prepare inputs
34
+ processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
35
+
36
+ # 2. prepare image
37
+ image = Image.open("./lena.png")
38
+
39
+ # resize image to 512x512
40
+ image = image.resize((512, 512))
41
+ # 3. prepare text
42
+ prompt = "<MORE_DETAILED_CAPTION>"
43
+ inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
44
+ for k, v in inputs.items():
45
+ print(k, v.shape)
46
+
47
+ # 4. run vision encoder using RKNN
48
+ start_time = time.time()
49
+ image_features = rknn_vision_encoder.inference(inputs=[inputs["pixel_values"]])[0]
50
+ end_time = time.time()
51
+ vision_encoder_time = (end_time - start_time) * 1000
52
+ total_time += vision_encoder_time
53
+ print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
54
+ print(image_features.shape)
55
+ np.save("image_features.npy", image_features)
56
+
57
+ # 5. run text embed using RKNN
58
+ start_time = time.time()
59
+ inputs_embeds = text_embed.run(None, {
60
+ "input_ids": inputs["input_ids"]
61
+ })[0]
62
+ end_time = time.time()
63
+ text_embed_time = (end_time - start_time) * 1000
64
+ total_time += text_embed_time
65
+ print(f"Text embed time: {text_embed_time:.2f} ms")
66
+ print(inputs_embeds.shape)
67
+
68
+ # 6. concat image features and text embed
69
+ batch_size, image_token_length = image_features.shape[:-1]
70
+ image_attention_mask = np.ones((batch_size, image_token_length))
71
+ task_prefix_embeds = inputs_embeds
72
+ task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
73
+ if len(task_prefix_attention_mask.shape) == 3:
74
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
75
+ inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
76
+ attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
77
+
78
+ # 6. run encoder using RKNN
79
+ start_time = time.time()
80
+ encoder_out = rknn_encoder.inference(inputs=[attention_mask.astype(np.int64),inputs_embeds])
81
+ end_time = time.time()
82
+ encoder_time = (end_time - start_time) * 1000
83
+ total_time += encoder_time
84
+ print(f"Encoder time: {encoder_time:.2f} ms")
85
+ encoder_hidden_states = encoder_out[0]
86
+ print(encoder_hidden_states.shape)
87
+
88
+ # 7. run decoder prefill stage using RKNN
89
+ start_time = time.time()
90
+ decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]])
91
+ end_time = time.time()
92
+ decoder_prefill_time = (end_time - start_time) * 1000
93
+ total_time += decoder_prefill_time
94
+ print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
95
+ # for output in decoder_outs:
96
+ # print(output.shape)
97
+
98
+ encoder_kv = decoder_outs[1:]
99
+
100
+ # 8. run decoder decode stage(autoregressive) (using onnxruntime)
101
+ generated_tokens = []
102
+ max_new_tokens = 32
103
+ decoder_decode_total_time = 0
104
+ while generated_tokens.__len__() < max_new_tokens:
105
+ # 获取上一步的输出
106
+ logits = decoder_outs[0]
107
+ decoder_kv = decoder_outs[1:]
108
+
109
+ # 选择最后一个token的logits
110
+ next_token_logits = logits[:, -1, :]
111
+
112
+ # 使用argmax选择下一个token (贪心算法)
113
+ next_token = np.argmax(next_token_logits, axis=-1)[0]
114
+ # print("next_token: ", next_token)
115
+ # 将新生成的token添加到结果中
116
+ generated_tokens.append(next_token)
117
+
118
+ # 如果生成了结束符,则停止生成
119
+ if next_token == 2: # </s>
120
+ break
121
+
122
+ # 准备下一步的输入
123
+ start_time = time.time()
124
+ next_input_embeds = text_embed.run(None, {
125
+ "input_ids": np.array([[next_token]], dtype=np.int64)
126
+ })[0]
127
+ end_time = time.time()
128
+ text_embed_time = (end_time - start_time) * 1000
129
+ decoder_decode_total_time += text_embed_time
130
+
131
+ # 运行decoder的decode阶段
132
+ start_time = time.time()
133
+ decoder_outs = decoder_decode.run(None, {
134
+ "use_cache_branch": np.array([True], dtype=np.bool_),
135
+ "inputs_embeds": next_input_embeds,
136
+ "encoder_hidden_states": encoder_hidden_states,
137
+ "encoder_attention_mask": attention_mask.astype(np.int64),
138
+ "past_key_values.0.decoder.key": decoder_kv[0],
139
+ "past_key_values.0.decoder.value": decoder_kv[1],
140
+ "past_key_values.0.encoder.key": encoder_kv[2],
141
+ "past_key_values.0.encoder.value": encoder_kv[3],
142
+ "past_key_values.1.decoder.key": decoder_kv[4],
143
+ "past_key_values.1.decoder.value": decoder_kv[5],
144
+ "past_key_values.1.encoder.key": encoder_kv[6],
145
+ "past_key_values.1.encoder.value": encoder_kv[7],
146
+ "past_key_values.2.decoder.key": decoder_kv[8],
147
+ "past_key_values.2.decoder.value": decoder_kv[9],
148
+ "past_key_values.2.encoder.key": encoder_kv[10],
149
+ "past_key_values.2.encoder.value": encoder_kv[11],
150
+ "past_key_values.3.decoder.key": decoder_kv[12],
151
+ "past_key_values.3.decoder.value": decoder_kv[13],
152
+ "past_key_values.3.encoder.key": encoder_kv[14],
153
+ "past_key_values.3.encoder.value": encoder_kv[15],
154
+ "past_key_values.4.decoder.key": decoder_kv[16],
155
+ "past_key_values.4.decoder.value": decoder_kv[17],
156
+ "past_key_values.4.encoder.key": encoder_kv[18],
157
+ "past_key_values.4.encoder.value": encoder_kv[19],
158
+ "past_key_values.5.decoder.key": decoder_kv[20],
159
+ "past_key_values.5.decoder.value": decoder_kv[21],
160
+ "past_key_values.5.encoder.key": encoder_kv[22],
161
+ "past_key_values.5.encoder.value": encoder_kv[23],
162
+ })
163
+ end_time = time.time()
164
+ decoder_decode_time = (end_time - start_time) * 1000
165
+ decoder_decode_total_time += decoder_decode_time
166
+
167
+ total_time += decoder_decode_total_time
168
+ print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
169
+
170
+ # 将生成的tokens转换为文本
171
+ print("generated_tokens: ", generated_tokens)
172
+ generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
173
+ print("Generated Text:", generated_text)
174
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
175
+ print("Parsed Answer:", parsed_answer)
176
+
177
+ print(f"Total inference time: {total_time:.2f} ms")
178
+
179
+ # Release RKNNLite instances
180
+ rknn_vision_encoder.release()
181
+ rknn_encoder.release()
182
+ rknn_decoder_prefill.release()
onnx/vision_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d67258cdfdebfa21285dad9e7bd4bd99725236d0aaef9e474a1b24a6ec471351
3
+ size 366549825
onnx/vision_encoder.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1824e5d3dbb2cdd02244d38e95c9bafb325a8307bf776ffcbe013776ea28a701
3
+ size 230119500
preprocessor_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_florence2.Florence2Processor"
4
+ },
5
+ "_valid_processor_keys": [
6
+ "images",
7
+ "do_resize",
8
+ "size",
9
+ "resample",
10
+ "do_rescale",
11
+ "rescale_factor",
12
+ "do_normalize",
13
+ "image_mean",
14
+ "image_std",
15
+ "return_tensors",
16
+ "data_format",
17
+ "input_data_format",
18
+ "do_convert_rgb"
19
+ ],
20
+ "do_convert_rgb": null,
21
+ "do_normalize": true,
22
+ "do_rescale": true,
23
+ "do_resize": true,
24
+ "do_center_crop": false,
25
+ "image_processor_type": "CLIPImageProcessor",
26
+ "image_seq_length": 577,
27
+ "image_mean": [0.485, 0.456, 0.406],
28
+ "image_std": [0.229, 0.224, 0.225],
29
+ "processor_class": "Florence2Processor",
30
+ "resample": 3,
31
+ "size": {
32
+ "height": 768,
33
+ "width":768
34
+ },
35
+ "crop_size": {
36
+ "height": 768,
37
+ "width": 768
38
+ }
39
+ }
processing_florence2.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Florence-2.
17
+ """
18
+
19
+ import re
20
+ import logging
21
+ from typing import List, Optional, Union
22
+ import numpy as np
23
+
24
+ import torch
25
+
26
+ from transformers.feature_extraction_utils import BatchFeature
27
+ from transformers.image_utils import ImageInput, is_valid_image
28
+ from transformers.processing_utils import ProcessorMixin
29
+ from transformers.tokenization_utils_base import (
30
+ PaddingStrategy,
31
+ PreTokenizedInput,
32
+ TextInput,
33
+ TruncationStrategy,
34
+ )
35
+ from transformers.utils import TensorType
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Copied from transformers.models.idefics2.processing_idefics2.is_url
41
+ def is_url(val) -> bool:
42
+ return isinstance(val, str) and val.startswith("http")
43
+
44
+ # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
45
+ def is_image_or_image_url(elem):
46
+ return is_url(elem) or is_valid_image(elem)
47
+
48
+
49
+ def _is_str_or_image(elem):
50
+ return isinstance(elem, (str)) or is_image_or_image_url(elem)
51
+
52
+
53
+ class Florence2Processor(ProcessorMixin):
54
+ r"""
55
+ Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
56
+
57
+ [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
58
+ [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
59
+
60
+ Args:
61
+ image_processor ([`CLIPImageProcessor`], *optional*):
62
+ The image processor is a required input.
63
+ tokenizer ([`BartTokenizerFast`], *optional*):
64
+ The tokenizer is a required input.
65
+ """
66
+
67
+ attributes = ["image_processor", "tokenizer"]
68
+ image_processor_class = "CLIPImageProcessor"
69
+ tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
70
+
71
+ def __init__(
72
+ self,
73
+ image_processor=None,
74
+ tokenizer=None,
75
+ ):
76
+ if image_processor is None:
77
+ raise ValueError("You need to specify an `image_processor`.")
78
+ if tokenizer is None:
79
+ raise ValueError("You need to specify a `tokenizer`.")
80
+ if not hasattr(image_processor, "image_seq_length"):
81
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
82
+
83
+ self.image_seq_length = image_processor.image_seq_length
84
+
85
+ tokens_to_add = {
86
+ 'additional_special_tokens': \
87
+ tokenizer.additional_special_tokens + \
88
+ ['<od>', '</od>', '<ocr>', '</ocr>'] + \
89
+ [f'<loc_{x}>' for x in range(1000)] + \
90
+ ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
91
+ }
92
+ tokenizer.add_special_tokens(tokens_to_add)
93
+
94
+ self.tasks_answer_post_processing_type = {
95
+ '<OCR>': 'pure_text',
96
+ '<OCR_WITH_REGION>': 'ocr',
97
+ '<CAPTION>': 'pure_text',
98
+ '<DETAILED_CAPTION>': 'pure_text',
99
+ '<MORE_DETAILED_CAPTION>': 'pure_text',
100
+ '<OD>': 'description_with_bboxes',
101
+ '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
102
+ '<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
103
+ '<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
104
+ '<REGION_TO_SEGMENTATION>': 'polygons',
105
+ '<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
106
+ '<REGION_TO_CATEGORY>': 'pure_text',
107
+ '<REGION_TO_DESCRIPTION>': 'pure_text',
108
+ '<REGION_TO_OCR>': 'pure_text',
109
+ '<REGION_PROPOSAL>': 'bboxes'
110
+ }
111
+
112
+ self.task_prompts_without_inputs = {
113
+ '<OCR>': 'What is the text in the image?',
114
+ '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
115
+ '<CAPTION>': 'What does the image describe?',
116
+ '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
117
+ '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
118
+ '<OD>': 'Locate the objects with category name in the image.',
119
+ '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
120
+ '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
121
+ }
122
+
123
+ self.task_prompts_with_input = {
124
+ '<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
125
+ '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
126
+ '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
127
+ '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
128
+ '<REGION_TO_CATEGORY>': 'What is the region {input}?',
129
+ '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
130
+ '<REGION_TO_OCR>': 'What text is in the region {input}?',
131
+ }
132
+
133
+ self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
134
+
135
+
136
+ super().__init__(image_processor, tokenizer)
137
+
138
+ def _construct_prompts(self, text):
139
+ # replace the task tokens with the task prompts if task token is in the text
140
+ prompts = []
141
+ for _text in text:
142
+ # 1. fixed task prompts without additional inputs
143
+ for task_token, task_prompt in self.task_prompts_without_inputs.items():
144
+ if task_token in _text:
145
+ assert _text == task_token, f"Task token {task_token} should be the only token in the text."
146
+ _text = task_prompt
147
+ break
148
+ # 2. task prompts with additional inputs
149
+ for task_token, task_prompt in self.task_prompts_with_input.items():
150
+ if task_token in _text:
151
+ _text = task_prompt.format(input=_text.replace(task_token, ''))
152
+ break
153
+ prompts.append(_text)
154
+ return prompts
155
+
156
+ def __call__(
157
+ self,
158
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
159
+ images: ImageInput = None,
160
+ tokenize_newline_separately: bool = True,
161
+ padding: Union[bool, str, PaddingStrategy] = False,
162
+ truncation: Union[bool, str, TruncationStrategy] = None,
163
+ max_length=None,
164
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
165
+ do_resize: bool = None,
166
+ do_normalize: bool = None,
167
+ image_mean: Optional[Union[float, List[float]]] = None,
168
+ image_std: Optional[Union[float, List[float]]] = None,
169
+ data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
170
+ input_data_format: Optional[
171
+ Union[str, "ChannelDimension"] # noqa: F821
172
+ ] = None,
173
+ resample: "PILImageResampling" = None, # noqa: F821
174
+ do_convert_rgb: bool = None,
175
+ do_thumbnail: bool = None,
176
+ do_align_long_axis: bool = None,
177
+ do_rescale: bool = None,
178
+ ) -> BatchFeature:
179
+ """
180
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
181
+ and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
182
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
183
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
184
+ of the above two methods for more information.
185
+
186
+ Args:
187
+ text (`str`, `List[str]`, `List[List[str]]`):
188
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
189
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
190
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
191
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
192
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
193
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
194
+ number of channels, H and W are image height and width.
195
+ tokenize_newline_separately (`bool`, defaults to `True`):
196
+ Adds a separately tokenized '\n' at the end of the prompt.
197
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
198
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
199
+ index) among:
200
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
201
+ sequence if provided).
202
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
203
+ acceptable input length for the model if that argument is not provided.
204
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
205
+ lengths).
206
+ max_length (`int`, *optional*):
207
+ Maximum length of the returned list and optionally padding length (see above).
208
+ truncation (`bool`, *optional*):
209
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
210
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
211
+ If set, will return tensors of a particular framework. Acceptable values are:
212
+
213
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
214
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
215
+ - `'np'`: Return NumPy `np.ndarray` objects.
216
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
217
+
218
+ Returns:
219
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
220
+
221
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
222
+ is provided, the `input_ids` will also contain the suffix input ids.
223
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
224
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
225
+ `None`).
226
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
227
+ - **labels** -- Labels compatible with training if `suffix` is not None
228
+ """
229
+
230
+ return_token_type_ids = False
231
+
232
+ if images is None:
233
+ raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
234
+ if text is None:
235
+ logger.warning_once(
236
+ "You are using Florence-2 without a text prompt."
237
+ )
238
+ text = ""
239
+
240
+ if isinstance(text, List) and isinstance(images, List):
241
+ if len(images) < len(text):
242
+ raise ValueError(
243
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
244
+ )
245
+ if _is_str_or_image(text):
246
+ text = [text]
247
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
248
+ pass
249
+
250
+ pixel_values = self.image_processor(
251
+ images,
252
+ do_resize=do_resize,
253
+ do_normalize=do_normalize,
254
+ return_tensors=return_tensors,
255
+ image_mean=image_mean,
256
+ image_std=image_std,
257
+ input_data_format=input_data_format,
258
+ data_format=data_format,
259
+ resample=resample,
260
+ do_convert_rgb=do_convert_rgb,
261
+ )["pixel_values"]
262
+
263
+ if max_length is not None:
264
+ max_length -= self.image_seq_length # max_length has to account for the image tokens
265
+
266
+ text = self._construct_prompts(text)
267
+
268
+ inputs = self.tokenizer(
269
+ text,
270
+ return_tensors=return_tensors,
271
+ padding=padding,
272
+ max_length=max_length,
273
+ truncation=truncation,
274
+ return_token_type_ids=return_token_type_ids,
275
+ )
276
+
277
+ return_data = {**inputs, "pixel_values": pixel_values}
278
+
279
+ if return_token_type_ids:
280
+ labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
281
+ return_data.update({"labels": labels})
282
+ return BatchFeature(data=return_data)
283
+
284
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
285
+ def batch_decode(self, *args, **kwargs):
286
+ """
287
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
288
+ refer to the docstring of this method for more information.
289
+ """
290
+ return self.tokenizer.batch_decode(*args, **kwargs)
291
+
292
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
293
+ def decode(self, *args, **kwargs):
294
+ """
295
+ This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
296
+ the docstring of this method for more information.
297
+ """
298
+ return self.tokenizer.decode(*args, **kwargs)
299
+
300
+ @property
301
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
302
+ def model_input_names(self):
303
+ tokenizer_input_names = self.tokenizer.model_input_names
304
+ image_processor_input_names = self.image_processor.model_input_names
305
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
+
307
+ def post_process_generation(self, text, task, image_size):
308
+ """
309
+ Post-process the output of the model to each of the task outputs.
310
+
311
+ Args:
312
+ text (`str`): The text to post-process.
313
+ task (`str`): The task to post-process the text for.
314
+ image_size (`Tuple[int, int]`): The size of the image. height x width.
315
+ """
316
+
317
+ task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
+ task_answer = self.post_processor(
319
+ text=text,
320
+ image_size=image_size,
321
+ parse_tasks=task_answer_post_processing_type,
322
+ )[task_answer_post_processing_type]
323
+
324
+ if task_answer_post_processing_type == 'pure_text':
325
+ final_answer = task_answer
326
+ # remove the special tokens
327
+ final_answer = final_answer.replace('<s>', '').replace('</s>', '')
328
+ elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
329
+ od_instances = task_answer
330
+ bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
+ labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
+ final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
333
+ elif task_answer_post_processing_type in ['ocr']:
334
+ bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
+ labels = [str(_od_instance['text']) for _od_instance in task_answer]
336
+ final_answer = {'quad_boxes': bboxes, 'labels': labels}
337
+ elif task_answer_post_processing_type in ['phrase_grounding']:
338
+ bboxes = []
339
+ labels = []
340
+ for _grounded_phrase in task_answer:
341
+ for _bbox in _grounded_phrase['bbox']:
342
+ bboxes.append(_bbox)
343
+ labels.append(_grounded_phrase['cat_name'])
344
+ final_answer = {'bboxes': bboxes, 'labels': labels}
345
+ elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
346
+ labels = []
347
+ polygons = []
348
+ for result in task_answer:
349
+ label = result['cat_name']
350
+ _polygons = result['polygons']
351
+ labels.append(label)
352
+ polygons.append(_polygons)
353
+ final_answer = {'polygons': polygons, 'labels': labels}
354
+ elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
355
+ bboxes = []
356
+ bboxes_labels = []
357
+ polygons = []
358
+ polygons_labels = []
359
+ for result in task_answer:
360
+ label = result['cat_name']
361
+ if 'polygons' in result:
362
+ _polygons = result['polygons']
363
+ polygons.append(_polygons)
364
+ polygons_labels.append(label)
365
+ else:
366
+ _bbox = result['bbox']
367
+ bboxes.append(_bbox)
368
+ bboxes_labels.append(label)
369
+ final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
370
+ else:
371
+ raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
372
+
373
+ final_answer = {
374
+ task: final_answer}
375
+ return final_answer
376
+
377
+ class BoxQuantizer(object):
378
+ def __init__(self, mode, bins):
379
+ self.mode = mode
380
+ self.bins = bins
381
+
382
+ def quantize(self, boxes: torch.Tensor, size):
383
+ bins_w, bins_h = self.bins # Quantization bins.
384
+ size_w, size_h = size # Original image size.
385
+ size_per_bin_w = size_w / bins_w
386
+ size_per_bin_h = size_h / bins_h
387
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
388
+
389
+ if self.mode == 'floor':
390
+ quantized_xmin = (
391
+ xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
392
+ quantized_ymin = (
393
+ ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
394
+ quantized_xmax = (
395
+ xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
396
+ quantized_ymax = (
397
+ ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
398
+
399
+ elif self.mode == 'round':
400
+ raise NotImplementedError()
401
+
402
+ else:
403
+ raise ValueError('Incorrect quantization type.')
404
+
405
+ quantized_boxes = torch.cat(
406
+ (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
407
+ ).int()
408
+
409
+ return quantized_boxes
410
+
411
+ def dequantize(self, boxes: torch.Tensor, size):
412
+ bins_w, bins_h = self.bins # Quantization bins.
413
+ size_w, size_h = size # Original image size.
414
+ size_per_bin_w = size_w / bins_w
415
+ size_per_bin_h = size_h / bins_h
416
+ xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
417
+
418
+ if self.mode == 'floor':
419
+ # Add 0.5 to use the center position of the bin as the coordinate.
420
+ dequantized_xmin = (xmin + 0.5) * size_per_bin_w
421
+ dequantized_ymin = (ymin + 0.5) * size_per_bin_h
422
+ dequantized_xmax = (xmax + 0.5) * size_per_bin_w
423
+ dequantized_ymax = (ymax + 0.5) * size_per_bin_h
424
+
425
+ elif self.mode == 'round':
426
+ raise NotImplementedError()
427
+
428
+ else:
429
+ raise ValueError('Incorrect quantization type.')
430
+
431
+ dequantized_boxes = torch.cat(
432
+ (dequantized_xmin, dequantized_ymin,
433
+ dequantized_xmax, dequantized_ymax), dim=-1
434
+ )
435
+
436
+ return dequantized_boxes
437
+
438
+
439
+ class CoordinatesQuantizer(object):
440
+ """
441
+ Quantize coornidates (Nx2)
442
+ """
443
+
444
+ def __init__(self, mode, bins):
445
+ self.mode = mode
446
+ self.bins = bins
447
+
448
+ def quantize(self, coordinates: torch.Tensor, size):
449
+ bins_w, bins_h = self.bins # Quantization bins.
450
+ size_w, size_h = size # Original image size.
451
+ size_per_bin_w = size_w / bins_w
452
+ size_per_bin_h = size_h / bins_h
453
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
454
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
455
+
456
+ if self.mode == 'floor':
457
+ quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
458
+ quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
459
+
460
+ elif self.mode == 'round':
461
+ raise NotImplementedError()
462
+
463
+ else:
464
+ raise ValueError('Incorrect quantization type.')
465
+
466
+ quantized_coordinates = torch.cat(
467
+ (quantized_x, quantized_y), dim=-1
468
+ ).int()
469
+
470
+ return quantized_coordinates
471
+
472
+ def dequantize(self, coordinates: torch.Tensor, size):
473
+ bins_w, bins_h = self.bins # Quantization bins.
474
+ size_w, size_h = size # Original image size.
475
+ size_per_bin_w = size_w / bins_w
476
+ size_per_bin_h = size_h / bins_h
477
+ assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
478
+ x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
479
+
480
+ if self.mode == 'floor':
481
+ # Add 0.5 to use the center position of the bin as the coordinate.
482
+ dequantized_x = (x + 0.5) * size_per_bin_w
483
+ dequantized_y = (y + 0.5) * size_per_bin_h
484
+
485
+ elif self.mode == 'round':
486
+ raise NotImplementedError()
487
+
488
+ else:
489
+ raise ValueError('Incorrect quantization type.')
490
+
491
+ dequantized_coordinates = torch.cat(
492
+ (dequantized_x, dequantized_y), dim=-1
493
+ )
494
+
495
+ return dequantized_coordinates
496
+
497
+
498
+ class Florence2PostProcesser(object):
499
+ """
500
+ Florence-2 post process for converting text prediction to various tasks results.
501
+
502
+ Args:
503
+ config: A dict of configs.
504
+ tokenizer: A tokenizer for decoding text to spans.
505
+ sample config:
506
+ UNIFIED_POST_PROCESS:
507
+ # commom configs
508
+ NUM_BBOX_HEIGHT_BINS: 1000
509
+ NUM_BBOX_WIDTH_BINS: 1000
510
+ COORDINATES_HEIGHT_BINS: 1000
511
+ COORDINATES_WIDTH_BINS: 1000
512
+ # task specific configs, override the common configs
513
+ PRASE_TASKS:
514
+ - TASK_NAME: 'video_dense_caption'
515
+ PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
516
+ SCORE_MODE: 'avg_cat_name_scores'
517
+ NUM_BINS: 100
518
+ - TASK_NAME: 'od'
519
+ PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
520
+ SCORE_MODE: 'avg_cat_name_scores'
521
+
522
+ Returns:
523
+ parsed_dict (dict): A dict of parsed results.
524
+ """
525
+ def __init__(
526
+ self,
527
+ tokenizer=None
528
+ ):
529
+ parse_tasks = []
530
+ parse_task_configs = {}
531
+ config = self._create_default_config()
532
+ for task in config['PARSE_TASKS']:
533
+ parse_tasks.append(task['TASK_NAME'])
534
+ parse_task_configs[task['TASK_NAME']] = task
535
+
536
+ self.config = config
537
+ self.parse_tasks = parse_tasks
538
+ self.parse_tasks_configs = parse_task_configs
539
+
540
+ self.tokenizer = tokenizer
541
+ if self.tokenizer is not None:
542
+ self.all_special_tokens = set(self.tokenizer.all_special_tokens)
543
+
544
+ self.init_quantizers()
545
+ self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
546
+
547
+ def _create_black_list_of_phrase_grounding(self):
548
+ black_list = {}
549
+
550
+ if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
551
+ black_list = set(
552
+ ['it', 'I', 'me', 'mine',
553
+ 'you', 'your', 'yours',
554
+ 'he', 'him', 'his',
555
+ 'she', 'her', 'hers',
556
+ 'they', 'them', 'their', 'theirs',
557
+ 'one', 'oneself',
558
+ 'we', 'us', 'our', 'ours',
559
+ 'you', 'your', 'yours',
560
+ 'they', 'them', 'their', 'theirs',
561
+ 'mine', 'yours', 'his', 'hers', 'its',
562
+ 'ours', 'yours', 'theirs',
563
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
564
+ 'ourselves', 'yourselves', 'themselves',
565
+ 'this', 'that',
566
+ 'these', 'those',
567
+ 'who', 'whom', 'whose', 'which', 'what',
568
+ 'who', 'whom', 'whose', 'which', 'that',
569
+ 'all', 'another', 'any', 'anybody', 'anyone', 'anything',
570
+ 'each', 'everybody', 'everyone', 'everything',
571
+ 'few', 'many', 'nobody', 'none', 'one', 'several',
572
+ 'some', 'somebody', 'someone', 'something',
573
+ 'each other', 'one another',
574
+ 'myself', 'yourself', 'himself', 'herself', 'itself',
575
+ 'ourselves', 'yourselves', 'themselves',
576
+ 'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
577
+ 'other objects', 'lots', 'a set',
578
+ ]
579
+ )
580
+
581
+ return black_list
582
+
583
+ def _create_default_config(self):
584
+ config = {
585
+ 'NUM_BBOX_HEIGHT_BINS': 1000,
586
+ 'NUM_BBOX_WIDTH_BINS': 1000,
587
+ 'BOX_QUANTIZATION_MODE': 'floor',
588
+ 'COORDINATES_HEIGHT_BINS': 1000,
589
+ 'COORDINATES_WIDTH_BINS': 1000,
590
+ 'COORDINATES_QUANTIZATION_MODE': 'floor',
591
+ 'PARSE_TASKS': [
592
+ {
593
+ 'TASK_NAME': 'od',
594
+ 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
595
+ },
596
+ {
597
+ 'TASK_NAME': 'ocr',
598
+ 'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
599
+ 'AREA_THRESHOLD': 0.00
600
+ },
601
+ {
602
+ 'TASK_NAME': 'phrase_grounding',
603
+ 'FILTER_BY_BLACK_LIST': True
604
+ },
605
+ {
606
+ 'TASK_NAME': 'pure_text',
607
+ },
608
+ {
609
+ 'TASK_NAME': 'description_with_bboxes',
610
+ },
611
+ {
612
+ 'TASK_NAME': 'description_with_polygons',
613
+ },
614
+ {
615
+ 'TASK_NAME': 'polygons',
616
+ },
617
+ {
618
+ 'TASK_NAME': 'bboxes',
619
+ },
620
+ {
621
+ 'TASK_NAME': 'description_with_bboxes_or_polygons',
622
+ }
623
+ ]
624
+ }
625
+
626
+ return config
627
+
628
+ def init_quantizers(self):
629
+ # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
630
+ num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
631
+ num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
632
+ box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
633
+ self.box_quantizer = BoxQuantizer(
634
+ box_quantization_mode,
635
+ (num_bbox_width_bins, num_bbox_height_bins),
636
+ )
637
+
638
+ num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
639
+ num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
640
+ box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
641
+ self.coordinates_quantizer = CoordinatesQuantizer(
642
+ box_quantization_mode,
643
+ (num_bbox_width_bins, num_bbox_height_bins),
644
+ )
645
+
646
+ def decode_with_spans(self, tokenizer, token_ids):
647
+ filtered_tokens = tokenizer.convert_ids_to_tokens(
648
+ token_ids, skip_special_tokens=False)
649
+ assert len(filtered_tokens) == len(token_ids)
650
+
651
+ # To avoid mixing byte-level and unicode for byte-level BPT
652
+ # we need to build string separately for added tokens and byte-level tokens
653
+ # cf. https://github.com/huggingface/transformers/issues/1133
654
+ sub_texts = []
655
+ for token in filtered_tokens:
656
+ if token in self.all_special_tokens:
657
+ sub_texts.append(token)
658
+ else:
659
+ if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
+ sub_text = tokenizer.convert_tokens_to_string([token])
661
+ elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
+ # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
+ # Note: Do not strip sub_text as it may have functional whitespace
664
+ sub_text = token.replace('▁', ' ')
665
+ else:
666
+ raise ValueError(f'type {type(tokenizer)} not supported')
667
+ sub_texts.append(sub_text)
668
+
669
+ text = ''
670
+ spans = []
671
+ for sub_text in sub_texts:
672
+ span = (len(text), len(text) + len(sub_text)) # [start index, end index).
673
+ text += sub_text
674
+ spans.append(span)
675
+
676
+ # Text format:
677
+ # 1. T5Tokenizer/T5TokenizerFast:
678
+ # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
+ # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
+ # 2. BartTokenizer (need to double check):
681
+ # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
+ # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
+ return text, spans
684
+
685
+ def parse_od_from_text_and_spans(
686
+ self,
687
+ text,
688
+ pattern,
689
+ image_size,
690
+ phrase_centric=False
691
+ ):
692
+ parsed = list(re.finditer(pattern, text))
693
+
694
+ instances = []
695
+ for i in range(len(parsed)):
696
+ # Prepare instance.
697
+ instance = {}
698
+
699
+ if phrase_centric:
700
+ bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
701
+ else:
702
+ bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
703
+ instance['bbox'] = self.box_quantizer.dequantize(
704
+ boxes=torch.tensor(bbox_bins),
705
+ size=image_size
706
+ ).tolist()
707
+
708
+ if phrase_centric:
709
+ instance['cat_name'] = parsed[i].group(1).lower().strip()
710
+ else:
711
+ instance['cat_name'] = parsed[i].group(5).lower().strip()
712
+ instances.append(instance)
713
+
714
+ return instances
715
+
716
+ def parse_ocr_from_text_and_spans(self,
717
+ text,
718
+ pattern,
719
+ image_size,
720
+ area_threshold=-1.0,
721
+ ):
722
+ bboxes = []
723
+ labels = []
724
+ text = text.replace('<s>', '')
725
+ # ocr with regions
726
+ parsed = re.findall(pattern, text)
727
+ instances = []
728
+ image_width, image_height = image_size
729
+
730
+ for ocr_line in parsed:
731
+ ocr_content = ocr_line[0]
732
+ quad_box = ocr_line[1:]
733
+ quad_box = [int(i) for i in quad_box]
734
+ quad_box = self.coordinates_quantizer.dequantize(
735
+ torch.tensor(np.array(quad_box).reshape(-1, 2)),
736
+ size=image_size
737
+ ).reshape(-1).tolist()
738
+
739
+ if area_threshold > 0:
740
+ x_coords = [i for i in quad_box[0::2]]
741
+ y_coords = [i for i in quad_box[1::2]]
742
+
743
+ # apply the Shoelace formula
744
+ area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
745
+
746
+ if area < (image_width * image_height) * area_threshold:
747
+ continue
748
+
749
+ bboxes.append(quad_box)
750
+ labels.append(ocr_content)
751
+ instances.append({
752
+ 'quad_box': quad_box,
753
+ 'text': ocr_content,
754
+ })
755
+ return instances
756
+
757
+ def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
758
+ # ignore <s> </s> and <pad>
759
+ cur_span = 0
760
+ if text.startswith('<s>'):
761
+ cur_span += 3
762
+
763
+ text = text.replace('<s>', '')
764
+ text = text.replace('</s>', '')
765
+ text = text.replace('<pad>', '')
766
+
767
+ pattern = r"([^<]+(?:<loc_\d+>){4,})"
768
+ phrases = re.findall(pattern, text)
769
+
770
+ # pattern should be text pattern and od pattern
771
+ pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
772
+ box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
773
+
774
+ instances = []
775
+ for pharse_text in phrases:
776
+ phrase_text_strip = pharse_text.replace('<ground>', '', 1)
777
+ phrase_text_strip = pharse_text.replace('<obj>', '', 1)
778
+
779
+ if phrase_text_strip == '':
780
+ cur_span += len(pharse_text)
781
+ continue
782
+
783
+ # Prepare instance.
784
+ instance = {}
785
+
786
+ # parse phrase, get string
787
+ phrase = re.search(pattern, phrase_text_strip)
788
+ if phrase is None:
789
+ cur_span += len(pharse_text)
790
+ continue
791
+
792
+ # parse bboxes by box_pattern
793
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
794
+ if len(bboxes_parsed) == 0:
795
+ cur_span += len(pharse_text)
796
+ continue
797
+
798
+ phrase = phrase.group()
799
+ # remove leading and trailing spaces
800
+ phrase = phrase.strip()
801
+
802
+ if phrase in self.black_list_of_phrase_grounding:
803
+ cur_span += len(pharse_text)
804
+ continue
805
+
806
+ # a list of list
807
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
808
+ instance['bbox'] = self.box_quantizer.dequantize(
809
+ boxes=torch.tensor(bbox_bins),
810
+ size=image_size
811
+ ).tolist()
812
+
813
+ # exclude non-ascii characters
814
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
815
+ instance['cat_name'] = phrase
816
+
817
+ instances.append(instance)
818
+
819
+ return instances
820
+
821
+ def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
+ # temporary parse solution, split by '.'
823
+ # ignore <s> </s> and <pad>
824
+
825
+ text = text.replace('<s>', '')
826
+ text = text.replace('</s>', '')
827
+ text = text.replace('<pad>', '')
828
+
829
+ if allow_empty_phrase:
830
+ pattern = rf"(?:(?:<loc_\d+>){{4,}})"
831
+ else:
832
+ pattern = r"([^<]+(?:<loc_\d+>){4,})"
833
+ phrases = re.findall(pattern, text)
834
+
835
+ # pattern should be text pattern and od pattern
836
+ pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
837
+ box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
838
+
839
+ instances = []
840
+ for pharse_text in phrases:
841
+ phrase_text_strip = pharse_text.replace('<ground>', '', 1)
842
+ phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
+
844
+ if phrase_text_strip == '' and not allow_empty_phrase:
845
+ continue
846
+
847
+ # parse phrase, get string
848
+ phrase = re.search(pattern, phrase_text_strip)
849
+ if phrase is None:
850
+ continue
851
+
852
+ phrase = phrase.group()
853
+ # remove leading and trailing spaces
854
+ phrase = phrase.strip()
855
+
856
+ # parse bboxes by box_pattern
857
+ bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
+ if len(bboxes_parsed) == 0:
859
+ continue
860
+
861
+ # a list of list
862
+ bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
863
+
864
+ bboxes = self.box_quantizer.dequantize(
865
+ boxes=torch.tensor(bbox_bins),
866
+ size=image_size
867
+ ).tolist()
868
+
869
+ phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
+ for _bboxes in bboxes:
871
+ # Prepare instance.
872
+ instance = {}
873
+ instance['bbox'] = _bboxes
874
+ # exclude non-ascii characters
875
+ instance['cat_name'] = phrase
876
+ instances.append(instance)
877
+
878
+ return instances
879
+
880
+ def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
881
+ allow_empty_phrase=False,
882
+ polygon_sep_token='<sep>',
883
+ polygon_start_token='<poly>',
884
+ polygon_end_token='</poly>',
885
+ with_box_at_start=False,
886
+ ):
887
+
888
+ # ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
889
+ # ignore <s> </s> and <pad>
890
+
891
+ text = text.replace('<s>', '')
892
+ text = text.replace('</s>', '')
893
+ text = text.replace('<pad>', '')
894
+
895
+ if allow_empty_phrase:
896
+ pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
897
+ else:
898
+ # [^<]+: This part matches one or more characters that are not the < symbol.
899
+ # The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
900
+ #
901
+ pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
902
+ phrases = re.findall(pattern, text)
903
+
904
+ phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
905
+ box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
906
+
907
+ # one polygons instance is separated by polygon_start_token and polygon_end_token
908
+ polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
909
+
910
+ instances = []
911
+ for phrase_text in phrases:
912
+
913
+ # exclude loc_\d+>
914
+ # need to get span if want to include category score
915
+ phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
916
+
917
+ # phrase = phrase.replace('<poly>', '')
918
+ # phrase = phrase.replace('poly>', '')
919
+
920
+ if phrase_text_strip == '' and not allow_empty_phrase:
921
+ continue
922
+
923
+
924
+ # parse phrase, get string
925
+ phrase = re.search(phrase_string_pattern, phrase_text_strip)
926
+ if phrase is None:
927
+ continue
928
+ phrase = phrase.group()
929
+ # remove leading and trailing spaces
930
+ phrase = phrase.strip()
931
+
932
+ # parse bboxes by box_pattern
933
+
934
+ # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
935
+ if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
936
+ polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
937
+ else:
938
+ polygons_instances_parsed = [phrase_text]
939
+
940
+ for _polygons_instances_parsed in polygons_instances_parsed:
941
+ # Prepare instance.
942
+ instance = {}
943
+
944
+ # polygons_parsed= list(re.finditer(box_pattern, phrase_text))
945
+ if isinstance(_polygons_instances_parsed, str):
946
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
947
+ else:
948
+ polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
949
+ if len(polygons_parsed) == 0:
950
+ continue
951
+
952
+ # a list of list (polygon)
953
+ bbox = []
954
+ polygons = []
955
+ for _polygon_parsed in polygons_parsed:
956
+ # group 1: whole <loc_\d+>...</loc_\d+>
957
+ _polygon = _polygon_parsed.group(1)
958
+ # parse into list of int
959
+ _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
960
+ if with_box_at_start and len(bbox) == 0:
961
+ if len(_polygon) > 4:
962
+ # no valid bbox prediction
963
+ bbox = _polygon[:4]
964
+ _polygon = _polygon[4:]
965
+ else:
966
+ bbox = [0, 0, 0, 0]
967
+ # abandon last element if is not paired
968
+ if len(_polygon) % 2 == 1:
969
+ _polygon = _polygon[:-1]
970
+
971
+ # reshape into (n, 2)
972
+ _polygon = self.coordinates_quantizer.dequantize(
973
+ torch.tensor(np.array(_polygon).reshape(-1, 2)),
974
+ size=image_size
975
+ ).reshape(-1).tolist()
976
+ # reshape back
977
+ polygons.append(_polygon)
978
+
979
+ instance['cat_name'] = phrase
980
+ instance['polygons'] = polygons
981
+ if len(bbox) != 0:
982
+ instance['bbox'] = self.box_quantizer.dequantize(
983
+ boxes=torch.tensor([bbox]),
984
+ size=image_size
985
+ ).tolist()[0]
986
+
987
+ instances.append(instance)
988
+
989
+ return instances
990
+
991
+ def __call__(
992
+ self,
993
+ text=None,
994
+ image_size=None,
995
+ parse_tasks=None,
996
+ ):
997
+ """
998
+ Args:
999
+ text: model outputs
1000
+ image_size: (width, height)
1001
+ parse_tasks: a list of tasks to parse, if None, parse all tasks.
1002
+
1003
+ """
1004
+ if parse_tasks is not None:
1005
+ if isinstance(parse_tasks, str):
1006
+ parse_tasks = [parse_tasks]
1007
+ for _parse_task in parse_tasks:
1008
+ assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
+
1010
+ # sequence or text should be provided
1011
+ assert text is not None, 'text should be provided'
1012
+
1013
+ parsed_dict = {
1014
+ 'text': text
1015
+ }
1016
+
1017
+ for task in self.parse_tasks:
1018
+ if parse_tasks is not None and task not in parse_tasks:
1019
+ continue
1020
+
1021
+ pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1022
+
1023
+ if task == 'ocr':
1024
+ instances = self.parse_ocr_from_text_and_spans(
1025
+ text,
1026
+ pattern=pattern,
1027
+ image_size=image_size,
1028
+ area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0),
1029
+ )
1030
+ parsed_dict['ocr'] = instances
1031
+ elif task == 'phrase_grounding':
1032
+ instances = self.parse_phrase_grounding_from_text_and_spans(
1033
+ text,
1034
+ pattern=pattern,
1035
+ image_size=image_size,
1036
+ )
1037
+ parsed_dict['phrase_grounding'] = instances
1038
+ elif task == 'pure_text':
1039
+ parsed_dict['pure_text'] = text
1040
+ elif task == 'description_with_bboxes':
1041
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
+ text,
1043
+ pattern=pattern,
1044
+ image_size=image_size,
1045
+ )
1046
+ parsed_dict['description_with_bboxes'] = instances
1047
+ elif task == 'description_with_polygons':
1048
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1049
+ text,
1050
+ pattern=pattern,
1051
+ image_size=image_size,
1052
+ )
1053
+ parsed_dict['description_with_polygons'] = instances
1054
+ elif task == 'polygons':
1055
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1056
+ text,
1057
+ pattern=pattern,
1058
+ image_size=image_size,
1059
+ allow_empty_phrase=True,
1060
+ )
1061
+ parsed_dict['polygons'] = instances
1062
+ elif task == 'bboxes':
1063
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1064
+ text,
1065
+ pattern=pattern,
1066
+ image_size=image_size,
1067
+ allow_empty_phrase=True,
1068
+ )
1069
+ parsed_dict['bboxes'] = instances
1070
+ elif task == 'description_with_bboxes_or_polygons':
1071
+ if '<poly>' in text:
1072
+ # only support either polygons or bboxes, not both at the same time
1073
+ instances = self.parse_description_with_polygons_from_text_and_spans(
1074
+ text,
1075
+ pattern=pattern,
1076
+ image_size=image_size,
1077
+ )
1078
+ else:
1079
+ instances = self.parse_description_with_bboxes_from_text_and_spans(
1080
+ text,
1081
+ pattern=pattern,
1082
+ image_size=image_size,
1083
+ )
1084
+ parsed_dict['description_with_bboxes_or_polygons'] = instances
1085
+ else:
1086
+ raise ValueError("task {} is not supported".format(task))
1087
+
1088
+ return parsed_dict
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_max_length": 1024
3
+ }
4
+
vocab.json ADDED
The diff for this file is too large to render. See raw diff