UniTok / inference_i2t.py
Junfeng5's picture
enable app
735672d verified
raw
history blame
4.24 kB
import torch
import argparse
import PIL
from PIL import Image
import os
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from conversation import conv_templates, SeparatorStyle
from torchvision import transforms
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from threading import Thread
from unitok.config import Args
from unitok.model import UniTok
from model.builder import load_pretrained_model
from mm_utils import tokenizer_image_token, get_model_name_from_path
IMAGE_TOKEN_INDEX=-200
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def main(args):
ckpt = torch.load(args.unitok_path, map_location='cpu')
vae_cfg = Args()
vae_cfg.load_state_dict(ckpt['args'])
vq_model = UniTok(vae_cfg)
vq_model.load_state_dict(ckpt['trainer']['unitok'])
vq_model.to('cuda')
vq_model.eval()
model_path = os.path.expanduser(args.mllm_path)
model_name = get_model_name_from_path(model_path)
tokenizer, vqllm, image_processor, context_len = load_pretrained_model(model_path, model_name, load_8bit=args.load_8bit)
qs = args.prompt
qs = '<boi><image><eoi>' + '\n' + qs
conv = conv_templates['llava_v1'].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
crop_size = 256
transform = transforms.Compose([
transforms.Resize((crop_size, crop_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
print(prompt)
image = Image.open(args.image_path).convert('RGB')
pad_image = expand2square(image, (122, 116, 104) )
# import pdb;pdb.set_trace()
img = transform(pad_image).unsqueeze(0)
img = img.to('cuda')
# import pdb;pdb.set_trace()
with torch.no_grad():
vq_code = vq_model.img_to_idx(img)
image_codes = vq_code.unsqueeze(0)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
# input_ids = torch.cat(text_ids, dim=0)
# input_embeddings = vqllm.embed_tokens(input_ids)
inputs = {
"inputs":input_ids.unsqueeze(0).to("cuda:0"),
"images":image_codes.to("cuda:0"),
"max_new_tokens":1024,
"bos_token_id":tokenizer.bos_token_id, # Begin of sequence token
"eos_token_id":tokenizer.eos_token_id, # End of sequence token
"pad_token_id":tokenizer.pad_token_id, # Pad token
}
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True})
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=vqllm.generate_mllm, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
print(generated_text)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--unitok_path', type=str, default=r'D:\projects\liquid_app\UniTok\UniTok_weights\unitok_tokenizer\unitok_tokenizer.pth',required=False)
parser.add_argument('--mllm_path', type=str, default= r'C:\debug_ckpts\unitok_mllm', required=False)
parser.add_argument('--prompt', type=str, required=True, help='input text prompt')
parser.add_argument('--image_path', type=str, required=True, help='input image path')
parser.add_argument('--load_8bit', action='store_true', default=False, help='use 8bit to save memory')
args = parser.parse_args()
main(args)