Spaces:
Running
Running
File size: 7,634 Bytes
8e31ab1 94ee0c6 8e31ab1 8ec9ef4 06e746f 8e31ab1 94ee0c6 8ec9ef4 8e31ab1 8ec9ef4 06e746f 8ec9ef4 06e746f 8ec9ef4 06e746f bd91e22 8ec9ef4 06e746f 8ec9ef4 06e746f 2144e66 8ec9ef4 8e31ab1 2144e66 8ec9ef4 066eb01 94ee0c6 066eb01 8ec9ef4 2144e66 06e746f 2144e66 94ee0c6 2144e66 06e746f 2144e66 066eb01 8ec9ef4 94ee0c6 8e31ab1 8ec9ef4 8e31ab1 066eb01 8e31ab1 2144e66 8e31ab1 bd91e22 8e31ab1 bd91e22 8e31ab1 06e746f 8e31ab1 2144e66 06e746f 8e31ab1 06e746f 8e31ab1 bd91e22 8e31ab1 bd91e22 8e31ab1 06e746f 8e31ab1 06e746f 8e31ab1 06e746f 8e31ab1 94ee0c6 8e31ab1 06e746f 8e31ab1 06e746f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import logging
import spaces
import numpy as np
# Setup logging
logging.basicConfig(level=logging.INFO)
class LLaVAPhiModel:
def __init__(self, model_id="sagar007/Lava_phi"):
self.device = "cuda"
self.model_id = model_id
logging.info("Initializing LLaVA-Phi model...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.history = []
self.model = None
self.clip = None
# Add a linear projection layer to align CLIP features with text embeddings
self.projection = None
@spaces.GPU
def ensure_models_loaded(self):
if self.model is None:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=False
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.model.config.pad_token_id = self.tokenizer.eos_token_id
logging.info("Successfully loaded main model")
if self.clip is None:
self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
logging.info("Successfully loaded CLIP model")
# Initialize projection layer (CLIP features: 512-dim, model embedding size: e.g., 2048 for Phi)
embed_dim = self.model.config.hidden_size # e.g., 2048 for Phi-1.5
clip_dim = self.clip.config.projection_dim # 512 for CLIP
self.projection = torch.nn.Linear(clip_dim, embed_dim).to(self.device)
@spaces.GPU
def process_image(self, image):
try:
self.ensure_models_loaded()
if self.clip is None or self.processor is None:
logging.warning("CLIP model or processor not available")
return None
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
with torch.no_grad():
image_inputs = self.processor(images=image, return_tensors="pt")
image_features = self.clip.get_image_features(
pixel_values=image_inputs.pixel_values.to(self.device)
)
# Project image features to text embedding space
projected_features = self.projection(image_features)
logging.info("Successfully processed image through CLIP")
return projected_features
except Exception as e:
logging.error(f"Error in process_image: {str(e)}")
return None
@spaces.GPU(duration=120)
def generate_response(self, message, image=None):
try:
self.ensure_models_loaded()
if image is not None:
image_features = self.process_image(image)
has_image = image_features is not None
if not has_image:
message = "Note: Image processing is not available - continuing with text only.\n" + message
prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if has_image:
# Convert input_ids to embeddings
embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
# Concatenate image features with text embeddings
image_features_expanded = image_features.unsqueeze(1) # Shape: [batch, 1, embed_dim]
combined_embeddings = torch.cat([image_features_expanded, embeddings], dim=1)
inputs["inputs_embeds"] = combined_embeddings
# Update attention mask to account for the extra image token
inputs["attention_mask"] = torch.cat(
[torch.ones(inputs["attention_mask"].shape[0], 1).to(self.device),
inputs["attention_mask"]],
dim=1
)
# Remove input_ids since we're using inputs_embeds
del inputs["input_ids"]
else:
prompt = f"human: {message}\ngpt:"
context = ""
for turn in self.history[-5:]:
context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
full_prompt = context + prompt
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=256,
min_length=20,
temperature=0.3,
do_sample=True,
top_p=0.92,
top_k=50,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
use_cache=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "gpt:" in response:
response = response.split("gpt:")[-1].strip()
if "human:" in response:
response = response.split("human:")[0].strip()
if "<image>" in response:
response = response.replace("<image>", "").strip()
self.history.append((message, response))
return response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
return f"Error: {str(e)}"
def clear_history(self):
self.history = []
return None
def create_demo():
model = LLaVAPhiModel()
# Rest of your Gradio setup remains the same
# ... (omitted for brevity)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |