fix added
Browse files- SkinGPT.py +3 -149
- app.py +64 -50
SkinGPT.py
CHANGED
@@ -1,29 +1,21 @@
|
|
1 |
-
import torch
|
2 |
from torch import nn
|
3 |
from torchvision import transforms
|
4 |
-
from PIL import Image
|
5 |
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
|
6 |
from eva_vit import create_eva_vit_g
|
7 |
import requests
|
8 |
from io import BytesIO
|
9 |
import os
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
-
from transformers import BitsAndBytesConfig
|
12 |
-
from accelerate import init_empty_weights
|
13 |
import torch
|
14 |
-
from torch.cuda.amp import autocast
|
15 |
-
import warnings
|
16 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
17 |
token = os.getenv("HF_TOKEN")
|
18 |
import streamlit as st
|
19 |
-
import torch.nn.functional as F
|
20 |
|
21 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
class Blip2QFormer(nn.Module):
|
23 |
def __init__(self, num_query_tokens=32, vision_width=1408):
|
24 |
super().__init__()
|
25 |
self.num_query_tokens = num_query_tokens
|
26 |
-
# Load pre-trained Q-Former config
|
27 |
self.bert_config = BertConfig(
|
28 |
vocab_size=30522,
|
29 |
hidden_size=768,
|
@@ -48,8 +40,6 @@ class Blip2QFormer(nn.Module):
|
|
48 |
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
|
49 |
)
|
50 |
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
|
51 |
-
|
52 |
-
# Initialize weights
|
53 |
self._init_weights()
|
54 |
|
55 |
def _init_weights(self):
|
@@ -67,16 +57,7 @@ class Blip2QFormer(nn.Module):
|
|
67 |
msg = self.load_state_dict(state_dict, strict=False)
|
68 |
|
69 |
def forward(self, visual_features):
|
70 |
-
|
71 |
-
print(
|
72 |
-
f"Visual features stats - min: {visual_features.min().item():.4f}, max: {visual_features.max().item():.4f}")
|
73 |
-
|
74 |
-
# Project visual features
|
75 |
visual_embeds = self.vision_proj(visual_features.float())
|
76 |
-
print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
|
77 |
-
# visual_embeds = self.vision_proj(visual_features.float())
|
78 |
-
|
79 |
-
# Expand query tokens
|
80 |
query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
|
81 |
combined_input = torch.cat([query_tokens, visual_embeds], dim=1)
|
82 |
attention_mask = torch.ones(
|
@@ -84,14 +65,11 @@ class Blip2QFormer(nn.Module):
|
|
84 |
dtype=torch.long,
|
85 |
device=combined_input.device
|
86 |
)
|
87 |
-
|
88 |
-
# Forward through BERT
|
89 |
outputs = self.bert(
|
90 |
attention_mask=attention_mask,
|
91 |
inputs_embeds=combined_input,
|
92 |
return_dict=True
|
93 |
)
|
94 |
-
|
95 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
96 |
|
97 |
|
@@ -100,17 +78,14 @@ class SkinGPT4(nn.Module):
|
|
100 |
def __init__(self, vit_checkpoint_path,
|
101 |
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
|
102 |
super().__init__()
|
103 |
-
# Image encoder parameters from paper
|
104 |
self.device = device
|
105 |
-
# self.dtype = torch.float16
|
106 |
self.dtype = MODEL_DTYPE
|
107 |
self.H, self.W, self.C = 224, 224, 3
|
108 |
-
self.P = 14
|
109 |
-
self.D = 1408
|
110 |
self.num_query_tokens = 32
|
111 |
|
112 |
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
|
113 |
-
print("Loaded ViT")
|
114 |
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
|
115 |
|
116 |
self.q_former = Blip2QFormer(
|
@@ -120,27 +95,17 @@ class SkinGPT4(nn.Module):
|
|
120 |
self.q_former.load_from_pretrained(q_former_model)
|
121 |
for param in self.q_former.parameters():
|
122 |
param.requires_grad = False
|
123 |
-
|
124 |
-
print("Loaded QFormer")
|
125 |
-
|
126 |
-
|
127 |
self.llama = self._init_llama()
|
128 |
-
|
129 |
self.llama_proj = nn.Linear(
|
130 |
self.q_former.bert_config.hidden_size,
|
131 |
self.llama.config.hidden_size
|
132 |
).to(self.dtype)
|
133 |
-
|
134 |
-
print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
|
135 |
-
print(f"LLaMA input dim: {self.llama.config.hidden_size}")
|
136 |
-
|
137 |
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
|
138 |
for param in module.parameters():
|
139 |
param.requires_grad = False
|
140 |
module.eval()
|
141 |
|
142 |
def _init_vit(self, vit_checkpoint_path):
|
143 |
-
"""Initialize EVA-ViT-G with paper specifications"""
|
144 |
vit = create_eva_vit_g(
|
145 |
img_size=(self.H, self.W),
|
146 |
patch_size=self.P,
|
@@ -156,24 +121,17 @@ class SkinGPT4(nn.Module):
|
|
156 |
if not hasattr(vit, 'norm'):
|
157 |
vit.norm = nn.LayerNorm(self.D)
|
158 |
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
|
159 |
-
# 3. Filter weights for ViT components only
|
160 |
vit_weights = {k.replace("vit.", ""): v
|
161 |
for k, v in checkpoint.items()
|
162 |
if k.startswith("vit.")}
|
163 |
-
|
164 |
-
# 4. Load weights while ignoring classifier head
|
165 |
vit.load_state_dict(vit_weights, strict=False)
|
166 |
-
|
167 |
-
|
168 |
return vit.eval()
|
169 |
|
170 |
def _init_llama(self):
|
171 |
-
"""Initialize frozen LLaMA-2-13b-chat with proper error handling"""
|
172 |
try:
|
173 |
device_map = {
|
174 |
"": 0 if torch.cuda.is_available() else "cpu"
|
175 |
}
|
176 |
-
# First try loading with device_map="auto"
|
177 |
model = LlamaForCausalLM.from_pretrained(
|
178 |
"meta-llama/Llama-2-13b-chat-hf",
|
179 |
token=token,
|
@@ -181,9 +139,7 @@ class SkinGPT4(nn.Module):
|
|
181 |
device_map=device_map,
|
182 |
low_cpu_mem_usage=True
|
183 |
)
|
184 |
-
|
185 |
return model.eval()
|
186 |
-
|
187 |
except Exception as e:
|
188 |
raise ImportError(
|
189 |
f"Failed to load LLaMA model. Please ensure:\n"
|
@@ -194,143 +150,61 @@ class SkinGPT4(nn.Module):
|
|
194 |
)
|
195 |
|
196 |
def encode_image(self, x):
|
197 |
-
"""Convert image to patch embeddings following Eq. (1)"""
|
198 |
-
# x: (B, C, H, W)
|
199 |
x = x.to(self.dtype)
|
200 |
if x.dim() == 3:
|
201 |
-
x = x.unsqueeze(0)
|
202 |
if x.dim() != 4:
|
203 |
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
|
204 |
-
|
205 |
-
B, C, H, W = x.shape
|
206 |
-
N = (H * W) // (self.P ** 2)
|
207 |
-
|
208 |
x = self.vit.patch_embed(x)
|
209 |
-
|
210 |
num_patches = x.shape[1]
|
211 |
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
|
212 |
x = x + pos_embed
|
213 |
-
|
214 |
-
# Add class token
|
215 |
class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
|
216 |
x = torch.cat([class_token, x], dim=1)
|
217 |
for blk in self.vit.blocks:
|
218 |
x = blk(x)
|
219 |
x = self.vit.norm(x)
|
220 |
vit_features = self.ln_vision(x)
|
221 |
-
print(f"vit features (first 5): {vit_features[0, 0, :5]}")
|
222 |
-
|
223 |
-
# Q-Former forward pass
|
224 |
with torch.no_grad():
|
225 |
qformer_output = self.q_former(vit_features.float())
|
226 |
-
print(f"Q-Former output (first 5): {qformer_output[0, 0, :5]}")
|
227 |
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
|
228 |
-
|
229 |
-
|
230 |
return image_embeds
|
231 |
|
232 |
def generate(self, images, user_input=None, max_new_tokens=300):
|
233 |
-
|
234 |
image_embeds = self.encode_image(images)
|
235 |
-
|
236 |
-
print(f"Aligned features : {image_embeds}")
|
237 |
-
# print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
|
238 |
-
|
239 |
-
print(
|
240 |
-
f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}")
|
241 |
-
|
242 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
243 |
raise ValueError(
|
244 |
f"Feature dimension mismatch. "
|
245 |
f"Q-Former output: {image_embeds.shape[-1]}, "
|
246 |
f"LLaMA expected: {self.llama.config.hidden_size}"
|
247 |
)
|
248 |
-
|
249 |
-
|
250 |
-
# prompt = """"### Instruction: <Img><IMAGE></Img>
|
251 |
-
# Could you describe the skin condition in this image?
|
252 |
-
# Please provide a detailed analysis including possible diagnoses.
|
253 |
-
# ### Response:
|
254 |
-
# """
|
255 |
-
|
256 |
-
# prompt = """### Skin Diagnosis Analysis ###
|
257 |
-
# <IMAGE>
|
258 |
-
# Could you describe the skin condition in this image?
|
259 |
-
# Please provide a detailed analysis including possible diagnoses.
|
260 |
-
# ### Response:"""
|
261 |
-
|
262 |
prompt = """### Instruction:
|
263 |
<IMAGE>
|
264 |
Could you describe the skin condition in this image?
|
265 |
### Response:"""
|
266 |
-
|
267 |
-
|
268 |
-
# print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
|
269 |
-
|
270 |
self.tokenizer = LlamaTokenizer.from_pretrained(
|
271 |
"meta-llama/Llama-2-13b-chat-hf",
|
272 |
token=token,
|
273 |
padding_side="right"
|
274 |
)
|
275 |
-
# self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
|
276 |
num_added = self.tokenizer.add_special_tokens({
|
277 |
'additional_special_tokens': ['<IMAGE>']
|
278 |
})
|
279 |
-
# num_added = self.tokenizer.add_special_tokens({
|
280 |
-
# 'additional_special_tokens': ['<Img>', '</Img>', '<IMAGE>']
|
281 |
-
# })
|
282 |
-
|
283 |
if num_added == 0:
|
284 |
raise ValueError("Failed to add <IMAGE> token!")
|
285 |
-
|
286 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
287 |
-
|
288 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
289 |
-
|
290 |
-
# print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
|
291 |
-
# print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
|
292 |
-
|
293 |
-
# Prepare embeddings
|
294 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
295 |
visual_embeds = image_embeds.mean(dim=1)
|
296 |
-
|
297 |
-
# image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
|
298 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
299 |
replace_positions = (inputs.input_ids == image_token_id).nonzero()
|
300 |
-
|
301 |
if len(replace_positions) == 0:
|
302 |
raise ValueError("No <IMAGE> tokens found in prompt!")
|
303 |
-
|
304 |
if len(replace_positions[0]) == 0:
|
305 |
raise ValueError("Image token not found in prompt")
|
306 |
-
|
307 |
-
# print(f"\n[DEBUG] Image token found at position: {replace_positions}")
|
308 |
-
|
309 |
-
|
310 |
-
print(f"\n[DEBUG] Before replacement:")
|
311 |
-
# print(f"Text embeddings shape: {input_embeddings.shape}")
|
312 |
-
# print(f"Visual embeddings shape: {visual_embeds.shape}")
|
313 |
-
# print(f"Image token at {replace_positions[0][1].item()}:")
|
314 |
-
print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
|
315 |
-
|
316 |
for pos in replace_positions:
|
317 |
input_embeddings[0, pos[1]] = visual_embeds[0]
|
318 |
|
319 |
-
print(f"\n[DEBUG] After replacement:")
|
320 |
-
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
|
321 |
-
|
322 |
-
# outputs = self.llama.generate(
|
323 |
-
# inputs_embeds=input_embeddings,
|
324 |
-
# max_new_tokens=max_new_tokens,
|
325 |
-
# temperature=0.7,
|
326 |
-
# top_k=40,
|
327 |
-
# top_p=0.9,
|
328 |
-
# repetition_penalty=1.1,
|
329 |
-
# do_sample=True,
|
330 |
-
# pad_token_id = self.tokenizer.eos_token_id,
|
331 |
-
# eos_token_id = self.tokenizer.eos_token_id
|
332 |
-
# )
|
333 |
-
|
334 |
outputs = self.llama.generate(
|
335 |
inputs_embeds=input_embeddings,
|
336 |
max_new_tokens=max_new_tokens,
|
@@ -340,27 +214,16 @@ class SkinGPT4(nn.Module):
|
|
340 |
repetition_penalty=1.1,
|
341 |
do_sample=True,
|
342 |
)
|
343 |
-
|
344 |
-
|
345 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
346 |
-
print(f"Full Output from llama : {full_output}")
|
347 |
response = full_output.split("### Response:")[-1].strip()
|
348 |
-
# print(f"Response from llama : {full_output}")
|
349 |
-
|
350 |
return response
|
351 |
|
352 |
|
353 |
class SkinGPTClassifier:
|
354 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
355 |
self.device = torch.device(device)
|
356 |
-
self.conversation_history = []
|
357 |
-
|
358 |
with st.spinner("Loading AI models (this may take several minutes)..."):
|
359 |
self.model = self._load_model()
|
360 |
-
# print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
|
361 |
-
# print(f"Projection layer: {self.model.llama_proj}")
|
362 |
-
|
363 |
-
# Image transformations
|
364 |
self.transform = transforms.Compose([
|
365 |
transforms.Resize((224, 224)),
|
366 |
transforms.ToTensor(),
|
@@ -378,18 +241,9 @@ class SkinGPTClassifier:
|
|
378 |
|
379 |
def predict(self, image):
|
380 |
image = image.convert('RGB')
|
381 |
-
print(f"Original image mode: {image.mode}, size: {image.size}")
|
382 |
-
|
383 |
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
384 |
-
|
385 |
-
print(f"Tensor shape: {image_tensor.shape}")
|
386 |
-
print(f"Tensor min/max: {image_tensor.min().item():.4f}/{image_tensor.max().item():.4f}")
|
387 |
-
print(f"Tensor mean: {image_tensor.mean().item():.4f}")
|
388 |
-
|
389 |
with torch.no_grad():
|
390 |
diagnosis = self.model.generate(image_tensor)
|
391 |
-
|
392 |
return {
|
393 |
"diagnosis": diagnosis,
|
394 |
-
"visual_features": None # Can return features if needed
|
395 |
}
|
|
|
|
|
1 |
from torch import nn
|
2 |
from torchvision import transforms
|
|
|
3 |
from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
|
4 |
from eva_vit import create_eva_vit_g
|
5 |
import requests
|
6 |
from io import BytesIO
|
7 |
import os
|
8 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
9 |
import torch
|
|
|
|
|
10 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
11 |
token = os.getenv("HF_TOKEN")
|
12 |
import streamlit as st
|
|
|
13 |
|
14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
class Blip2QFormer(nn.Module):
|
16 |
def __init__(self, num_query_tokens=32, vision_width=1408):
|
17 |
super().__init__()
|
18 |
self.num_query_tokens = num_query_tokens
|
|
|
19 |
self.bert_config = BertConfig(
|
20 |
vocab_size=30522,
|
21 |
hidden_size=768,
|
|
|
40 |
torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
|
41 |
)
|
42 |
self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
|
|
|
|
|
43 |
self._init_weights()
|
44 |
|
45 |
def _init_weights(self):
|
|
|
57 |
msg = self.load_state_dict(state_dict, strict=False)
|
58 |
|
59 |
def forward(self, visual_features):
|
|
|
|
|
|
|
|
|
|
|
60 |
visual_embeds = self.vision_proj(visual_features.float())
|
|
|
|
|
|
|
|
|
61 |
query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
|
62 |
combined_input = torch.cat([query_tokens, visual_embeds], dim=1)
|
63 |
attention_mask = torch.ones(
|
|
|
65 |
dtype=torch.long,
|
66 |
device=combined_input.device
|
67 |
)
|
|
|
|
|
68 |
outputs = self.bert(
|
69 |
attention_mask=attention_mask,
|
70 |
inputs_embeds=combined_input,
|
71 |
return_dict=True
|
72 |
)
|
|
|
73 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
74 |
|
75 |
|
|
|
78 |
def __init__(self, vit_checkpoint_path,
|
79 |
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
|
80 |
super().__init__()
|
|
|
81 |
self.device = device
|
|
|
82 |
self.dtype = MODEL_DTYPE
|
83 |
self.H, self.W, self.C = 224, 224, 3
|
84 |
+
self.P = 14
|
85 |
+
self.D = 1408
|
86 |
self.num_query_tokens = 32
|
87 |
|
88 |
self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
|
|
|
89 |
self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
|
90 |
|
91 |
self.q_former = Blip2QFormer(
|
|
|
95 |
self.q_former.load_from_pretrained(q_former_model)
|
96 |
for param in self.q_former.parameters():
|
97 |
param.requires_grad = False
|
|
|
|
|
|
|
|
|
98 |
self.llama = self._init_llama()
|
|
|
99 |
self.llama_proj = nn.Linear(
|
100 |
self.q_former.bert_config.hidden_size,
|
101 |
self.llama.config.hidden_size
|
102 |
).to(self.dtype)
|
|
|
|
|
|
|
|
|
103 |
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
|
104 |
for param in module.parameters():
|
105 |
param.requires_grad = False
|
106 |
module.eval()
|
107 |
|
108 |
def _init_vit(self, vit_checkpoint_path):
|
|
|
109 |
vit = create_eva_vit_g(
|
110 |
img_size=(self.H, self.W),
|
111 |
patch_size=self.P,
|
|
|
121 |
if not hasattr(vit, 'norm'):
|
122 |
vit.norm = nn.LayerNorm(self.D)
|
123 |
checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
|
|
|
124 |
vit_weights = {k.replace("vit.", ""): v
|
125 |
for k, v in checkpoint.items()
|
126 |
if k.startswith("vit.")}
|
|
|
|
|
127 |
vit.load_state_dict(vit_weights, strict=False)
|
|
|
|
|
128 |
return vit.eval()
|
129 |
|
130 |
def _init_llama(self):
|
|
|
131 |
try:
|
132 |
device_map = {
|
133 |
"": 0 if torch.cuda.is_available() else "cpu"
|
134 |
}
|
|
|
135 |
model = LlamaForCausalLM.from_pretrained(
|
136 |
"meta-llama/Llama-2-13b-chat-hf",
|
137 |
token=token,
|
|
|
139 |
device_map=device_map,
|
140 |
low_cpu_mem_usage=True
|
141 |
)
|
|
|
142 |
return model.eval()
|
|
|
143 |
except Exception as e:
|
144 |
raise ImportError(
|
145 |
f"Failed to load LLaMA model. Please ensure:\n"
|
|
|
150 |
)
|
151 |
|
152 |
def encode_image(self, x):
|
|
|
|
|
153 |
x = x.to(self.dtype)
|
154 |
if x.dim() == 3:
|
155 |
+
x = x.unsqueeze(0)
|
156 |
if x.dim() != 4:
|
157 |
raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
|
|
|
|
|
|
|
|
|
158 |
x = self.vit.patch_embed(x)
|
|
|
159 |
num_patches = x.shape[1]
|
160 |
pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
|
161 |
x = x + pos_embed
|
|
|
|
|
162 |
class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
|
163 |
x = torch.cat([class_token, x], dim=1)
|
164 |
for blk in self.vit.blocks:
|
165 |
x = blk(x)
|
166 |
x = self.vit.norm(x)
|
167 |
vit_features = self.ln_vision(x)
|
|
|
|
|
|
|
168 |
with torch.no_grad():
|
169 |
qformer_output = self.q_former(vit_features.float())
|
|
|
170 |
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
|
|
|
|
|
171 |
return image_embeds
|
172 |
|
173 |
def generate(self, images, user_input=None, max_new_tokens=300):
|
|
|
174 |
image_embeds = self.encode_image(images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
176 |
raise ValueError(
|
177 |
f"Feature dimension mismatch. "
|
178 |
f"Q-Former output: {image_embeds.shape[-1]}, "
|
179 |
f"LLaMA expected: {self.llama.config.hidden_size}"
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
prompt = """### Instruction:
|
182 |
<IMAGE>
|
183 |
Could you describe the skin condition in this image?
|
184 |
### Response:"""
|
|
|
|
|
|
|
|
|
185 |
self.tokenizer = LlamaTokenizer.from_pretrained(
|
186 |
"meta-llama/Llama-2-13b-chat-hf",
|
187 |
token=token,
|
188 |
padding_side="right"
|
189 |
)
|
|
|
190 |
num_added = self.tokenizer.add_special_tokens({
|
191 |
'additional_special_tokens': ['<IMAGE>']
|
192 |
})
|
|
|
|
|
|
|
|
|
193 |
if num_added == 0:
|
194 |
raise ValueError("Failed to add <IMAGE> token!")
|
|
|
195 |
self.llama.resize_token_embeddings(len(self.tokenizer))
|
|
|
196 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
|
|
|
|
|
|
|
|
|
|
197 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
198 |
visual_embeds = image_embeds.mean(dim=1)
|
|
|
|
|
199 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
200 |
replace_positions = (inputs.input_ids == image_token_id).nonzero()
|
|
|
201 |
if len(replace_positions) == 0:
|
202 |
raise ValueError("No <IMAGE> tokens found in prompt!")
|
|
|
203 |
if len(replace_positions[0]) == 0:
|
204 |
raise ValueError("Image token not found in prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
for pos in replace_positions:
|
206 |
input_embeddings[0, pos[1]] = visual_embeds[0]
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
outputs = self.llama.generate(
|
209 |
inputs_embeds=input_embeddings,
|
210 |
max_new_tokens=max_new_tokens,
|
|
|
214 |
repetition_penalty=1.1,
|
215 |
do_sample=True,
|
216 |
)
|
|
|
|
|
217 |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
218 |
response = full_output.split("### Response:")[-1].strip()
|
|
|
|
|
219 |
return response
|
220 |
|
221 |
|
222 |
class SkinGPTClassifier:
|
223 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
224 |
self.device = torch.device(device)
|
|
|
|
|
225 |
with st.spinner("Loading AI models (this may take several minutes)..."):
|
226 |
self.model = self._load_model()
|
|
|
|
|
|
|
|
|
227 |
self.transform = transforms.Compose([
|
228 |
transforms.Resize((224, 224)),
|
229 |
transforms.ToTensor(),
|
|
|
241 |
|
242 |
def predict(self, image):
|
243 |
image = image.convert('RGB')
|
|
|
|
|
244 |
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
245 |
with torch.no_grad():
|
246 |
diagnosis = self.model.generate(image_tensor)
|
|
|
247 |
return {
|
248 |
"diagnosis": diagnosis,
|
|
|
249 |
}
|
app.py
CHANGED
@@ -1,32 +1,18 @@
|
|
1 |
import torch
|
2 |
import random
|
3 |
import numpy as np
|
4 |
-
|
5 |
torch.manual_seed(42)
|
6 |
random.seed(42)
|
7 |
np.random.seed(42)
|
8 |
-
|
9 |
-
|
10 |
import streamlit as st
|
11 |
import io
|
12 |
-
from fpdf import FPDF
|
13 |
-
import nest_asyncio
|
14 |
-
nest_asyncio.apply()
|
15 |
-
device='cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
-
|
17 |
-
|
18 |
-
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
|
19 |
-
|
20 |
from PIL import Image
|
21 |
import os
|
22 |
from transformers import logging
|
23 |
-
|
24 |
-
import torch
|
25 |
from SkinGPT import SkinGPTClassifier
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
torch.set_default_dtype(torch.float32) # Main computations in float32
|
31 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
32 |
import warnings
|
@@ -41,7 +27,11 @@ import warnings
|
|
41 |
warnings.filterwarnings("ignore")
|
42 |
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
def get_classifier():
|
46 |
classifier = SkinGPTClassifier()
|
47 |
for module in [classifier.model.vit,
|
@@ -53,12 +43,17 @@ def get_classifier():
|
|
53 |
|
54 |
return classifier
|
55 |
|
56 |
-
|
|
|
|
|
|
|
57 |
|
58 |
# === Session Init ===
|
59 |
if "messages" not in st.session_state:
|
60 |
st.session_state.messages = []
|
61 |
|
|
|
|
|
62 |
|
63 |
# === PDF Export ===
|
64 |
def export_chat_to_pdf(messages):
|
@@ -77,37 +72,56 @@ def export_chat_to_pdf(messages):
|
|
77 |
|
78 |
st.title("𧬠DermBOT β Skin AI Assistant")
|
79 |
st.caption(f"π§ Using model: SkinGPT")
|
80 |
-
uploaded_file = st.file_uploader(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
image = Image.open(uploaded_file).convert("RGB")
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
else:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
st.session_state.conversation.append(("user", user_query))
|
99 |
-
with st.chat_message("user"):
|
100 |
-
st.markdown(user_query)
|
101 |
-
|
102 |
-
# Generate response with context
|
103 |
-
context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation])
|
104 |
-
response = classifier.generate(image, user_input=context)
|
105 |
-
|
106 |
-
st.session_state.conversation.append(("assistant", response))
|
107 |
-
with st.chat_message("assistant"):
|
108 |
-
st.markdown(response)
|
109 |
-
|
110 |
-
# === PDF Button ===
|
111 |
-
if st.button("π Download Chat as PDF"):
|
112 |
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
113 |
-
st.download_button(
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import random
|
3 |
import numpy as np
|
|
|
4 |
torch.manual_seed(42)
|
5 |
random.seed(42)
|
6 |
np.random.seed(42)
|
|
|
|
|
7 |
import streamlit as st
|
8 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from PIL import Image
|
10 |
import os
|
11 |
from transformers import logging
|
|
|
|
|
12 |
from SkinGPT import SkinGPTClassifier
|
13 |
+
from fpdf import FPDF
|
14 |
+
import nest_asyncio
|
15 |
+
nest_asyncio.apply()
|
|
|
16 |
torch.set_default_dtype(torch.float32) # Main computations in float32
|
17 |
MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
import warnings
|
|
|
27 |
warnings.filterwarnings("ignore")
|
28 |
|
29 |
|
30 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
31 |
+
st.set_page_config(page_title="SkinGPT", page_icon="π§¬", layout="centered")
|
32 |
+
|
33 |
+
|
34 |
+
@st.cache_resource(show_spinner=False)
|
35 |
def get_classifier():
|
36 |
classifier = SkinGPTClassifier()
|
37 |
for module in [classifier.model.vit,
|
|
|
43 |
|
44 |
return classifier
|
45 |
|
46 |
+
if 'app_models' not in st.session_state:
|
47 |
+
st.session_state.app_models = get_classifier()
|
48 |
+
|
49 |
+
classifier = st.session_state.app_models
|
50 |
|
51 |
# === Session Init ===
|
52 |
if "messages" not in st.session_state:
|
53 |
st.session_state.messages = []
|
54 |
|
55 |
+
if "current_image" not in st.session_state:
|
56 |
+
st.session_state.current_image = None
|
57 |
|
58 |
# === PDF Export ===
|
59 |
def export_chat_to_pdf(messages):
|
|
|
72 |
|
73 |
st.title("𧬠DermBOT β Skin AI Assistant")
|
74 |
st.caption(f"π§ Using model: SkinGPT")
|
75 |
+
uploaded_file = st.file_uploader(
|
76 |
+
"Upload a skin image",
|
77 |
+
type=["jpg", "jpeg", "png"],
|
78 |
+
key="file_uploader"
|
79 |
+
)
|
80 |
+
|
81 |
+
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
|
82 |
+
st.session_state.messages = []
|
83 |
+
st.session_state.current_image = uploaded_file
|
84 |
+
|
85 |
image = Image.open(uploaded_file).convert("RGB")
|
86 |
+
st.image(image, caption="Uploaded image", use_column_width=True)
|
87 |
+
with st.spinner("Analyzing the image..."):
|
88 |
+
result = classifier.predict(image)
|
89 |
+
|
90 |
+
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|
91 |
+
|
92 |
+
for message in st.session_state.messages:
|
93 |
+
with st.chat_message(message["role"]):
|
94 |
+
st.markdown(message["content"])
|
95 |
+
|
96 |
+
# === Chat Interface ===
|
97 |
+
if prompt := st.chat_input("Ask a follow-up question..."):
|
98 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
99 |
+
with st.chat_message("user"):
|
100 |
+
st.markdown(prompt)
|
101 |
+
|
102 |
+
with st.chat_message("assistant"):
|
103 |
+
with st.spinner("Thinking..."):
|
104 |
+
if len(st.session_state.messages) > 1:
|
105 |
+
conversation_context = "\n".join(
|
106 |
+
f"{m['role']}: {m['content']}"
|
107 |
+
for m in st.session_state.messages[:-1] # Exclude current prompt
|
108 |
+
)
|
109 |
+
augmented_prompt = (
|
110 |
+
f"Conversation history:\n{conversation_context}\n\n"
|
111 |
+
f"Current question: {prompt}"
|
112 |
+
)
|
113 |
+
result = classifier.predict(image)
|
114 |
else:
|
115 |
+
result = classifier.predict(image)
|
116 |
+
|
117 |
+
st.markdown(result["diagnosis"])
|
118 |
+
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|
119 |
+
|
120 |
+
if st.session_state.messages and st.button("π Download Chat as PDF"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
122 |
+
st.download_button(
|
123 |
+
"Download PDF",
|
124 |
+
data=pdf_file,
|
125 |
+
file_name="skingpt_chat_history.pdf",
|
126 |
+
mime="application/pdf"
|
127 |
+
)
|