happyme531 commited on
Commit
385f65d
1 Parent(s): b6a79c4

Split part of vision encoder to CPU and optimize Transpose ops. (Reupload to correct path)

Browse files
Files changed (1) hide show
  1. onnx/rknnrun.py +157 -14
onnx/rknnrun.py CHANGED
@@ -1,9 +1,12 @@
 
1
  from rknnlite.api.rknn_lite import RKNNLite
2
  from transformers import AutoProcessor
3
- from PIL import Image
4
  import numpy as np
5
  import onnxruntime as ort
6
  import time
 
 
7
  # set current working directory to the directory of this file
8
  import os
9
  os.chdir(os.path.dirname(os.path.abspath(__file__)))
@@ -17,7 +20,7 @@ rknn_encoder = RKNNLite(verbose=False)
17
  rknn_decoder_prefill = RKNNLite(verbose=False)
18
 
19
  # Load RKNN models
20
- ret = rknn_vision_encoder.load_rknn('./vision_encoder.rknn')
21
  ret = rknn_encoder.load_rknn('./encoder_model.rknn')
22
  ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn')
23
 
@@ -26,27 +29,44 @@ ret = rknn_vision_encoder.init_runtime()
26
  ret = rknn_encoder.init_runtime()
27
  ret = rknn_decoder_prefill.init_runtime()
28
 
29
- text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
30
  decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
31
- # vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
 
32
 
33
  # 1. prepare inputs
34
  processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
35
 
36
  # 2. prepare image
37
- image = Image.open("./lena.png")
38
-
39
- # resize image to 512x512
40
- image = image.resize((512, 512))
 
41
  # 3. prepare text
42
  prompt = "<MORE_DETAILED_CAPTION>"
43
- inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
 
 
 
 
 
 
 
 
 
 
 
44
  for k, v in inputs.items():
45
  print(k, v.shape)
46
 
47
  # 4. run vision encoder using RKNN
48
  start_time = time.time()
49
- image_features = rknn_vision_encoder.inference(inputs=[inputs["pixel_values"]], data_format='nchw')[0]
 
 
 
 
50
  end_time = time.time()
51
  vision_encoder_time = (end_time - start_time) * 1000
52
  total_time += vision_encoder_time
@@ -87,6 +107,10 @@ print(encoder_hidden_states.shape)
87
 
88
  # 7. run decoder prefill stage using RKNN
89
  start_time = time.time()
 
 
 
 
90
  decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]])
91
  end_time = time.time()
92
  decoder_prefill_time = (end_time - start_time) * 1000
@@ -99,7 +123,7 @@ encoder_kv = decoder_outs[1:]
99
 
100
  # 8. run decoder decode stage(autoregressive) (using onnxruntime)
101
  generated_tokens = []
102
- max_new_tokens = 32
103
  decoder_decode_total_time = 0
104
  while generated_tokens.__len__() < max_new_tokens:
105
  # 获取上一步的输出
@@ -111,7 +135,7 @@ while generated_tokens.__len__() < max_new_tokens:
111
 
112
  # 使用argmax选择下一个token (贪心算法)
113
  next_token = np.argmax(next_token_logits, axis=-1)[0]
114
- # print("next_token: ", next_token)
115
  # 将新生成的token添加到结果中
116
  generated_tokens.append(next_token)
117
 
@@ -119,7 +143,7 @@ while generated_tokens.__len__() < max_new_tokens:
119
  if next_token == 2: # </s>
120
  break
121
 
122
- # 准备下一步的输入
123
  start_time = time.time()
124
  next_input_embeds = text_embed.run(None, {
125
  "input_ids": np.array([[next_token]], dtype=np.int64)
@@ -171,11 +195,130 @@ print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
171
  print("generated_tokens: ", generated_tokens)
172
  generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
173
  print("Generated Text:", generated_text)
174
- parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
175
  print("Parsed Answer:", parsed_answer)
176
 
177
  print(f"Total inference time: {total_time:.2f} ms")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # Release RKNNLite instances
180
  rknn_vision_encoder.release()
181
  rknn_encoder.release()
 
1
+ import random
2
  from rknnlite.api.rknn_lite import RKNNLite
3
  from transformers import AutoProcessor
4
+ from PIL import Image, ImageDraw
5
  import numpy as np
6
  import onnxruntime as ort
7
  import time
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patches as patches
10
  # set current working directory to the directory of this file
11
  import os
12
  os.chdir(os.path.dirname(os.path.abspath(__file__)))
 
20
  rknn_decoder_prefill = RKNNLite(verbose=False)
21
 
22
  # Load RKNN models
23
+ ret = rknn_vision_encoder.load_rknn('./vision_encoder_part2.rknn')
24
  ret = rknn_encoder.load_rknn('./encoder_model.rknn')
25
  ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn')
26
 
 
29
  ret = rknn_encoder.init_runtime()
30
  ret = rknn_decoder_prefill.init_runtime()
31
 
32
+ text_embed = ort.InferenceSession("embed_tokens_fp16.onnx", providers=['CPUExecutionProvider'])
33
  decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
34
+ vision_encoder = ort.InferenceSession("vision_encoder_part1.onnx", providers=['CPUExecutionProvider'])
35
+ prompt_tokens_list = [15, 17, 21, 25]
36
 
37
  # 1. prepare inputs
38
  processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
39
 
40
  # 2. prepare image
41
+ image = Image.open("./test.jpg")
42
+ original_image = image.copy()
43
+ original_size = image.size
44
+ # resize image to 768x768
45
+ image = image.resize((768, 768))
46
  # 3. prepare text
47
  prompt = "<MORE_DETAILED_CAPTION>"
48
+
49
+ ## try tokenize first
50
+ input_tokens_len = processor.tokenizer(prompt, return_tensors="np")["input_ids"].shape[1]
51
+ print("input_tokens_len: ", input_tokens_len)
52
+ ## select the closest greater value
53
+ pad_to = 0
54
+ for i in prompt_tokens_list:
55
+ if i >= input_tokens_len:
56
+ pad_to = i
57
+ break
58
+ print("pad_to: ", pad_to)
59
+ inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False, padding="max_length", max_length=pad_to + 577, truncation=True)
60
  for k, v in inputs.items():
61
  print(k, v.shape)
62
 
63
  # 4. run vision encoder using RKNN
64
  start_time = time.time()
65
+ image_features0 = vision_encoder.run(None, {
66
+ "pixel_values": inputs["pixel_values"]
67
+ })[0]
68
+ image_features = rknn_vision_encoder.inference(inputs=[image_features0.reshape(1, 128, 1, 36864)])[0]
69
+
70
  end_time = time.time()
71
  vision_encoder_time = (end_time - start_time) * 1000
72
  total_time += vision_encoder_time
 
107
 
108
  # 7. run decoder prefill stage using RKNN
109
  start_time = time.time()
110
+ next_token = processor.tokenizer.bos_token_id
111
+ next_input_embeds = text_embed.run(None, {
112
+ "input_ids": np.array([[next_token]], dtype=np.int64)
113
+ })[0]
114
  decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]])
115
  end_time = time.time()
116
  decoder_prefill_time = (end_time - start_time) * 1000
 
123
 
124
  # 8. run decoder decode stage(autoregressive) (using onnxruntime)
125
  generated_tokens = []
126
+ max_new_tokens = 512
127
  decoder_decode_total_time = 0
128
  while generated_tokens.__len__() < max_new_tokens:
129
  # 获取上一步的输出
 
135
 
136
  # 使用argmax选择下一个token (贪心算法)
137
  next_token = np.argmax(next_token_logits, axis=-1)[0]
138
+ print("next_token: ", next_token)
139
  # 将新生成的token添加到结果中
140
  generated_tokens.append(next_token)
141
 
 
143
  if next_token == 2: # </s>
144
  break
145
 
146
+ # 准备下一步的输入
147
  start_time = time.time()
148
  next_input_embeds = text_embed.run(None, {
149
  "input_ids": np.array([[next_token]], dtype=np.int64)
 
195
  print("generated_tokens: ", generated_tokens)
196
  generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
197
  print("Generated Text:", generated_text)
198
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt.split(">")[0].strip() + ">", image_size=original_size)
199
  print("Parsed Answer:", parsed_answer)
200
 
201
  print(f"Total inference time: {total_time:.2f} ms")
202
 
203
+ # postprocess
204
+ from PIL import Image, ImageDraw, ImageFont
205
+
206
+ from PIL import Image, ImageDraw, ImageFont
207
+
208
+ def plot_bbox(image, data):
209
+ # Convert the image to a PIL Image if it's not already
210
+ if not isinstance(image, Image.Image):
211
+ image = Image.fromarray(image)
212
+
213
+ # Create a drawing context
214
+ draw = ImageDraw.Draw(image)
215
+
216
+ # Load a larger font
217
+ try:
218
+ font = ImageFont.truetype("arial.ttf", 20) # 尝试加载Arial字体,大小为20
219
+ except IOError:
220
+ font = ImageFont.load_default().font_variant(size=20) # 如果Arial不可用,使用默认字体并放大
221
+
222
+ # Plot each bounding box
223
+ for bbox, label in zip(data['bboxes'], data['labels']):
224
+ # Unpack the bounding box coordinates
225
+ x1, y1, x2, y2 = bbox
226
+ # Draw the rectangle with thicker outline
227
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3) # 增加线条宽度到3
228
+
229
+ # Annotate the label
230
+ left, top, right, bottom = font.getbbox(label)
231
+ text_width = right - left
232
+ text_height = bottom - top
233
+
234
+ # 增加文本背景框的大小
235
+ padding = 5
236
+ draw.rectangle([x1, y1 - text_height - padding*2, x1 + text_width + padding*2, y1], fill="red")
237
+ draw.text((x1 + padding, y1 - text_height - padding), label, fill="white", font=font)
238
+
239
+ # Save the image
240
+ image.save("result_image.jpg")
241
+
242
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
243
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
244
+
245
+ def draw_polygons(image, prediction, fill_mask=False):
246
+ """
247
+ Draws segmentation masks with polygons on an image.
248
+
249
+ Parameters:
250
+ - image_path: Path to the image file.
251
+ - prediction: Dictionary containing 'polygons' and 'labels' keys.
252
+ 'polygons' is a list of lists, each containing vertices of a polygon.
253
+ 'labels' is a list of labels corresponding to each polygon.
254
+ - fill_mask: Boolean indicating whether to fill the polygons with color.
255
+ """
256
+ # Load the image
257
+
258
+ draw = ImageDraw.Draw(image)
259
+
260
+
261
+ # Set up scale factor if needed (use 1 if not scaling)
262
+ scale = 1
263
+
264
+ # Iterate over polygons and labels
265
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
266
+ color = random.choice(colormap)
267
+ fill_color = random.choice(colormap) if fill_mask else None
268
+
269
+ for _polygon in polygons:
270
+ _polygon = np.array(_polygon).reshape(-1, 2)
271
+ if len(_polygon) < 3:
272
+ print('Invalid polygon:', _polygon)
273
+ continue
274
+
275
+ _polygon = (_polygon * scale).reshape(-1).tolist()
276
+
277
+ # Draw the polygon
278
+ if fill_mask:
279
+ draw.polygon(_polygon, outline=color, fill=fill_color)
280
+ else:
281
+ draw.polygon(_polygon, outline=color)
282
+
283
+ # Draw the label text
284
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
285
+
286
+ # Save or display the image
287
+ # image.show() # Display the image
288
+ # display(image)
289
+ image.save("result_image.jpg")
290
+
291
+
292
+
293
+ def draw_ocr_bboxes(image, prediction, scale=1):
294
+ draw = ImageDraw.Draw(image)
295
+
296
+ # Load a larger font
297
+ try:
298
+ font = ImageFont.truetype("arial.ttf", 18) # 尝试加载Arial字体,大小为18
299
+ except IOError:
300
+ font = ImageFont.load_default().font_variant(size=18) # 如果Arial不可用,使用默认字体并放大
301
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
302
+ for box, label in zip(bboxes, labels):
303
+ color = random.choice(colormap)
304
+ new_box = (np.array(box) * scale).tolist()
305
+ draw.polygon(new_box, width=3, outline=color)
306
+ draw.text((new_box[0]+8, new_box[1]+2),
307
+ "{}".format(label),
308
+ align="right",
309
+
310
+ fill=color)
311
+
312
+ # display(image)
313
+ image.save("result_image.jpg")
314
+
315
+
316
+ # draw_polygons(original_image, parsed_answer['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
317
+ # plot_bbox(original_image, parsed_answer[prompt.split(">")[0].strip() + ">"])
318
+ # draw_ocr_bboxes(original_image, parsed_answer["<OCR_WITH_REGION>"], scale=1)
319
+
320
+
321
+
322
  # Release RKNNLite instances
323
  rknn_vision_encoder.release()
324
  rknn_encoder.release()