Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel | |
from PIL import Image | |
import logging | |
import spaces | |
import numpy | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
class LLaVAPhiModel: | |
def __init__(self, model_id="microsoft/phi-1_5"): # Updated to match config | |
self.device = "cuda" | |
self.model_id = model_id | |
logging.info(f"Initializing LLaVA-Phi model with {model_id}...") | |
# Initialize tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
try: | |
# Use CLIPProcessor with the correct model name from config | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
logging.info("Successfully loaded CLIP processor") | |
except Exception as e: | |
logging.error(f"Failed to load CLIP processor: {str(e)}") | |
self.processor = None | |
# Increase history length to retain more context | |
self.history = [] | |
self.model = None | |
self.clip = None | |
# Default generation parameters - can be updated from config | |
self.temperature = 0.3 | |
self.top_p = 0.92 | |
self.top_k = 50 | |
self.repetition_penalty = 1.2 | |
# Set max length from config | |
self.max_length = 512 # Default value, will be updated from config | |
def ensure_models_loaded(self): | |
"""Ensure models are loaded in GPU context""" | |
if self.model is None: | |
# Use 4-bit quantization according to config | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, # Changed to match config | |
bnb_4bit_compute_dtype=torch.bfloat16, # Changed to bfloat16 to match config's mixed_precision | |
bnb_4bit_use_double_quant=False | |
) | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_id, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
) | |
self.model.config.pad_token_id = self.tokenizer.eos_token_id | |
logging.info(f"Successfully loaded main model: {self.model_id}") | |
except Exception as e: | |
logging.error(f"Failed to load main model: {str(e)}") | |
raise | |
if self.clip is None: | |
try: | |
# Load CLIP model from config | |
clip_model_name = "openai/clip-vit-base-patch32" # From config | |
self.clip = CLIPModel.from_pretrained(clip_model_name).to(self.device) | |
logging.info(f"Successfully loaded CLIP model: {clip_model_name}") | |
except Exception as e: | |
logging.error(f"Failed to load CLIP model: {str(e)}") | |
self.clip = None | |
def apply_lora_config(self, lora_params): | |
"""Apply LoRA configuration to the model - to be called during training""" | |
from peft import LoraConfig, get_peft_model | |
lora_config = LoraConfig( | |
r=lora_params.get("r", 16), | |
lora_alpha=lora_params.get("lora_alpha", 32), | |
lora_dropout=lora_params.get("lora_dropout", 0.05), | |
target_modules=lora_params.get("target_modules", ["Wqkv", "out_proj"]), | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
# Convert model to PEFT/LoRA model | |
self.model = get_peft_model(self.model, lora_config) | |
logging.info("Applied LoRA configuration to the model") | |
return self.model | |
def generate_response(self, message, image=None): | |
try: | |
self.ensure_models_loaded() | |
# Prepare prompt based on whether we have an image | |
has_image = image is not None | |
# Process text input | |
if has_image: | |
# For image+text input | |
prompt = f"human: <image>\n{message}\ngpt:" | |
# Check if model has vision encoding capability | |
if not hasattr(self.model, "encode_image") and not hasattr(self.model, "get_vision_tower"): | |
logging.warning("Model doesn't have standard image encoding methods") | |
has_image = False | |
prompt = f"human: {message}\ngpt:" | |
else: | |
# For text-only input | |
prompt = f"human: {message}\ngpt:" | |
# Include previous conversation context | |
context = "" | |
for turn in self.history[-5:]: # Include 5 previous turns | |
context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
full_prompt = context + prompt | |
# Tokenize the input text | |
inputs = self.tokenizer( | |
full_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
# LLaVA-Phi specific image handling | |
if has_image: | |
try: | |
# Convert image to correct format | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, numpy.ndarray): | |
image = Image.fromarray(image) | |
# Ensure image is in RGB mode | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Process the image with CLIP processor | |
image_inputs = self.processor(images=image, return_tensors="pt") | |
image_features = self.clip.get_image_features( | |
pixel_values=image_inputs.pixel_values.to(self.device) | |
) | |
# Some LLaVA models have a prepare_inputs_for_generation method | |
if hasattr(self.model, "prepare_inputs_for_generation"): | |
logging.info("Using model's prepare_inputs_for_generation for image handling") | |
# Generate with image context | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
min_length=20, | |
temperature=self.temperature, | |
do_sample=True, | |
top_p=self.top_p, | |
top_k=self.top_k, | |
repetition_penalty=self.repetition_penalty, | |
no_repeat_ngram_size=3, | |
use_cache=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
except Exception as e: | |
logging.error(f"Error handling image: {str(e)}") | |
# Fall back to text-only generation | |
logging.info("Falling back to text-only generation") | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=256, | |
min_length=20, | |
temperature=self.temperature, | |
do_sample=True, | |
top_p=self.top_p, | |
top_k=self.top_k, | |
repetition_penalty=self.repetition_penalty, | |
no_repeat_ngram_size=3, | |
use_cache=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
else: | |
# Text-only generation | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=200, | |
min_length=20, | |
temperature=self.temperature, | |
do_sample=True, | |
top_p=self.top_p, | |
top_k=self.top_k, | |
repetition_penalty=self.repetition_penalty, | |
no_repeat_ngram_size=4, | |
use_cache=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode and clean up the response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up response | |
if "gpt:" in response: | |
response = response.split("gpt:")[-1].strip() | |
if "human:" in response: | |
response = response.split("human:")[0].strip() | |
if "<image>" in response: | |
response = response.replace("<image>", "").strip() | |
self.history.append((message, response)) | |
return response | |
except Exception as e: | |
logging.error(f"Error generating response: {str(e)}") | |
logging.error(f"Full traceback:", exc_info=True) | |
return f"Error: {str(e)}" | |
def clear_history(self): | |
self.history = [] | |
return None | |
# Add new function to control generation parameters | |
def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2): | |
"""Update generation parameters to control hallucination tendency""" | |
self.temperature = temperature | |
self.top_p = top_p | |
self.top_k = top_k | |
self.repetition_penalty = repetition_penalty | |
return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}" | |
# New method to apply config file settings | |
def apply_config(self, config): | |
"""Apply settings from config file""" | |
model_params = config.get("model_params", {}) | |
self.model_id = model_params.get("model_name", self.model_id) | |
self.max_length = model_params.get("max_length", 512) | |
# Update generation parameters if needed | |
training_params = config.get("training_params", {}) | |
# Could add specific updates based on training_params if needed | |
return f"Applied configuration. Model: {self.model_id}, Max Length: {self.max_length}" | |
def create_demo(config=None): | |
try: | |
# Initialize with config file settings | |
model = LLaVAPhiModel() | |
if config: | |
model.apply_config(config) | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
gr.Markdown( | |
""" | |
# LLaVA-Phi Demo (Optimized for Accuracy) | |
Chat with a vision-language model that can understand both text and images. | |
""" | |
) | |
chatbot = gr.Chatbot(height=400) | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
msg = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and/or upload an image", | |
container=False | |
) | |
with gr.Column(scale=0.15, min_width=0): | |
clear = gr.Button("Clear") | |
with gr.Column(scale=0.15, min_width=0): | |
submit = gr.Button("Submit", variant="primary") | |
image = gr.Image(type="pil", label="Upload Image (Optional)") | |
# Add generation parameter controls | |
with gr.Accordion("Advanced Settings (Reduce Hallucinations)", open=False): | |
gr.Markdown("Adjust these parameters to control hallucination tendency") | |
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)") | |
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)") | |
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k") | |
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
update_params = gr.Button("Update Parameters") | |
# Add debugging information box | |
debug_info = gr.Textbox(label="Debug Info", interactive=False) | |
# Add config information | |
if config: | |
config_info = f"Model: {model.model_id}, Max Length: {model.max_length}" | |
gr.Markdown(f"**Current Configuration:** {config_info}") | |
def respond(message, chat_history, image): | |
if not message and image is None: | |
return chat_history, "" | |
try: | |
response = model.generate_response(message, image) | |
chat_history.append((message, response)) | |
debug_msg = "Response generated successfully" | |
return "", chat_history, debug_msg | |
except Exception as e: | |
debug_msg = f"Error: {str(e)}" | |
return message, chat_history, debug_msg | |
def clear_chat(): | |
model.clear_history() | |
return None, None, "Chat history cleared" | |
def update_params_fn(temp, top_p, top_k, rep_penalty): | |
result = model.update_generation_params(temp, top_p, top_k, rep_penalty) | |
return f"Parameters updated: temp={temp}, top_p={top_p}, top_k={top_k}, rep_penalty={rep_penalty}" | |
submit.click( | |
respond, | |
[msg, chatbot, image], | |
[msg, chatbot, debug_info], | |
) | |
clear.click( | |
clear_chat, | |
None, | |
[chatbot, image, debug_info], | |
) | |
msg.submit( | |
respond, | |
[msg, chatbot, image], | |
[msg, chatbot, debug_info], | |
) | |
update_params.click( | |
update_params_fn, | |
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider], | |
[debug_info] | |
) | |
return demo | |
except Exception as e: | |
logging.error(f"Error creating demo: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
# Load config file | |
import json | |
try: | |
with open("config.json", "r") as f: | |
config = json.load(f) | |
logging.info("Successfully loaded config file") | |
except Exception as e: | |
logging.error(f"Error loading config: {str(e)}") | |
config = None | |
demo = create_demo(config) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |