Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel | |
from PIL import Image | |
import logging | |
import spaces | |
import numpy as np | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
class LLaVAPhiModel: | |
def __init__(self, model_id="sagar007/Lava_phi"): | |
self.device = "cuda" | |
self.model_id = model_id | |
logging.info("Initializing LLaVA-Phi model...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
self.history = [] | |
self.model = None | |
self.clip = None | |
# Add a linear projection layer to align CLIP features with text embeddings | |
self.projection = None | |
def ensure_models_loaded(self): | |
if self.model is None: | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=False | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_id, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
) | |
self.model.config.pad_token_id = self.tokenizer.eos_token_id | |
logging.info("Successfully loaded main model") | |
if self.clip is None: | |
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
logging.info("Successfully loaded CLIP model") | |
# Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi) | |
embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5 | |
clip_dim = self.clip.config.projection_dim # 512 for CLIP | |
self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device) | |
def process_image(self, image): | |
try: | |
self.ensure_models_loaded() | |
if self.clip is None or self.processor is None: | |
logging.warning("CLIP model or processor not available") | |
return None | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
with torch.no_grad(): | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
image_features = self.clip.get_image_features( | |
pixel_values=image_inputs.pixel_values.to(self.device) | |
) | |
# Project image features to text embedding space | |
projected_features = self.projection(image_features) | |
logging.info("Successfully processed image through CLIP") | |
return projected_features | |
except Exception as e: | |
logging.error(f"Error in process_image: {str(e)}") | |
return None | |
def generate_response(self, message, image=None): | |
try: | |
self.ensure_models_loaded() | |
if image is not None: | |
image_features = self.process_image(image) | |
has_image = image_features is not None | |
if not has_image: | |
message = "Note: Image processing is not available - continuing with text only.\n" + message | |
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:" | |
context = "" | |
for turn in self.history[-5:]: | |
context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
full_prompt = context + prompt | |
inputs = self.tokenizer( | |
full_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=1024 | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
if has_image: | |
# Convert input_ids to embeddings | |
embeddings = self.model.get_input_embeddings()(inputs["input_ids"]) | |
# Concatenate image features with text embeddings | |
image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim] | |
combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1) | |
inputs["inputs_embeds"] = combined_embeddings | |
# Update attention mask to account for the extra image token | |
inputs["attention_mask"] = torch.cat( | |
[torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device), | |
inputs["attention_mask"]], | |
dim=1 | |
) | |
# Remove input_ids since we're using inputs_embeds | |
del inputs["input_ids"] | |
else: | |
prompt = f"human: {message}\ngpt:" | |
context = "" | |
for turn in self.history[-5:]: | |
context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
full_prompt = context + prompt | |
inputs = self.tokenizer( | |
full_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=1024 | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
min_length=20, | |
temperature=0.3, | |
do_sample=True, | |
top_p=0.92, | |
top_k=50, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3, | |
use_cache=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if "gpt:" in response: | |
response = response.split("gpt:")[-1].strip() | |
if "human:" in response: | |
response = response.split("human:")[0].strip() | |
if "<image>" in response: | |
response = response.replace("<image>", "").strip() | |
self.history.append((message, response)) | |
return response | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
return f"Error: {str(e)}" | |
def clear_history(self): | |
self.history = [] | |
return None | |
def create_demo(): | |
model = LLaVAPhiModel() | |
# Rest of your Gradio setup remains the same | |
# ... (omitted for brevity) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |