import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoConfig, AutoModel from PIL import Image import logging from transformers import BitsAndBytesConfig # Setup logging logging.basicConfig(level=logging.INFO) class LLaVAPhiModel: def __init__(self, model_id="sagar007/Lava_phi"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device: {self.device}") # Initialize quantization config quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) try: # Load model directly from Hugging Face Hub logging.info(f"Loading model from {model_id}...") self.model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained(model_id) # Set up padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model.config.pad_token_id = self.tokenizer.eos_token_id # Load CLIP model and processor logging.info("Loading CLIP model and processor...") self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") self.clip = AutoModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) # Store conversation history self.history = [] except Exception as e: logging.error(f"Error initializing model: {str(e)}") raise def process_image(self, image): """Process image through CLIP""" with torch.no_grad(): image_inputs = self.processor(images=image, return_tensors="pt") image_features = self.clip.get_image_features( pixel_values=image_inputs.pixel_values.to(self.device) ) return image_features def generate_response(self, message, image=None): try: if image is not None: # Get image features image_features = self.process_image(image) # Format prompt prompt = f"human: \n{message}\ngpt:" # Add context from history context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt # Prepare text inputs inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Add image features to inputs inputs["image_features"] = image_features # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, min_length=20, temperature=0.7, do_sample=True, top_p=0.9, top_k=40, repetition_penalty=1.5, no_repeat_ngram_size=3, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) else: # Text-only response prompt = f"human: {message}\ngpt:" context = "" for turn in self.history[-3:]: context += f"human: {turn[0]}\ngpt: {turn[1]}\n" full_prompt = context + prompt inputs = self.tokenizer( full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=150, min_length=20, temperature=0.6, do_sample=True, top_p=0.85, top_k=30, repetition_penalty=1.8, no_repeat_ngram_size=4, use_cache=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode response response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up response if "gpt:" in response: response = response.split("gpt:")[-1].strip() if "human:" in response: response = response.split("human:")[0].strip() if "" in response: response = response.replace("", "").strip() # Update history self.history.append((message, response)) return response except Exception as e: logging.error(f"Error generating response: {str(e)}") logging.error(f"Full traceback:", exc_info=True) return f"Error: {str(e)}" def clear_history(self): self.history = [] return None def create_demo(): # Initialize model model = LLaVAPhiModel() with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown( """ # LLaVA-Phi Demo Chat with a vision-language model that can understand both text and images. """ ) chatbot = gr.Chatbot(height=400) with gr.Row(): with gr.Column(scale=0.7): msg = gr.Textbox( show_label=False, placeholder="Enter text and/or upload an image", container=False ) with gr.Column(scale=0.15, min_width=0): clear = gr.Button("Clear") with gr.Column(scale=0.15, min_width=0): submit = gr.Button("Submit", variant="primary") image = gr.Image(type="pil", label="Upload Image (Optional)") def respond(message, chat_history, image): if not message and image is None: return chat_history response = model.generate_response(message, image) chat_history.append((message, response)) return "", chat_history def clear_chat(): model.clear_history() return None, None submit.click( respond, [msg, chatbot, image], [msg, chatbot], ) clear.click( clear_chat, None, [chatbot, image], ) msg.submit( respond, [msg, chatbot, image], [msg, chatbot], ) return demo if __name__ == "__main__": demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )