Lava_phi_model / app.py
sagar007's picture
Update app.py
bd91e22 verified
raw
history blame
12.4 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda"
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
# Use CLIPProcessor directly instead of AutoProcessor
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
logging.info("Successfully loaded CLIP processor")
except Exception as e:
logging.error(f"Failed to load CLIP processor: {str(e)}")
self.processor = None
# Increase history length to retain more context
self.history = []
self.model = None
self.clip = None
@spaces.GPU
def ensure_models_loaded(self):
"""Ensure models are loaded in GPU context"""
if self.model is None:
# Improved quantization config for better quality
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, # Changed from 4-bit to 8-bit for better quality
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=False
)
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info("Successfully loaded main model")
except Exception as e:
logging.error(f"Failed to load main model: {str(e)}")
raise
if self.clip is None:
try:
# Use CLIPModel directly instead of AutoModel
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
logging.info("Successfully loaded CLIP model")
except Exception as e:
logging.error(f"Failed to load CLIP model: {str(e)}")
self.clip = None
@spaces.GPU
def process_image(self, image):
"""Process image through CLIP if available"""
try:
self.ensure_models_loaded()
if self.clip is None or self.processor is None:
logging.warning("CLIP model or processor not available")
return None
# Convert image to correct format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, numpy.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
with torch.no_grad():
try:
# Process image with error handling
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
logging.info("Successfully processed image through CLIP")
return image_features
except Exception as e:
logging.error(f"Error during image processing: {str(e)}")
return None
except Exception as e:
logging.error(f"Error in process_image: {str(e)}")
return None
@spaces.GPU(duration=120)
def generate_response(self, message, image=None):
try:
self.ensure_models_loaded()
if image is not None:
image_features = self.process_image(image)
has_image = image_features is not None
if not has_image:
message = "Note: Image processing is not available - continuing with text only.\n" + message
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
# Include more history for better context (previous 5 turns instead of 3)
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
# Increased context window
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024 # Increased from 512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if has_image:
inputs["image_features"] = image_features
with torch.no_grad():
# More conservative generation settings to reduce hallucinations
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.3, # Reduced from 0.7 for more deterministic output
do_sample=True,
top_p=0.92,
top_k=50,
repetition_penalty=1.2, # Adjusted for more natural responses
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
prompt = f"human: {message}\ngpt:"
# Include more history
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
# Increased context window
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024 # Increased from 512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
# More conservative generation settings
outputs = self.model.generate(
**inputs,
max_new_tokens=200, # Slightly increased from 150
min_length=20,
temperature=0.3, # Reduced from 0.6
do_sample=True,
top_p=0.92,
top_k=50,
repetition_penalty=1.2,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
# Add new function to control generation parameters
def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
"""Update generation parameters to control hallucination tendency"""
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
def create_demo():
try:
model = LLaVAPhiModel()
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (Optimized for Accuracy)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
# Add generation parameter controls
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("Adjust these parameters to control hallucination tendency")
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
update_params = gr.Button("Update Parameters")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history
response = model.generate_response(message, image)
chat_history.append((message, response))
return "", chat_history
def clear_chat():
model.clear_history()
return None, None
def update_params_fn(temp, top_p, top_k, rep_penalty):
return model.update_generation_params(temp, top_p, top_k, rep_penalty)
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
clear.click(
clear_chat,
None,
[chatbot, image],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
update_params.click(
update_params_fn,
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
None
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)