Lava_phi_model / app.py
sagar007's picture
Update app.py
06e746f verified
raw
history blame
7.63 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy as np
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda"
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.history = []
self.model = None
self.clip = None
# Add a linear projection layer to align CLIP features with text embeddings
self.projection = None
@spaces.GPU
def ensure_models_loaded(self):
if self.model is None:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=False
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info("Successfully loaded main model")
if self.clip is None:
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
logging.info("Successfully loaded CLIP model")
# Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi)
embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5
clip_dim = self.clip.config.projection_dim # 512 for CLIP
self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
@spaces.GPU
def process_image(self, image):
try:
self.ensure_models_loaded()
if self.clip is None or self.processor is None:
logging.warning("CLIP model or processor not available")
return None
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
with torch.no_grad():
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
# Project image features to text embedding space
projected_features = self.projection(image_features)
logging.info("Successfully processed image through CLIP")
return projected_features
except Exception as e:
logging.error(f"Error in process_image: {str(e)}")
return None
@spaces.GPU(duration=120)
def generate_response(self, message, image=None):
try:
self.ensure_models_loaded()
if image is not None:
image_features = self.process_image(image)
has_image = image_features is not None
if not has_image:
message = "Note: Image processing is not available - continuing with text only.\n" + message
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if has_image:
# Convert input_ids to embeddings
embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
# Concatenate image features with text embeddings
image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim]
combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1)
inputs["inputs_embeds"] = combined_embeddings
# Update attention mask to account for the extra image token
inputs["attention_mask"] = torch.cat(
[torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device),
inputs["attention_mask"]],
dim=1
)
# Remove input_ids since we're using inputs_embeds
del inputs["input_ids"]
else:
prompt = f"human: {message}\ngpt:"
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.3,
do_sample=True,
top_p=0.92,
top_k=50,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
def create_demo():
model = LLaVAPhiModel()
# Rest of your Gradio setup remains the same
# ... (omitted for brevity)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)