Spaces:
Running
Running
File size: 5,366 Bytes
8e31ab1 94ee0c6 8e31ab1 8ec9ef4 06e746f 8e31ab1 94ee0c6 8ec9ef4 8e31ab1 8ec9ef4 06e746f 8ec9ef4 06e746f 8ec9ef4 1abfce8 8ec9ef4 06e746f bd91e22 8ec9ef4 06e746f 1abfce8 8ec9ef4 06e746f 1abfce8 06e746f 2144e66 1abfce8 8ec9ef4 1abfce8 8ec9ef4 1abfce8 8e31ab1 1abfce8 bd91e22 1abfce8 8e31ab1 1abfce8 8e31ab1 1abfce8 8e31ab1 1abfce8 8e31ab1 1abfce8 8e31ab1 06e746f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy as np
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda"
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.history = []
self.model = None
self.clip = None
self.projection = None
@spaces.GPU
def ensure_models_loaded(self):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. This model requires a GPU.")
if self.model is None:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=False
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info("Successfully loaded main model on GPU")
if self.clip is None:
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
logging.info("Successfully loaded CLIP model")
embed_dim = self.model.config.hidden_size
clip_dim = self.clip.config.projection_dim
self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
# Rest of your class (process_image, generate_response, etc.) remains unchanged
# ... (omitted for brevity)
def create_demo():
try:
model = LLaVAPhiModel()
demo = gr.Blocks(css="footer {visibility: hidden}")
with demo:
gr.Markdown(
"""
# LLaVA-Phi Demo (Optimized for Accuracy)
Chat with a vision-language model that can understand both text and images.
"""
)
chatbot = gr.Chatbot(height=400)
with gr.Row():
with gr.Column(scale=0.7):
msg = gr.Textbox(
show_label=False,
placeholder="Enter text and/or upload an image",
container=False
)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
with gr.Column(scale=0.15, min_width=0):
submit = gr.Button("Submit", variant="primary")
image = gr.Image(type="pil", label="Upload Image (Optional)")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("Adjust these parameters to control hallucination tendency")
temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
update_params = gr.Button("Update Parameters")
def respond(message, chat_history, image):
if not message and image is None:
return chat_history
response = model.generate_response(message, image)
chat_history.append((message, response))
return "", chat_history
def clear_chat():
model.clear_history()
return None, None
def update_params_fn(temp, top_p, top_k, rep_penalty):
return model.update_generation_params(temp, top_p, top_k, rep_penalty)
submit.click(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
clear.click(
clear_chat,
None,
[chatbot, image],
)
msg.submit(
respond,
[msg, chatbot, image],
[msg, chatbot],
)
update_params.click(
update_params_fn,
[temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
None
)
return demo
except Exception as e:
logging.error(f"Error creating demo: {str(e)}")
raise
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |