Lava_phi_model / app.py
sagar007's picture
Create app.py
8e31ab1 verified
raw
history blame
8.21 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoConfig, AutoModel
from PIL import Image
import logging
from transformers import BitsAndBytesConfig
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {self.device}")
# Initialize quantization config
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
try:
# Load model directly from Hugging Face Hub
logging.info(f"Loading model from {model_id}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
# Set up padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.eos_token_id
# Load CLIP model and processor
logging.info("Loading CLIP model and processor...")
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.clip = AutoModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
# Store conversation history
self.history = []
except Exception as e:
logging.error(f"Error initializing model: {str(e)}")
raise
def process_image(self, image):
"""Process image through CLIP"""
with torch.no_grad():
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
return image_features
def generate_response(self, message, image=None):
try:
if image is not None:
# Get image features
image_features = self.process_image(image)
# Format prompt
prompt = f"human: <image>\n{message}\ngpt:"
# Add context from history
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
# Prepare text inputs
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Add image features to inputs
inputs["image_features"] = image_features
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=40,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
else:
# Text-only response
prompt = f"human: {message}\ngpt:"
context = ""
for turn in self.history[-3:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
min_length=20,
temperature=0.6,
do_sample=True,
top_p=0.85,
top_k=30,
repetition_penalty=1.8,
no_repeat_ngram_size=4,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
# Update history
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
logging.error(f"Full traceback:", exc_info=True)
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
def create_demo():
# Initialize model
model = LLaVAPhiModel()
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
# LLaVA-Phi Demo
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history
response = model.generate_response(message, image)
chat_history.append((message, response))
return "", chat_history
def clear_chat():
model.clear_history()
return None, None
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
clear.click(
clear_chat,
None,
[chatbot, image],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)