Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import sys | |
import os | |
sys.path.insert(0, os.path.dirname(__file__)) | |
from embd_input import MyModel | |
import numpy as np | |
from torch import nn | |
import torch | |
from PIL import Image | |
minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4") | |
sys.path.insert(0, minigpt4_path) | |
from minigpt4.models.blip2 import Blip2Base | |
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor | |
class MiniGPT4(Blip2Base): | |
""" | |
MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4 | |
""" | |
def __init__(self, | |
args, | |
vit_model="eva_clip_g", | |
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", | |
img_size=224, | |
drop_path_rate=0, | |
use_grad_checkpoint=False, | |
vit_precision="fp32", | |
freeze_vit=True, | |
freeze_qformer=True, | |
num_query_token=32, | |
llama_model="", | |
prompt_path="", | |
prompt_template="", | |
max_txt_len=32, | |
end_sym='\n', | |
low_resource=False, # use 8 bit and put vit in cpu | |
device_8bit=0 | |
): | |
super().__init__() | |
self.img_size = img_size | |
self.low_resource = low_resource | |
self.preprocessor = Blip2ImageEvalProcessor(img_size) | |
print('Loading VIT') | |
self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision | |
) | |
print('Loading VIT Done') | |
print('Loading Q-Former') | |
self.Qformer, self.query_tokens = self.init_Qformer( | |
num_query_token, self.visual_encoder.num_features | |
) | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.load_from_pretrained(url_or_filename=q_former_model) | |
print('Loading Q-Former Done') | |
self.llama_proj = nn.Linear( | |
self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size | |
) | |
self.max_txt_len = max_txt_len | |
self.end_sym = end_sym | |
self.model = MyModel(["main", *args]) | |
# system prompt | |
self.model.eval_string("Give the following image: <Img>ImageContent</Img>. " | |
"You will be able to see the image once I provide it to you. Please answer my questions." | |
"###") | |
def encode_img(self, image): | |
image = self.preprocessor(image) | |
image = image.unsqueeze(0) | |
device = image.device | |
if self.low_resource: | |
self.vit_to_cpu() | |
image = image.to("cpu") | |
with self.maybe_autocast(): | |
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) | |
return inputs_llama | |
def load_projection(self, path): | |
state = torch.load(path)["model"] | |
self.llama_proj.load_state_dict({ | |
"weight": state["llama_proj.weight"], | |
"bias": state["llama_proj.bias"]}) | |
def chat(self, question): | |
self.model.eval_string("Human: ") | |
self.model.eval_string(question) | |
self.model.eval_string("\n### Assistant:") | |
return self.model.generate_with_print(end="###") | |
def chat_with_image(self, image, question): | |
with torch.no_grad(): | |
embd_image = self.encode_img(image) | |
embd_image = embd_image.cpu().numpy()[0] | |
self.model.eval_string("Human: <Img>") | |
self.model.eval_float(embd_image.T) | |
self.model.eval_string("</Img> ") | |
self.model.eval_string(question) | |
self.model.eval_string("\n### Assistant:") | |
return self.model.generate_with_print(end="###") | |
if __name__=="__main__": | |
a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"]) | |
a.load_projection(os.path.join( | |
os.path.dirname(__file__) , | |
"pretrained_minigpt4.pth")) | |
respose = a.chat_with_image( | |
Image.open("./media/llama1-logo.png").convert('RGB'), | |
"what is the text in the picture?") | |
a.chat("what is the color of it?") | |