File size: 5,366 Bytes
8e31ab1
 
94ee0c6
8e31ab1
 
8ec9ef4
06e746f
8e31ab1
 
 
 
 
94ee0c6
8ec9ef4
 
8e31ab1
8ec9ef4
 
 
06e746f
 
8ec9ef4
 
 
06e746f
8ec9ef4
 
 
1abfce8
 
8ec9ef4
 
 
06e746f
bd91e22
 
8ec9ef4
06e746f
 
 
 
 
 
 
 
1abfce8
8ec9ef4
 
06e746f
 
1abfce8
 
06e746f
2144e66
1abfce8
 
8ec9ef4
1abfce8
 
 
 
 
 
 
 
 
 
 
 
8ec9ef4
1abfce8
 
 
 
 
 
 
8e31ab1
1abfce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd91e22
1abfce8
 
 
8e31ab1
1abfce8
 
 
8e31ab1
1abfce8
 
8e31ab1
1abfce8
 
 
 
 
8e31ab1
1abfce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e31ab1
 
 
06e746f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy as np

logging.basicConfig(level=logging.INFO)

class LLaVAPhiModel:
    def __init__(self, model_id="sagar007/Lava_phi"):
        self.device = "cuda"
        self.model_id = model_id
        logging.info("Initializing LLaVA-Phi model...")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.history = []
        self.model = None
        self.clip = None
        self.projection = None

    @spaces.GPU
    def ensure_models_loaded(self):
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is not available. This model requires a GPU.")
        if self.model is None:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_compute_dtype=torch.float16,
                bnb_8bit_use_double_quant=False
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                quantization_config=quantization_config,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True
            )
            self.model.config.pad_token_id = self.tokenizer.eos_token_id
            logging.info("Successfully loaded main model on GPU")

        if self.clip is None:
            self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            logging.info("Successfully loaded CLIP model")
            embed_dim = self.model.config.hidden_size
            clip_dim = self.clip.config.projection_dim
            self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)

    # Rest of your class (process_image, generate_response, etc.) remains unchanged
    # ... (omitted for brevity)

def create_demo():
    try:
        model = LLaVAPhiModel()
        
        demo = gr.Blocks(css="footer {visibility: hidden}")
        with demo:
            gr.Markdown(
                """
                # LLaVA-Phi Demo (Optimized for Accuracy)
                Chat with a vision-language model that can understand both text and images.
                """
            )
            
            chatbot = gr.Chatbot(height=400)
            with gr.Row():
                with gr.Column(scale=0.7):
                    msg = gr.Textbox(
                        show_label=False,
                        placeholder="Enter text and/or upload an image",
                        container=False
                    )
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("Clear")
                with gr.Column(scale=0.15, min_width=0):
                    submit = gr.Button("Submit", variant="primary")
            
            image = gr.Image(type="pil", label="Upload Image (Optional)")
            
            with gr.Accordion("Advanced Settings", open=False):
                gr.Markdown("Adjust these parameters to control hallucination tendency")
                temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
                top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
                top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
                rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
                update_params = gr.Button("Update Parameters")
            
            def respond(message, chat_history, image):
                if not message and image is None:
                    return chat_history
                
                response = model.generate_response(message, image)
                chat_history.append((message, response))
                return "", chat_history
            
            def clear_chat():
                model.clear_history()
                return None, None
            
            def update_params_fn(temp, top_p, top_k, rep_penalty):
                return model.update_generation_params(temp, top_p, top_k, rep_penalty)
            
            submit.click(
                respond,
                [msg, chatbot, image],
                [msg, chatbot],
            )
            
            clear.click(
                clear_chat,
                None,
                [chatbot, image],
            )
            
            msg.submit(
                respond,
                [msg, chatbot, image],
                [msg, chatbot],
            )
            
            update_params.click(
                update_params_fn,
                [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
                None
            )
        
        return demo
    except Exception as e:
        logging.error(f"Error creating demo: {str(e)}")
        raise

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)