import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel from PIL import Image import logging import spaces import numpy # Setup logging logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = "cuda" self.model_id = model_id logging.info("Initializing LLaVA-Phi model...") # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token try: # Use CLIPProcessor directly instead of AutoProcessor self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") logging.info("Successfully loaded CLIP processor") except Exception as e: logging.error(f"Failed to load CLIP processor: {str(e)}") self.processor = None self.history = [] self.model = None self.clip = None @spaces.GPU def ensure_models_loaded(self): """Ensure models are loaded in GPU context""" if self.model is None: # Load main model with updated quantization config from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) try: self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.model.config.pad_token_id = self.tokenizer.eos_token_id logging.info("Successfully loaded main model") except Exception as e: logging.error(f"Failed to load main model: {str(e)}") raise if self.clip is None: try: # Use CLIPModel directly instead of AutoModel self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) logging.info("Successfully loaded CLIP model") except Exception as e: logging.error(f"Failed to load CLIP model: {str(e)}") self.clip = None @spaces.GPU def process_image(self, image): """Process image through CLIP if available""" try: self.ensure_models_loaded() if self.clip is None or self.processor is None: logging.warning("CLIP model or processor not available") return None # Convert image to correct format if isinstance(image, str): image = Image.open(image) elif isinstance(image, numpy.ndarray): image = Image.fromarray(image) # Ensure image is in RGB mode if image.mode != 'RGB': image = image.convert('RGB') with torch.no_grad(): try: # Process image with error handling image_inputs = self.processor(images=image, return_tensors="pt") image_features = self.clip.get_image_features( pixel_values=image_inputs.pixel_values.to(self.device) ) logging.info("Successfully processed image through CLIP") return image_features except Exception as e: logging.error(f"Error during image processing: {str(e)}") return None except Exception as e: logging.error(f"Error in process_image: {str(e)}") return None @spaces.GPU(duration=120) def generate_response(self, message, image=None): try: self.ensure_models_loaded() if image is not None: image_features = self.process_image(image) has_image = image_features is not None if not has_image: message = "Note: Image processing is not available - continuing with text only.\n" + message prompt = f"human: {'' if has_image else ''}\n{message}\ngpt:" context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if has_image: inputs["image_features"] = image_features with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, min_length=20, temperature=0.7, do_sample=True, top_p=0.9, top_k=40, repetition_penalty=1.5, no_repeat_ngram_size=3, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) else: prompt = f"human: {message}\ngpt:" context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=150, min_length=20, temperature=0.6, do_sample=True, top_p=0.85, top_k=30, repetition_penalty=1.8, no_repeat_ngram_size=4, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up response if "gpt:" in response: response = response.split("gpt:")[-1].strip() if "human:" in response: response = response.split("human:")[0].strip() if "" in response: response = response.replace("", "").strip() self.history.append((message, response)) return response except Exception as e: logging.error(f"Error generating response: {str(e)}") logging.error(f"Full traceback:", exc_info=True) return f"Error: {str(e)}" def clear_history(self): self.history = [] return None def create_demo(): try: model = LLaVAPhiModel() with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown( """ # LLaVA-Phi Demo (ZeroGPU) Chat with a vision-language model that can understand both text and images. """ ) chatbot = gr.Chatbot(height=400) with gr.Row(): with gr.Column(scale=0.7): msg = gr.Textbox( show_label=False, placeholder="Enter text and/or upload an image", container=False ) with gr.Column(scale=0.15, min_width=0): clear = gr.Button("Clear") with gr.Column(scale=0.15, min_width=0): submit = gr.Button("Submit", variant="primary") image = gr.Image(type="pil", label="Upload Image (Optional)") def respond(message, chat_history, image): if not message and image is None: return chat_history response = model.generate_response(message, image) chat_history.append((message, response)) return "", chat_history def clear_chat(): model.clear_history() return None, None submit.click( respond, [msg, chatbot, image], [msg, chatbot], ) clear.click( clear_chat, None, [chatbot, image], ) msg.submit( respond, [msg, chatbot, image], [msg, chatbot], ) return demo except Exception as e: logging.error(f"Error creating demo: {str(e)}") raise if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )