happyme531's picture
Split part of vision encoder to CPU and optimize Transpose ops. (Reupload to correct path)
385f65d verified
raw
history blame
12.5 kB
import random
from rknnlite.api.rknn_lite import RKNNLite
from transformers import AutoProcessor
from PIL import Image, ImageDraw
import numpy as np
import onnxruntime as ort
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# set current working directory to the directory of this file
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# 初始化总时间计数器
total_time = 0
# Initialize RKNNLite instances
rknn_vision_encoder = RKNNLite(verbose=False)
rknn_encoder = RKNNLite(verbose=False)
rknn_decoder_prefill = RKNNLite(verbose=False)
# Load RKNN models
ret = rknn_vision_encoder.load_rknn('./vision_encoder_part2.rknn')
ret = rknn_encoder.load_rknn('./encoder_model.rknn')
ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn')
# Init runtime environment for each model
ret = rknn_vision_encoder.init_runtime()
ret = rknn_encoder.init_runtime()
ret = rknn_decoder_prefill.init_runtime()
text_embed = ort.InferenceSession("embed_tokens_fp16.onnx", providers=['CPUExecutionProvider'])
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
vision_encoder = ort.InferenceSession("vision_encoder_part1.onnx", providers=['CPUExecutionProvider'])
prompt_tokens_list = [15, 17, 21, 25]
# 1. prepare inputs
processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
# 2. prepare image
image = Image.open("./test.jpg")
original_image = image.copy()
original_size = image.size
# resize image to 768x768
image = image.resize((768, 768))
# 3. prepare text
prompt = "<MORE_DETAILED_CAPTION>"
## try tokenize first
input_tokens_len = processor.tokenizer(prompt, return_tensors="np")["input_ids"].shape[1]
print("input_tokens_len: ", input_tokens_len)
## select the closest greater value
pad_to = 0
for i in prompt_tokens_list:
if i >= input_tokens_len:
pad_to = i
break
print("pad_to: ", pad_to)
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False, padding="max_length", max_length=pad_to + 577, truncation=True)
for k, v in inputs.items():
print(k, v.shape)
# 4. run vision encoder using RKNN
start_time = time.time()
image_features0 = vision_encoder.run(None, {
"pixel_values": inputs["pixel_values"]
})[0]
image_features = rknn_vision_encoder.inference(inputs=[image_features0.reshape(1, 128, 1, 36864)])[0]
end_time = time.time()
vision_encoder_time = (end_time - start_time) * 1000
total_time += vision_encoder_time
print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
print(image_features.shape)
np.save("image_features.npy", image_features)
# 5. run text embed using RKNN
start_time = time.time()
inputs_embeds = text_embed.run(None, {
"input_ids": inputs["input_ids"]
})[0]
end_time = time.time()
text_embed_time = (end_time - start_time) * 1000
total_time += text_embed_time
print(f"Text embed time: {text_embed_time:.2f} ms")
print(inputs_embeds.shape)
# 6. concat image features and text embed
batch_size, image_token_length = image_features.shape[:-1]
image_attention_mask = np.ones((batch_size, image_token_length))
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
if len(task_prefix_attention_mask.shape) == 3:
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
# 6. run encoder using RKNN
start_time = time.time()
encoder_out = rknn_encoder.inference(inputs=[attention_mask.astype(np.int64),inputs_embeds])
end_time = time.time()
encoder_time = (end_time - start_time) * 1000
total_time += encoder_time
print(f"Encoder time: {encoder_time:.2f} ms")
encoder_hidden_states = encoder_out[0]
print(encoder_hidden_states.shape)
# 7. run decoder prefill stage using RKNN
start_time = time.time()
next_token = processor.tokenizer.bos_token_id
next_input_embeds = text_embed.run(None, {
"input_ids": np.array([[next_token]], dtype=np.int64)
})[0]
decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]])
end_time = time.time()
decoder_prefill_time = (end_time - start_time) * 1000
total_time += decoder_prefill_time
print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
# for output in decoder_outs:
# print(output.shape)
encoder_kv = decoder_outs[1:]
# 8. run decoder decode stage(autoregressive) (using onnxruntime)
generated_tokens = []
max_new_tokens = 512
decoder_decode_total_time = 0
while generated_tokens.__len__() < max_new_tokens:
# 获取上一步的输出
logits = decoder_outs[0]
decoder_kv = decoder_outs[1:]
# 选择最后一个token的logits
next_token_logits = logits[:, -1, :]
# 使用argmax选择下一个token (贪心算法)
next_token = np.argmax(next_token_logits, axis=-1)[0]
print("next_token: ", next_token)
# 将新生成的token添加到结果中
generated_tokens.append(next_token)
# 如果生成了结束符,则停止生成
if next_token == 2: # </s>
break
# 准备下一步的输入
start_time = time.time()
next_input_embeds = text_embed.run(None, {
"input_ids": np.array([[next_token]], dtype=np.int64)
})[0]
end_time = time.time()
text_embed_time = (end_time - start_time) * 1000
decoder_decode_total_time += text_embed_time
# 运行decoder的decode阶段
start_time = time.time()
decoder_outs = decoder_decode.run(None, {
"use_cache_branch": np.array([True], dtype=np.bool_),
"inputs_embeds": next_input_embeds,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask.astype(np.int64),
"past_key_values.0.decoder.key": decoder_kv[0],
"past_key_values.0.decoder.value": decoder_kv[1],
"past_key_values.0.encoder.key": encoder_kv[2],
"past_key_values.0.encoder.value": encoder_kv[3],
"past_key_values.1.decoder.key": decoder_kv[4],
"past_key_values.1.decoder.value": decoder_kv[5],
"past_key_values.1.encoder.key": encoder_kv[6],
"past_key_values.1.encoder.value": encoder_kv[7],
"past_key_values.2.decoder.key": decoder_kv[8],
"past_key_values.2.decoder.value": decoder_kv[9],
"past_key_values.2.encoder.key": encoder_kv[10],
"past_key_values.2.encoder.value": encoder_kv[11],
"past_key_values.3.decoder.key": decoder_kv[12],
"past_key_values.3.decoder.value": decoder_kv[13],
"past_key_values.3.encoder.key": encoder_kv[14],
"past_key_values.3.encoder.value": encoder_kv[15],
"past_key_values.4.decoder.key": decoder_kv[16],
"past_key_values.4.decoder.value": decoder_kv[17],
"past_key_values.4.encoder.key": encoder_kv[18],
"past_key_values.4.encoder.value": encoder_kv[19],
"past_key_values.5.decoder.key": decoder_kv[20],
"past_key_values.5.decoder.value": decoder_kv[21],
"past_key_values.5.encoder.key": encoder_kv[22],
"past_key_values.5.encoder.value": encoder_kv[23],
})
end_time = time.time()
decoder_decode_time = (end_time - start_time) * 1000
decoder_decode_total_time += decoder_decode_time
total_time += decoder_decode_total_time
print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
# 将生成的tokens转换为文本
print("generated_tokens: ", generated_tokens)
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
print("Generated Text:", generated_text)
parsed_answer = processor.post_process_generation(generated_text, task=prompt.split(">")[0].strip() + ">", image_size=original_size)
print("Parsed Answer:", parsed_answer)
print(f"Total inference time: {total_time:.2f} ms")
# postprocess
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont
def plot_bbox(image, data):
# Convert the image to a PIL Image if it's not already
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Create a drawing context
draw = ImageDraw.Draw(image)
# Load a larger font
try:
font = ImageFont.truetype("arial.ttf", 20) # 尝试加载Arial字体,大小为20
except IOError:
font = ImageFont.load_default().font_variant(size=20) # 如果Arial不可用,使用默认字体并放大
# Plot each bounding box
for bbox, label in zip(data['bboxes'], data['labels']):
# Unpack the bounding box coordinates
x1, y1, x2, y2 = bbox
# Draw the rectangle with thicker outline
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) # 增加线条宽度到3
# Annotate the label
left, top, right, bottom = font.getbbox(label)
text_width = right - left
text_height = bottom - top
# 增加文本背景框的大小
padding = 5
draw.rectangle([x1, y1 - text_height - padding*2, x1 + text_width + padding*2, y1], fill="red")
draw.text((x1 + padding, y1 - text_height - padding), label, fill="white", font=font)
# Save the image
image.save("result_image.jpg")
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
def draw_polygons(image, prediction, fill_mask=False):
"""
Draws segmentation masks with polygons on an image.
Parameters:
- image_path: Path to the image file.
- prediction: Dictionary containing 'polygons' and 'labels' keys.
'polygons' is a list of lists, each containing vertices of a polygon.
'labels' is a list of labels corresponding to each polygon.
- fill_mask: Boolean indicating whether to fill the polygons with color.
"""
# Load the image
draw = ImageDraw.Draw(image)
# Set up scale factor if needed (use 1 if not scaling)
scale = 1
# Iterate over polygons and labels
for polygons, label in zip(prediction['polygons'], prediction['labels']):
color = random.choice(colormap)
fill_color = random.choice(colormap) if fill_mask else None
for _polygon in polygons:
_polygon = np.array(_polygon).reshape(-1, 2)
if len(_polygon) < 3:
print('Invalid polygon:', _polygon)
continue
_polygon = (_polygon * scale).reshape(-1).tolist()
# Draw the polygon
if fill_mask:
draw.polygon(_polygon, outline=color, fill=fill_color)
else:
draw.polygon(_polygon, outline=color)
# Draw the label text
draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
# Save or display the image
# image.show() # Display the image
# display(image)
image.save("result_image.jpg")
def draw_ocr_bboxes(image, prediction, scale=1):
draw = ImageDraw.Draw(image)
# Load a larger font
try:
font = ImageFont.truetype("arial.ttf", 18) # 尝试加载Arial字体,大小为18
except IOError:
font = ImageFont.load_default().font_variant(size=18) # 如果Arial不可用,使用默认字体并放大
bboxes, labels = prediction['quad_boxes'], prediction['labels']
for box, label in zip(bboxes, labels):
color = random.choice(colormap)
new_box = (np.array(box) * scale).tolist()
draw.polygon(new_box, width=3, outline=color)
draw.text((new_box[0]+8, new_box[1]+2),
"{}".format(label),
align="right",
fill=color)
# display(image)
image.save("result_image.jpg")
# draw_polygons(original_image, parsed_answer['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
# plot_bbox(original_image, parsed_answer[prompt.split(">")[0].strip() + ">"])
# draw_ocr_bboxes(original_image, parsed_answer["<OCR_WITH_REGION>"], scale=1)
# Release RKNNLite instances
rknn_vision_encoder.release()
rknn_encoder.release()
rknn_decoder_prefill.release()