Lava_phi_model / app.py
sagar007's picture
Update app.py
8ec9ef4 verified
raw
history blame
9.24 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel
from PIL import Image
import logging
import spaces
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda" # Always use cuda with ZeroGPU
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
# Initialize tokenizer (can be done outside GPU context)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize processor (can be done outside GPU context)
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Store conversation history
self.history = []
# Lazy loading of models - will be initialized in GPU context
self.model = None
self.clip = None
@spaces.GPU
def ensure_models_loaded(self):
"""Ensure models are loaded in GPU context"""
if self.model is None:
# Load main model
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
if self.clip is None:
# Load CLIP model
self.clip = AutoModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
@spaces.GPU
def process_image(self, image):
"""Process image through CLIP"""
try:
# Ensure models are loaded
self.ensure_models_loaded()
# Convert image to correct format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, numpy.ndarray):
image = Image.fromarray(image)
with torch.no_grad():
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
return image_features
except Exception as e:
logging.error(f"Error processing image: {str(e)}")
raise
@spaces.GPU(duration=120) # Set longer duration for generation
def generate_response(self, message, image=None):
try:
# Ensure models are loaded
self.ensure_models_loaded()
if image is not None:
try:
image_features = self.process_image(image)
has_image = True
except Exception as e:
logging.error(f"Failed to process image: {str(e)}")
image_features = None
has_image = False
message = f"Note: Failed to process image. Continuing with text only. Error: {str(e)}\n{message}"
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if has_image:
inputs["image_features"] = image_features
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=40,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
prompt = f"human: {message}\ngpt:"
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
min_length=20,
temperature=0.6,
do_sample=True,
top_p=0.85,
top_k=30,
repetition_penalty=1.8,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
def create_demo():
try:
model = LLaVAPhiModel()
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (ZeroGPU)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history
response = model.generate_response(message, image)
chat_history.append((message, response))
return "", chat_history
def clear_chat():
model.clear_history()
return None, None
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
clear.click(
clear_chat,
None,
[chatbot, image],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)