import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel from PIL import Image import logging import spaces import numpy as np # Setup logging logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = "cuda" self.model_id = model_id logging.info("Initializing LLaVA-Phi model...") self.tokenizer = AutoTokenizer.from_pretrained(model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self.history = [] self.model = None self.clip = None # Add a linear projection layer to align CLIP features with text embeddings self.projection = None @spaces.GPU def ensure_models_loaded(self): if self.model is None: from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=False ) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.model.config.pad_token_id = self.tokenizer.eos_token_id logging.info("Successfully loaded main model") if self.clip is None: self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) logging.info("Successfully loaded CLIP model") # Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi) embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5 clip_dim = self.clip.config.projection_dim # 512 for CLIP self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device) @spaces.GPU def process_image(self, image): try: self.ensure_models_loaded() if self.clip is None or self.processor is None: logging.warning("CLIP model or processor not available") return None if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') with torch.no_grad(): image_inputs = self.processor(images=image, return_tensors="pt") image_features = self.clip.get_image_features( pixel_values=image_inputs.pixel_values.to(self.device) ) # Project image features to text embedding space projected_features = self.projection(image_features) logging.info("Successfully processed image through CLIP") return projected_features except Exception as e: logging.error(f"Error in process_image: {str(e)}") return None @spaces.GPU(duration=120) def generate_response(self, message, image=None): try: self.ensure_models_loaded() if image is not None: image_features = self.process_image(image) has_image = image_features is not None if not has_image: message = "Note: Image processing is not available - continuing with text only.\n" + message prompt = f"human: {'' if has_image else ''}\n{message}\ngpt:" context = "" for turn in self.history[-5:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if has_image: # Convert input_ids to embeddings embeddings = self.model.get_input_embeddings()(inputs["input_ids"]) # Concatenate image features with text embeddings image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim] combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1) inputs["inputs_embeds"] = combined_embeddings # Update attention mask to account for the extra image token inputs["attention_mask"] = torch.cat( [torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device), inputs["attention_mask"]], dim=1 ) # Remove input_ids since we're using inputs_embeds del inputs["input_ids"] else: prompt = f"human: {message}\ngpt:" context = "" for turn in self.history[-5:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, min_length=20, temperature=0.3, do_sample=True, top_p=0.92, top_k=50, repetition_penalty=1.2, no_repeat_ngram_size=3, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if "gpt:" in response: response = response.split("gpt:")[-1].strip() if "human:" in response: response = response.split("human:")[0].strip() if "" in response: response = response.replace("", "").strip() self.history.append((message, response)) return response except Exception as e: logging.error(f"Error generating response: {str(e)}") return f"Error: {str(e)}" def clear_history(self): self.history = [] return None def create_demo(): model = LLaVAPhiModel() # Rest of your Gradio setup remains the same # ... (omitted for brevity) return demo if __name__ == "__main__": demo = create_demo() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)