Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import argparse | |
import PIL | |
from PIL import Image | |
import os | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from conversation import conv_templates, SeparatorStyle | |
from torchvision import transforms | |
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
from threading import Thread | |
from unitok.config import Args | |
from unitok.model import UniTok | |
from model.builder import load_pretrained_model | |
from mm_utils import tokenizer_image_token, get_model_name_from_path | |
IMAGE_TOKEN_INDEX=-200 | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def main(args): | |
ckpt = torch.load(args.unitok_path, map_location='cpu') | |
vae_cfg = Args() | |
vae_cfg.load_state_dict(ckpt['args']) | |
vq_model = UniTok(vae_cfg) | |
vq_model.load_state_dict(ckpt['trainer']['unitok']) | |
vq_model.to('cuda') | |
vq_model.eval() | |
model_path = os.path.expanduser(args.mllm_path) | |
model_name = get_model_name_from_path(model_path) | |
tokenizer, vqllm, image_processor, context_len = load_pretrained_model(model_path, model_name, load_8bit=args.load_8bit) | |
qs = args.prompt | |
qs = '<boi><image><eoi>' + '\n' + qs | |
conv = conv_templates['llava_v1'].copy() | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
crop_size = 256 | |
transform = transforms.Compose([ | |
transforms.Resize((crop_size, crop_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
print(prompt) | |
image = Image.open(args.image_path).convert('RGB') | |
pad_image = expand2square(image, (122, 116, 104) ) | |
# import pdb;pdb.set_trace() | |
img = transform(pad_image).unsqueeze(0) | |
img = img.to('cuda') | |
# import pdb;pdb.set_trace() | |
with torch.no_grad(): | |
vq_code = vq_model.img_to_idx(img) | |
image_codes = vq_code.unsqueeze(0) | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
# input_ids = torch.cat(text_ids, dim=0) | |
# input_embeddings = vqllm.embed_tokens(input_ids) | |
inputs = { | |
"inputs":input_ids.unsqueeze(0).to("cuda:0"), | |
"images":image_codes.to("cuda:0"), | |
"max_new_tokens":1024, | |
"bos_token_id":tokenizer.bos_token_id, # Begin of sequence token | |
"eos_token_id":tokenizer.eos_token_id, # End of sequence token | |
"pad_token_id":tokenizer.pad_token_id, # Pad token | |
} | |
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True}) | |
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
thread = Thread(target=vqllm.generate_mllm, kwargs=generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
print(generated_text) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument('--unitok_path', type=str, default=r'D:\projects\liquid_app\UniTok\UniTok_weights\unitok_tokenizer\unitok_tokenizer.pth',required=False) | |
parser.add_argument('--mllm_path', type=str, default= r'C:\debug_ckpts\unitok_mllm', required=False) | |
parser.add_argument('--prompt', type=str, required=True, help='input text prompt') | |
parser.add_argument('--image_path', type=str, required=True, help='input image path') | |
parser.add_argument('--load_8bit', action='store_true', default=False, help='use 8bit to save memory') | |
args = parser.parse_args() | |
main(args) | |