Split part of vision encoder to CPU and optimize Transpose ops. (Reupload to correct path)
385f65d
verified
import random | |
from rknnlite.api.rknn_lite import RKNNLite | |
from transformers import AutoProcessor | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import onnxruntime as ort | |
import time | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
# set current working directory to the directory of this file | |
import os | |
os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
# 初始化总时间计数器 | |
total_time = 0 | |
# Initialize RKNNLite instances | |
rknn_vision_encoder = RKNNLite(verbose=False) | |
rknn_encoder = RKNNLite(verbose=False) | |
rknn_decoder_prefill = RKNNLite(verbose=False) | |
# Load RKNN models | |
ret = rknn_vision_encoder.load_rknn('./vision_encoder_part2.rknn') | |
ret = rknn_encoder.load_rknn('./encoder_model.rknn') | |
ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn') | |
# Init runtime environment for each model | |
ret = rknn_vision_encoder.init_runtime() | |
ret = rknn_encoder.init_runtime() | |
ret = rknn_decoder_prefill.init_runtime() | |
text_embed = ort.InferenceSession("embed_tokens_fp16.onnx", providers=['CPUExecutionProvider']) | |
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider']) | |
vision_encoder = ort.InferenceSession("vision_encoder_part1.onnx", providers=['CPUExecutionProvider']) | |
prompt_tokens_list = [15, 17, 21, 25] | |
# 1. prepare inputs | |
processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True) | |
# 2. prepare image | |
image = Image.open("./test.jpg") | |
original_image = image.copy() | |
original_size = image.size | |
# resize image to 768x768 | |
image = image.resize((768, 768)) | |
# 3. prepare text | |
prompt = "<MORE_DETAILED_CAPTION>" | |
## try tokenize first | |
input_tokens_len = processor.tokenizer(prompt, return_tensors="np")["input_ids"].shape[1] | |
print("input_tokens_len: ", input_tokens_len) | |
## select the closest greater value | |
pad_to = 0 | |
for i in prompt_tokens_list: | |
if i >= input_tokens_len: | |
pad_to = i | |
break | |
print("pad_to: ", pad_to) | |
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False, padding="max_length", max_length=pad_to + 577, truncation=True) | |
for k, v in inputs.items(): | |
print(k, v.shape) | |
# 4. run vision encoder using RKNN | |
start_time = time.time() | |
image_features0 = vision_encoder.run(None, { | |
"pixel_values": inputs["pixel_values"] | |
})[0] | |
image_features = rknn_vision_encoder.inference(inputs=[image_features0.reshape(1, 128, 1, 36864)])[0] | |
end_time = time.time() | |
vision_encoder_time = (end_time - start_time) * 1000 | |
total_time += vision_encoder_time | |
print(f"Vision encoder time: {vision_encoder_time:.2f} ms") | |
print(image_features.shape) | |
np.save("image_features.npy", image_features) | |
# 5. run text embed using RKNN | |
start_time = time.time() | |
inputs_embeds = text_embed.run(None, { | |
"input_ids": inputs["input_ids"] | |
})[0] | |
end_time = time.time() | |
text_embed_time = (end_time - start_time) * 1000 | |
total_time += text_embed_time | |
print(f"Text embed time: {text_embed_time:.2f} ms") | |
print(inputs_embeds.shape) | |
# 6. concat image features and text embed | |
batch_size, image_token_length = image_features.shape[:-1] | |
image_attention_mask = np.ones((batch_size, image_token_length)) | |
task_prefix_embeds = inputs_embeds | |
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1])) | |
if len(task_prefix_attention_mask.shape) == 3: | |
task_prefix_attention_mask = task_prefix_attention_mask[:, 0] | |
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1) | |
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1) | |
# 6. run encoder using RKNN | |
start_time = time.time() | |
encoder_out = rknn_encoder.inference(inputs=[attention_mask.astype(np.int64),inputs_embeds]) | |
end_time = time.time() | |
encoder_time = (end_time - start_time) * 1000 | |
total_time += encoder_time | |
print(f"Encoder time: {encoder_time:.2f} ms") | |
encoder_hidden_states = encoder_out[0] | |
print(encoder_hidden_states.shape) | |
# 7. run decoder prefill stage using RKNN | |
start_time = time.time() | |
next_token = processor.tokenizer.bos_token_id | |
next_input_embeds = text_embed.run(None, { | |
"input_ids": np.array([[next_token]], dtype=np.int64) | |
})[0] | |
decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]]) | |
end_time = time.time() | |
decoder_prefill_time = (end_time - start_time) * 1000 | |
total_time += decoder_prefill_time | |
print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms") | |
# for output in decoder_outs: | |
# print(output.shape) | |
encoder_kv = decoder_outs[1:] | |
# 8. run decoder decode stage(autoregressive) (using onnxruntime) | |
generated_tokens = [] | |
max_new_tokens = 512 | |
decoder_decode_total_time = 0 | |
while generated_tokens.__len__() < max_new_tokens: | |
# 获取上一步的输出 | |
logits = decoder_outs[0] | |
decoder_kv = decoder_outs[1:] | |
# 选择最后一个token的logits | |
next_token_logits = logits[:, -1, :] | |
# 使用argmax选择下一个token (贪心算法) | |
next_token = np.argmax(next_token_logits, axis=-1)[0] | |
print("next_token: ", next_token) | |
# 将新生成的token添加到结果中 | |
generated_tokens.append(next_token) | |
# 如果生成了结束符,则停止生成 | |
if next_token == 2: # </s> | |
break | |
# 准备下一步的输入 | |
start_time = time.time() | |
next_input_embeds = text_embed.run(None, { | |
"input_ids": np.array([[next_token]], dtype=np.int64) | |
})[0] | |
end_time = time.time() | |
text_embed_time = (end_time - start_time) * 1000 | |
decoder_decode_total_time += text_embed_time | |
# 运行decoder的decode阶段 | |
start_time = time.time() | |
decoder_outs = decoder_decode.run(None, { | |
"use_cache_branch": np.array([True], dtype=np.bool_), | |
"inputs_embeds": next_input_embeds, | |
"encoder_hidden_states": encoder_hidden_states, | |
"encoder_attention_mask": attention_mask.astype(np.int64), | |
"past_key_values.0.decoder.key": decoder_kv[0], | |
"past_key_values.0.decoder.value": decoder_kv[1], | |
"past_key_values.0.encoder.key": encoder_kv[2], | |
"past_key_values.0.encoder.value": encoder_kv[3], | |
"past_key_values.1.decoder.key": decoder_kv[4], | |
"past_key_values.1.decoder.value": decoder_kv[5], | |
"past_key_values.1.encoder.key": encoder_kv[6], | |
"past_key_values.1.encoder.value": encoder_kv[7], | |
"past_key_values.2.decoder.key": decoder_kv[8], | |
"past_key_values.2.decoder.value": decoder_kv[9], | |
"past_key_values.2.encoder.key": encoder_kv[10], | |
"past_key_values.2.encoder.value": encoder_kv[11], | |
"past_key_values.3.decoder.key": decoder_kv[12], | |
"past_key_values.3.decoder.value": decoder_kv[13], | |
"past_key_values.3.encoder.key": encoder_kv[14], | |
"past_key_values.3.encoder.value": encoder_kv[15], | |
"past_key_values.4.decoder.key": decoder_kv[16], | |
"past_key_values.4.decoder.value": decoder_kv[17], | |
"past_key_values.4.encoder.key": encoder_kv[18], | |
"past_key_values.4.encoder.value": encoder_kv[19], | |
"past_key_values.5.decoder.key": decoder_kv[20], | |
"past_key_values.5.decoder.value": decoder_kv[21], | |
"past_key_values.5.encoder.key": encoder_kv[22], | |
"past_key_values.5.encoder.value": encoder_kv[23], | |
}) | |
end_time = time.time() | |
decoder_decode_time = (end_time - start_time) * 1000 | |
decoder_decode_total_time += decoder_decode_time | |
total_time += decoder_decode_total_time | |
print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms") | |
# 将生成的tokens转换为文本 | |
print("generated_tokens: ", generated_tokens) | |
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0] | |
print("Generated Text:", generated_text) | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt.split(">")[0].strip() + ">", image_size=original_size) | |
print("Parsed Answer:", parsed_answer) | |
print(f"Total inference time: {total_time:.2f} ms") | |
# postprocess | |
from PIL import Image, ImageDraw, ImageFont | |
from PIL import Image, ImageDraw, ImageFont | |
def plot_bbox(image, data): | |
# Convert the image to a PIL Image if it's not already | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Create a drawing context | |
draw = ImageDraw.Draw(image) | |
# Load a larger font | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) # 尝试加载Arial字体,大小为20 | |
except IOError: | |
font = ImageFont.load_default().font_variant(size=20) # 如果Arial不可用,使用默认字体并放大 | |
# Plot each bounding box | |
for bbox, label in zip(data['bboxes'], data['labels']): | |
# Unpack the bounding box coordinates | |
x1, y1, x2, y2 = bbox | |
# Draw the rectangle with thicker outline | |
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) # 增加线条宽度到3 | |
# Annotate the label | |
left, top, right, bottom = font.getbbox(label) | |
text_width = right - left | |
text_height = bottom - top | |
# 增加文本背景框的大小 | |
padding = 5 | |
draw.rectangle([x1, y1 - text_height - padding*2, x1 + text_width + padding*2, y1], fill="red") | |
draw.text((x1 + padding, y1 - text_height - padding), label, fill="white", font=font) | |
# Save the image | |
image.save("result_image.jpg") | |
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', | |
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] | |
def draw_polygons(image, prediction, fill_mask=False): | |
""" | |
Draws segmentation masks with polygons on an image. | |
Parameters: | |
- image_path: Path to the image file. | |
- prediction: Dictionary containing 'polygons' and 'labels' keys. | |
'polygons' is a list of lists, each containing vertices of a polygon. | |
'labels' is a list of labels corresponding to each polygon. | |
- fill_mask: Boolean indicating whether to fill the polygons with color. | |
""" | |
# Load the image | |
draw = ImageDraw.Draw(image) | |
# Set up scale factor if needed (use 1 if not scaling) | |
scale = 1 | |
# Iterate over polygons and labels | |
for polygons, label in zip(prediction['polygons'], prediction['labels']): | |
color = random.choice(colormap) | |
fill_color = random.choice(colormap) if fill_mask else None | |
for _polygon in polygons: | |
_polygon = np.array(_polygon).reshape(-1, 2) | |
if len(_polygon) < 3: | |
print('Invalid polygon:', _polygon) | |
continue | |
_polygon = (_polygon * scale).reshape(-1).tolist() | |
# Draw the polygon | |
if fill_mask: | |
draw.polygon(_polygon, outline=color, fill=fill_color) | |
else: | |
draw.polygon(_polygon, outline=color) | |
# Draw the label text | |
draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color) | |
# Save or display the image | |
# image.show() # Display the image | |
# display(image) | |
image.save("result_image.jpg") | |
def draw_ocr_bboxes(image, prediction, scale=1): | |
draw = ImageDraw.Draw(image) | |
# Load a larger font | |
try: | |
font = ImageFont.truetype("arial.ttf", 18) # 尝试加载Arial字体,大小为18 | |
except IOError: | |
font = ImageFont.load_default().font_variant(size=18) # 如果Arial不可用,使用默认字体并放大 | |
bboxes, labels = prediction['quad_boxes'], prediction['labels'] | |
for box, label in zip(bboxes, labels): | |
color = random.choice(colormap) | |
new_box = (np.array(box) * scale).tolist() | |
draw.polygon(new_box, width=3, outline=color) | |
draw.text((new_box[0]+8, new_box[1]+2), | |
"{}".format(label), | |
align="right", | |
fill=color) | |
# display(image) | |
image.save("result_image.jpg") | |
# draw_polygons(original_image, parsed_answer['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True) | |
# plot_bbox(original_image, parsed_answer[prompt.split(">")[0].strip() + ">"]) | |
# draw_ocr_bboxes(original_image, parsed_answer["<OCR_WITH_REGION>"], scale=1) | |
# Release RKNNLite instances | |
rknn_vision_encoder.release() | |
rknn_encoder.release() | |
rknn_decoder_prefill.release() |