Lava_phi_model / app.py
sagar007's picture
Update app.py
f87dcd8 verified
raw
history blame
15.8 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config
self.device = "cuda"
self.model_id = model_id
logging.info(f"Initializing LLaVA-Phi model with {model_id}...")
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
# Use CLIPProcessor with the correct model name from config
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
logging.info("Successfully loaded CLIP processor")
except Exception as e:
logging.error(f"Failed to load CLIP processor: {str(e)}")
self.processor = None
# Increase history length to retain more context
self.history = []
self.model = None
self.clip = None
# Default generation parameters - can be updated from config
self.temperature = 0.3
self.top_p = 0.92
self.top_k = 50
self.repetition_penalty = 1.2
# Set max length from config
self.max_length = 512 # Default value, will be updated from config
@spaces.GPU
def ensure_models_loaded(self):
"""Ensure models are loaded in GPU context"""
if self.model is None:
# Use 4-bit quantization according to config
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Changed to match config
bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision
bnb_4bit_use_double_quant=False
)
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info(f"Successfully loaded main model: {self.model_id}")
except Exception as e:
logging.error(f"Failed to load main model: {str(e)}")
raise
if self.clip is None:
try:
# Load CLIP model from config
clip_model_name = "openai/clip-vit-base-patch32" # From config
self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device)
logging.info(f"Successfully loaded CLIP model: {clip_model_name}")
except Exception as e:
logging.error(f"Failed to load CLIP model: {str(e)}")
self.clip = None
def apply_lora_config(self, lora_params):
"""Apply LoRA configuration to the model - to be called during training"""
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=lora_params.get("r", 16),
lora_alpha=lora_params.get("lora_alpha", 32),
lora_dropout=lora_params.get("lora_dropout", 0.05),
target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]),
bias="none",
task_type="CAUSAL_LM"
)
# Convert model to PEFT/LoRA model
self.model = get_peft_model(self.model, lora_config)
logging.info("Applied LoRA configuration to the model")
return self.model
@spaces.GPU(duration=120)
def generate_response(self, message, image=None):
try:
self.ensure_models_loaded()
# Prepare prompt based on whether we have an image
has_image = image is not None
# Process text input
if has_image:
# For image+text input
prompt = f"human: <image>\n{message}\ngpt:"
# Check if model has vision encoding capability
if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"):
logging.warning("Model doesn't have standard image encoding methods")
has_image = False
prompt = f"human: {message}\ngpt:"
else:
# For text-only input
prompt = f"human: {message}\ngpt:"
# Include previous conversation context
context = ""
for turn in self.history[-5:]: # Include 5 previous turns
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
# Tokenize the input text
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# LLaVA-Phi specific image handling
if has_image:
try:
# Convert image to correct format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, numpy.ndarray):
image = Image.fromarray(image)
# Ensure image is in RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# Process the image with CLIP processor
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
# Some LLaVA models have a prepare_inputs_for_generation method
if hasattr(self.model, "prepare_inputs_for_generation"):
logging.info("Using model's prepare_inputs_for_generation for image handling")
# Generate with image context
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
except Exception as e:
logging.error(f"Error handling image: {str(e)}")
# Fall back to text-only generation
logging.info("Falling back to text-only generation")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
# Text-only generation
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
min_length=20,
temperature=self.temperature,
do_sample=True,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode and clean up the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
# Add new function to control generation parameters
def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
"""Update generation parameters to control hallucination tendency"""
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
# New method to apply config file settings
def apply_config(self, config):
"""Apply settings from config file"""
model_params = config.get("model_params", {})
self.model_id = model_params.get("model_name", self.model_id)
self.max_length = model_params.get("max_length", 512)
# Update generation parameters if needed
training_params = config.get("training_params", {})
# Could add specific updates based on training_params if needed
return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}"
def create_demo(config=None):
try:
# Initialize with config file settings
model = LLaVAPhiModel()
if config:
model.apply_config(config)
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (Optimized for Accuracy)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
# Add generation parameter controls
with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False):
gr.Markdown("Adjust these parameters to control hallucination tendency")
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
update_params = gr.Button("Update Parameters")
# Add debugging information box
debug_info = gr.Textbox(label="Debug Info", interactive=False)
# Add config information
if config:
config_info = f"Model: {model.model_id}, Max Length: {model.max_length}"
gr.Markdown(f"**Current Configuration:** {config_info}")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history, ""
try:
response = model.generate_response(message, image)
chat_history.append((message, response))
debug_msg = "Response generated successfully"
return "", chat_history, debug_msg
except Exception as e:
debug_msg = f"Error: {str(e)}"
return message, chat_history, debug_msg
def clear_chat():
model.clear_history()
return None, None, "Chat history cleared"
def update_params_fn(temp, top_p, top_k, rep_penalty):
result = model.update_generation_params(temp, top_p, top_k, rep_penalty)
return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}"
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot, debug_info],
)
clear.click(
clear_chat,
None,
[chatbot, image, debug_info],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot, debug_info],
)
update_params.click(
update_params_fn,
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
[debug_info]
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
# Load config file
import json
try:
with open("config.json", "r") as f:
config = json.load(f)
logging.info("Successfully loaded config file")
except Exception as e:
logging.error(f"Error loading config: {str(e)}")
config = None
demo = create_demo(config)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)