Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel | |
from PIL import Image | |
import logging | |
import spaces | |
import numpy as np | |
logging.basicConfig(level=logging.INFO) | |
class LLaVAPhiModel: | |
def __init__(self, model_id="sagar007/Lava_phi"): | |
self.device = "cuda" | |
self.model_id = model_id | |
logging.info("Initializing LLaVA-Phi model...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
self.history = [] | |
self.model = None | |
self.clip = None | |
self.projection = None | |
def ensure_models_loaded(self): | |
if not torch.cuda.is_available(): | |
raise RuntimeError("CUDA is not available. This model requires a GPU.") | |
if self.model is None: | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_compute_dtype=torch.float16, | |
bnb_8bit_use_double_quant=False | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_id, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
) | |
self.model.config.pad_token_id = self.tokenizer.eos_token_id | |
logging.info("Successfully loaded main model on GPU") | |
if self.clip is None: | |
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
logging.info("Successfully loaded CLIP model") | |
embed_dim = self.model.config.hidden_size | |
clip_dim = self.clip.config.projection_dim | |
self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device) | |
# Rest of your class (process_image, generate_response, etc.) remains unchanged | |
# ... (omitted for brevity) | |
def create_demo(): | |
try: | |
model = LLaVAPhiModel() | |
demo = gr.Blocks(css="footer {visibility: hidden}") | |
with demo: | |
gr.Markdown( | |
""" | |
# LLaVA-Phi Demo (Optimized for Accuracy) | |
Chat with a vision-language model that can understand both text and images. | |
""" | |
) | |
chatbot = gr.Chatbot(height=400) | |
with gr.Row(): | |
with gr.Column(scale=0.7): | |
msg = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and/or upload an image", | |
container=False | |
) | |
with gr.Column(scale=0.15, min_width=0): | |
clear = gr.Button("Clear") | |
with gr.Column(scale=0.15, min_width=0): | |
submit = gr.Button("Submit", variant="primary") | |
image = gr.Image(type="pil", label="Upload Image (Optional)") | |
with gr.Accordion("Advanced Settings", open=False): | |
gr.Markdown("Adjust these parameters to control hallucination tendency") | |
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)") | |
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)") | |
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k") | |
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
update_params = gr.Button("Update Parameters") | |
def respond(message, chat_history, image): | |
if not message and image is None: | |
return chat_history | |
response = model.generate_response(message, image) | |
chat_history.append((message, response)) | |
return "", chat_history | |
def clear_chat(): | |
model.clear_history() | |
return None, None | |
def update_params_fn(temp, top_p, top_k, rep_penalty): | |
return model.update_generation_params(temp, top_p, top_k, rep_penalty) | |
submit.click( | |
respond, | |
[msg, chatbot, image], | |
[msg, chatbot], | |
) | |
clear.click( | |
clear_chat, | |
None, | |
[chatbot, image], | |
) | |
msg.submit( | |
respond, | |
[msg, chatbot, image], | |
[msg, chatbot], | |
) | |
update_params.click( | |
update_params_fn, | |
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider], | |
None | |
) | |
return demo | |
except Exception as e: | |
logging.error(f"Error creating demo: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |