Spaces:
Runtime error
Runtime error
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig | |
model_name = "microsoft/phi-2" | |
phi2_model_pretrained = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map = 'cpu' | |
) | |
phi2_model_pretrained.config.use_cache = False | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.bos_token = tokenizer.eos_token | |
def convert_text_input_embeds(text): | |
in_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False) | |
in_embeds = phi2_model_pretrained.get_input_embeddings()(in_tokens.input_ids) | |
return in_embeds | |
import whisperx | |
whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32') | |
def convert_audio_file_text_embeds(fname): | |
result = whisper_model.transcribe(fname) | |
full_text = '' | |
for seg in result['segments']: | |
full_text = full_text + seg['text'] | |
return full_text.strip() | |
from transformers import CLIPVisionModel, CLIPImageProcessor | |
vision_tower_name = 'openai/clip-vit-base-patch32' ## torch.Size([1, 49, 768]) | |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name) | |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower_name) | |
def feature_select(image_forward_outs): | |
image_features = image_forward_outs.hidden_states[-1] # last layer | |
image_features = image_features[:, 1:, :] | |
return image_features # [1, 49, 768] | |
def image_CLIP_embed(image): | |
_ = vision_tower.requires_grad_(False) | |
image = image_processor(images=image, return_tensors="pt") | |
image_forward_out = vision_tower(image['pixel_values'].to(device=vision_tower.device), output_hidden_states=True) | |
image_feature = feature_select(image_forward_out) | |
return image_feature | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class CustomGELU(nn.Module): | |
def forward(self, x): | |
return F.gelu(x.clone()) | |
class SimpleResBlock(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.pre_norm = nn.LayerNorm(input_size) | |
self.proj = nn.Sequential( | |
nn.Linear(input_size, input_size), | |
nn.GELU(), | |
nn.Linear(input_size, input_size) | |
) | |
def forward(self, x): | |
x = self.pre_norm(x) | |
return x + self.proj(x) | |
class CLIPembed_projection(nn.Module): | |
def __init__(self, input_dim_CLIP=768, input_dim_phi2=2560): | |
super(CLIPembed_projection, self).__init__() | |
self.input_dim_CLIP = input_dim_CLIP | |
self.input_dim_phi2 = input_dim_phi2 | |
self.projection_img = nn.Linear(self.input_dim_CLIP, self.input_dim_phi2, | |
bias=False) | |
self.resblock = SimpleResBlock(self.input_dim_phi2) | |
def forward(self, x): | |
x = self.projection_img(x) | |
x = self.resblock(x) | |
return x | |
Image_projection_layer = CLIPembed_projection() | |
location_projection_img_p1 = f'./weights/stage_2/run2_projection_img.pth' | |
location_projection_img_p2 = f'./weights/stage_2/run2_resblock.pth' | |
# load projection_img, resblock from stage 2 | |
Image_projection_layer.projection_img.load_state_dict(torch.load(location_projection_img_p1, map_location='cpu')) | |
Image_projection_layer.resblock.load_state_dict(torch.load(location_projection_img_p2, map_location='cpu')) | |
def img_input_embed(image): | |
clip_embed = image_CLIP_embed(image) | |
post_projection = Image_projection_layer(clip_embed) | |
return post_projection | |
device = 'cpu' | |
user = "LN1996" # put your user name here | |
model_name = "peft-qlora-run2" | |
model_id = f"{user}/{model_name}" | |
import peft | |
phi2_model_pretrained_peft = peft.PeftModel.from_pretrained(phi2_model_pretrained, model_id) | |
def input_multimodel(image=None, audio=None, text=None, query=None): | |
if len(text) == 0: | |
text = None | |
if len(query) == 0: | |
query = None | |
if query is None: | |
print('Please ask a query') | |
return None | |
if image is None and audio is None and text is None: | |
print('Please provide context in form of image, audio, text') | |
return None | |
bos = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False) | |
input_embeds_stage_2 = phi2_model_pretrained_peft.get_input_embeddings()(bos.input_ids) | |
if image is not None: | |
image_embeds = img_input_embed(image) | |
input_embeds_stage_2 = torch.cat((input_embeds_stage_2, image_embeds), dim=1) | |
if audio is not None: | |
audio_transcribed = convert_audio_file_text_embeds(audio) | |
audio_embeds = convert_text_input_embeds(audio_transcribed) | |
input_embeds_stage_2 = torch.cat((input_embeds_stage_2, audio_embeds), dim=1) | |
if text is not None: | |
text_embeds = convert_text_input_embeds(text) | |
input_embeds_stage_2 = torch.cat((input_embeds_stage_2, text_embeds), dim=1) | |
qus = tokenizer(" Question: " + query, return_tensors="pt", | |
return_attention_mask=False) | |
qus_embeds = phi2_model_pretrained_peft.get_input_embeddings()(qus.input_ids) | |
input_embeds_stage_2 = torch.cat((input_embeds_stage_2, qus_embeds), dim=1) | |
ans = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False) | |
ans_embeds = phi2_model_pretrained_peft.get_input_embeddings()(ans.input_ids) | |
input_embeds_stage_2 = torch.cat((input_embeds_stage_2, ans_embeds), dim=1) | |
result = phi2_model_pretrained_peft.generate(inputs_embeds=input_embeds_stage_2, | |
bos_token_id = tokenizer.bos_token_id) | |
process = tokenizer.batch_decode(result)[0] | |
process = process.split(tokenizer.eos_token) | |
if process[0] == '': | |
return process[1] | |
else: | |
return process[0] | |
import gradio as gr | |
title = "Multi-Modal Phi-2 " | |
description = "A simple Gradio interface to use a custom Multi-modal (image, text, audio) version of Microsoft Phi-2" | |
demo = gr.Interface(input_multimodel, | |
inputs = [gr.Image(label="Input context Image"), | |
gr.Audio(label="Input context Audio", sources=["microphone", "upload"], type="filepath"), | |
gr.Textbox(label="Input context Text"), | |
gr.Textbox(label="Input Query"), | |
], | |
outputs = [ | |
gr.Textbox(label='Answer'), | |
], | |
title = title, | |
description = description, | |
) | |
demo.launch() | |