File size: 7,634 Bytes
8e31ab1
 
94ee0c6
8e31ab1
 
8ec9ef4
06e746f
8e31ab1
 
 
 
 
 
94ee0c6
8ec9ef4
 
8e31ab1
8ec9ef4
 
 
06e746f
 
8ec9ef4
 
 
06e746f
 
 
8ec9ef4
 
 
 
 
 
06e746f
bd91e22
 
8ec9ef4
06e746f
 
 
 
 
 
 
 
 
8ec9ef4
 
06e746f
 
 
 
 
 
 
2144e66
8ec9ef4
8e31ab1
2144e66
8ec9ef4
066eb01
94ee0c6
066eb01
 
8ec9ef4
2144e66
06e746f
2144e66
94ee0c6
 
 
2144e66
06e746f
 
 
 
 
 
 
 
2144e66
066eb01
 
8ec9ef4
94ee0c6
8e31ab1
 
8ec9ef4
 
8e31ab1
066eb01
 
 
 
8e31ab1
2144e66
8e31ab1
bd91e22
8e31ab1
 
bd91e22
8e31ab1
 
 
 
 
06e746f
8e31ab1
 
 
2144e66
06e746f
 
 
 
 
 
 
 
 
 
 
8e31ab1
06e746f
 
8e31ab1
 
 
bd91e22
8e31ab1
 
bd91e22
8e31ab1
 
 
 
 
06e746f
8e31ab1
 
 
06e746f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e31ab1
06e746f
8e31ab1
 
 
 
 
 
 
 
 
 
 
 
 
94ee0c6
8e31ab1
 
 
 
 
06e746f
 
 
 
8e31ab1
 
 
06e746f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy as np

# Setup logging
logging.basicConfig(level=logging.INFO)

class LLaVAPhiModel:
    def __init__(self, model_id="sagar007/Lava_phi"):
        self.device = "cuda"
        self.model_id = model_id
        logging.info("Initializing LLaVA-Phi model...")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.history = []
        self.model = None
        self.clip = None
        
        # Add a linear projection layer to align CLIP features with text embeddings
        self.projection = None

    @spaces.GPU
    def ensure_models_loaded(self):
        if self.model is None:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_compute_dtype=torch.float16,
                bnb_8bit_use_double_quant=False
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                quantization_config=quantization_config,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True
            )
            self.model.config.pad_token_id = self.tokenizer.eos_token_id
            logging.info("Successfully loaded main model")

        if self.clip is None:
            self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
            logging.info("Successfully loaded CLIP model")
            
            # Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi)
            embed_dim = self.model.config.hidden_size  # e.g., 2048 for Phi-1.5
            clip_dim = self.clip.config.projection_dim  # 512 for CLIP
            self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)

    @spaces.GPU
    def process_image(self, image):
        try:
            self.ensure_models_loaded()
            if self.clip is None or self.processor is None:
                logging.warning("CLIP model or processor not available")
                return None
            
            if isinstance(image, str):
                image = Image.open(image)
            elif isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            with torch.no_grad():
                image_inputs = self.processor(images=image, return_tensors="pt")
                image_features = self.clip.get_image_features(
                    pixel_values=image_inputs.pixel_values.to(self.device)
                )
                # Project image features to text embedding space
                projected_features = self.projection(image_features)
                logging.info("Successfully processed image through CLIP")
                return projected_features
        except Exception as e:
            logging.error(f"Error in process_image: {str(e)}")
            return None

    @spaces.GPU(duration=120)
    def generate_response(self, message, image=None):
        try:
            self.ensure_models_loaded()
            
            if image is not None:
                image_features = self.process_image(image)
                has_image = image_features is not None
                if not has_image:
                    message = "Note: Image processing is not available - continuing with text only.\n" + message
                
                prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
                context = ""
                for turn in self.history[-5:]:
                    context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
                full_prompt = context + prompt
                
                inputs = self.tokenizer(
                    full_prompt, 
                    return_tensors="pt", 
                    padding=True,
                    truncation=True,
                    max_length=1024
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                if has_image:
                    # Convert input_ids to embeddings
                    embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
                    # Concatenate image features with text embeddings
                    image_features_expanded = image_features.unsqueeze(1)  # Shape: [batch, 1, embed_dim]
                    combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1)
                    inputs["inputs_embeds"] = combined_embeddings
                    # Update attention mask to account for the extra image token
                    inputs["attention_mask"] = torch.cat(
                        [torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device), 
                         inputs["attention_mask"]], 
                        dim=1
                    )
                    # Remove input_ids since we're using inputs_embeds
                    del inputs["input_ids"]
            else:
                prompt = f"human: {message}\ngpt:"
                context = ""
                for turn in self.history[-5:]:
                    context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
                full_prompt = context + prompt
                
                inputs = self.tokenizer(
                    full_prompt,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=1024
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    min_length=20,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.92,
                    top_k=50,
                    repetition_penalty=1.2,
                    no_repeat_ngram_size=3,
                    use_cache=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if "gpt:" in response:
                response = response.split("gpt:")[-1].strip()
            if "human:" in response:
                response = response.split("human:")[0].strip()
            if "<image>" in response:
                response = response.replace("<image>", "").strip()
            
            self.history.append((message, response))
            return response
            
        except Exception as e:
            logging.error(f"Error generating response: {str(e)}")
            return f"Error: {str(e)}"

    def clear_history(self):
        self.history = []
        return None

def create_demo():
    model = LLaVAPhiModel()
    # Rest of your Gradio setup remains the same
    # ... (omitted for brevity)
    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)