happyme531's picture
Upload 21 files
95dfa6c verified
raw
history blame
5.83 kB
import onnxruntime as ort
from transformers import AutoProcessor
from PIL import Image
import numpy as np
# set current working directory to the directory of this file
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# embeddings
vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
# encoder
encoder = ort.InferenceSession("encoder_model.onnx", providers=['CPUExecutionProvider'])
# decoder
decoder_prefill = ort.InferenceSession("decoder_model.onnx", providers=['CPUExecutionProvider'])
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
# 1. prepare inputs
processor = AutoProcessor.from_pretrained("/home/zt/rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
# 2. prepare image
image = Image.open("./lena.png")
# resize image to 512x512
image = image.resize((512, 512))
# 3. prepare text
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
for k, v in inputs.items():
print(k, v.shape)
# 4. run vision encoder
image_features = vision_encoder.run(None, {
"pixel_values": inputs["pixel_values"]
})
for output in image_features:
print(output.shape)
image_features = image_features[0]
np.save("image_features.npy", image_features)
# 5. run text embed
inputs_embeds = text_embed.run(None, {
"input_ids": inputs["input_ids"]
})
for output in inputs_embeds:
print(output.shape)
inputs_embeds = inputs_embeds[0]
# 6. concat image features and text embed
batch_size, image_token_length = image_features.shape[:-1]
image_attention_mask = np.ones((batch_size, image_token_length))
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
if len(task_prefix_attention_mask.shape) == 3:
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
# 6. run encoder
encoder_out = encoder.run(None, {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask.astype(np.int64)
})
for output in encoder_out:
print(output.shape)
encoder_hidden_states = encoder_out[0]
# 7. run decoder prefill stage
decoder_outs = decoder_prefill.run(None, {
"inputs_embeds": inputs_embeds[:, -1:],
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask.astype(np.int64)
})
for output in decoder_outs:
print(output.shape)
encoder_kv = decoder_outs[1:];
# 8. run decoder decode stage(autoregressive)
generated_tokens = []
max_new_tokens = 32
while generated_tokens.__len__() < max_new_tokens:
# 获取上一步的输出
logits = decoder_outs[0]
decoder_kv = decoder_outs[1:]
# 选择最后一个token的logits
next_token_logits = logits[:, -1, :]
# 使用argmax选择下一个token (贪心算法)
next_token = np.argmax(next_token_logits, axis=-1)[0]
print("next_token: ", next_token)
# 将新生成的token添加到结果中
generated_tokens.append(next_token)
# 如果生成了结束符,则停止生成
if next_token == 2: # </s>
break
# 准备下一步的输入
next_input_embeds = text_embed.run(None, {
"input_ids": np.array([[next_token]], dtype=np.int64)
})[0]
# 运行decoder的decode阶段
decoder_outs = decoder_decode.run(None, {
"use_cache_branch": np.array([True], dtype=np.bool_),
"inputs_embeds": next_input_embeds,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask.astype(np.int64),
"past_key_values.0.decoder.key": decoder_kv[0],
"past_key_values.0.decoder.value": decoder_kv[1],
"past_key_values.0.encoder.key": encoder_kv[2],
"past_key_values.0.encoder.value": encoder_kv[3],
"past_key_values.1.decoder.key": decoder_kv[4],
"past_key_values.1.decoder.value": decoder_kv[5],
"past_key_values.1.encoder.key": encoder_kv[6],
"past_key_values.1.encoder.value": encoder_kv[7],
"past_key_values.2.decoder.key": decoder_kv[8],
"past_key_values.2.decoder.value": decoder_kv[9],
"past_key_values.2.encoder.key": encoder_kv[10],
"past_key_values.2.encoder.value": encoder_kv[11],
"past_key_values.3.decoder.key": decoder_kv[12],
"past_key_values.3.decoder.value": decoder_kv[13],
"past_key_values.3.encoder.key": encoder_kv[14],
"past_key_values.3.encoder.value": encoder_kv[15],
"past_key_values.4.decoder.key": decoder_kv[16],
"past_key_values.4.decoder.value": decoder_kv[17],
"past_key_values.4.encoder.key": encoder_kv[18],
"past_key_values.4.encoder.value": encoder_kv[19],
"past_key_values.5.decoder.key": decoder_kv[20],
"past_key_values.5.decoder.value": decoder_kv[21],
"past_key_values.5.encoder.key": encoder_kv[22],
"past_key_values.5.encoder.value": encoder_kv[23],
})
for output in decoder_outs:
print(output.shape)
# print("generated_token: ", processor.decode(next_token, skip_special_tokens=False))
# 删除第一个token
# generated_tokens = generated_tokens[1:]
# 将生成的tokens转换为文本
print("generated_tokens: ", generated_tokens)
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
print("Generated Text:", generated_text)
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print("Parsed Answer:", parsed_answer)