Lava_phi_model / app.py
sagar007's picture
Update app.py
066eb01 verified
raw
history blame
10 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel
from PIL import Image
import logging
import spaces
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda" # Always use cuda with ZeroGPU
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
# Initialize tokenizer (can be done outside GPU context)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
# Initialize processor (can be done outside GPU context)
self.processor = AutoProcessor.from_pretrained("huggingface/clip-vit-base-patch32")
except Exception as e:
logging.warning(f"Failed to load CLIP processor: {str(e)}")
# Fallback to basic tokenizer if needed
self.processor = None
# Store conversation history
self.history = []
# Lazy loading of models - will be initialized in GPU context
self.model = None
self.clip = None
@spaces.GPU
def ensure_models_loaded(self):
"""Ensure models are loaded in GPU context"""
if self.model is None:
# Load main model
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
if self.clip is None:
# Load CLIP model if not already loaded
if self.clip is None:
try:
self.clip = AutoModel.from_pretrained("huggingface/clip-vit-base-patch32").to(self.device)
except Exception as e:
logging.warning(f"Failed to load CLIP model: {str(e)}")
self.clip = None
@spaces.GPU
def process_image(self, image):
"""Process image through CLIP if available, otherwise return None"""
try:
# Ensure models are loaded
self.ensure_models_loaded()
# If CLIP isn't available, return None
if self.clip is None or self.processor is None:
logging.warning("CLIP model or processor not available - skipping image processing")
return None
# Convert image to correct format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, numpy.ndarray):
image = Image.fromarray(image)
with torch.no_grad():
try:
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
return image_features
except Exception as e:
logging.error(f"Error during image processing: {str(e)}")
return None
except Exception as e:
logging.error(f"Error in process_image: {str(e)}")
return None
@spaces.GPU(duration=120) # Set longer duration for generation
def generate_response(self, message, image=None):
try:
# Ensure models are loaded
self.ensure_models_loaded()
if image is not None:
image_features = self.process_image(image)
has_image = image_features is not None
if not has_image:
message = "Note: Image processing is not available - continuing with text only.\n" + message
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if has_image:
inputs["image_features"] = image_features
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=40,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
prompt = f"human: {message}\ngpt:"
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
min_length=20,
temperature=0.6,
do_sample=True,
top_p=0.85,
top_k=30,
repetition_penalty=1.8,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
def create_demo():
try:
model = LLaVAPhiModel()
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (ZeroGPU)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history
response = model.generate_response(message, image)
chat_history.append((message, response))
return "", chat_history
def clear_chat():
model.clear_history()
return None, None
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
clear.click(
clear_chat,
None,
[chatbot, image],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)