KeerthiVM commited on
Commit
2e04d58
Β·
1 Parent(s): 2b8678a
Files changed (2) hide show
  1. SkinGPT.py +3 -149
  2. app.py +64 -50
SkinGPT.py CHANGED
@@ -1,29 +1,21 @@
1
- import torch
2
  from torch import nn
3
  from torchvision import transforms
4
- from PIL import Image
5
  from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
6
  from eva_vit import create_eva_vit_g
7
  import requests
8
  from io import BytesIO
9
  import os
10
  from huggingface_hub import hf_hub_download
11
- from transformers import BitsAndBytesConfig
12
- from accelerate import init_empty_weights
13
  import torch
14
- from torch.cuda.amp import autocast
15
- import warnings
16
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
17
  token = os.getenv("HF_TOKEN")
18
  import streamlit as st
19
- import torch.nn.functional as F
20
 
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  class Blip2QFormer(nn.Module):
23
  def __init__(self, num_query_tokens=32, vision_width=1408):
24
  super().__init__()
25
  self.num_query_tokens = num_query_tokens
26
- # Load pre-trained Q-Former config
27
  self.bert_config = BertConfig(
28
  vocab_size=30522,
29
  hidden_size=768,
@@ -48,8 +40,6 @@ class Blip2QFormer(nn.Module):
48
  torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
49
  )
50
  self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
51
-
52
- # Initialize weights
53
  self._init_weights()
54
 
55
  def _init_weights(self):
@@ -67,16 +57,7 @@ class Blip2QFormer(nn.Module):
67
  msg = self.load_state_dict(state_dict, strict=False)
68
 
69
  def forward(self, visual_features):
70
-
71
- print(
72
- f"Visual features stats - min: {visual_features.min().item():.4f}, max: {visual_features.max().item():.4f}")
73
-
74
- # Project visual features
75
  visual_embeds = self.vision_proj(visual_features.float())
76
- print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
77
- # visual_embeds = self.vision_proj(visual_features.float())
78
-
79
- # Expand query tokens
80
  query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
81
  combined_input = torch.cat([query_tokens, visual_embeds], dim=1)
82
  attention_mask = torch.ones(
@@ -84,14 +65,11 @@ class Blip2QFormer(nn.Module):
84
  dtype=torch.long,
85
  device=combined_input.device
86
  )
87
-
88
- # Forward through BERT
89
  outputs = self.bert(
90
  attention_mask=attention_mask,
91
  inputs_embeds=combined_input,
92
  return_dict=True
93
  )
94
-
95
  return outputs.last_hidden_state[:, :self.num_query_tokens]
96
 
97
 
@@ -100,17 +78,14 @@ class SkinGPT4(nn.Module):
100
  def __init__(self, vit_checkpoint_path,
101
  q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
102
  super().__init__()
103
- # Image encoder parameters from paper
104
  self.device = device
105
- # self.dtype = torch.float16
106
  self.dtype = MODEL_DTYPE
107
  self.H, self.W, self.C = 224, 224, 3
108
- self.P = 14 # Patch size
109
- self.D = 1408 # ViT embedding dimension
110
  self.num_query_tokens = 32
111
 
112
  self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
113
- print("Loaded ViT")
114
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
115
 
116
  self.q_former = Blip2QFormer(
@@ -120,27 +95,17 @@ class SkinGPT4(nn.Module):
120
  self.q_former.load_from_pretrained(q_former_model)
121
  for param in self.q_former.parameters():
122
  param.requires_grad = False
123
-
124
- print("Loaded QFormer")
125
-
126
-
127
  self.llama = self._init_llama()
128
-
129
  self.llama_proj = nn.Linear(
130
  self.q_former.bert_config.hidden_size,
131
  self.llama.config.hidden_size
132
  ).to(self.dtype)
133
-
134
- print(f"Q-Former output dim: {self.q_former.bert_config.hidden_size}")
135
- print(f"LLaMA input dim: {self.llama.config.hidden_size}")
136
-
137
  for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
138
  for param in module.parameters():
139
  param.requires_grad = False
140
  module.eval()
141
 
142
  def _init_vit(self, vit_checkpoint_path):
143
- """Initialize EVA-ViT-G with paper specifications"""
144
  vit = create_eva_vit_g(
145
  img_size=(self.H, self.W),
146
  patch_size=self.P,
@@ -156,24 +121,17 @@ class SkinGPT4(nn.Module):
156
  if not hasattr(vit, 'norm'):
157
  vit.norm = nn.LayerNorm(self.D)
158
  checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
159
- # 3. Filter weights for ViT components only
160
  vit_weights = {k.replace("vit.", ""): v
161
  for k, v in checkpoint.items()
162
  if k.startswith("vit.")}
163
-
164
- # 4. Load weights while ignoring classifier head
165
  vit.load_state_dict(vit_weights, strict=False)
166
-
167
-
168
  return vit.eval()
169
 
170
  def _init_llama(self):
171
- """Initialize frozen LLaMA-2-13b-chat with proper error handling"""
172
  try:
173
  device_map = {
174
  "": 0 if torch.cuda.is_available() else "cpu"
175
  }
176
- # First try loading with device_map="auto"
177
  model = LlamaForCausalLM.from_pretrained(
178
  "meta-llama/Llama-2-13b-chat-hf",
179
  token=token,
@@ -181,9 +139,7 @@ class SkinGPT4(nn.Module):
181
  device_map=device_map,
182
  low_cpu_mem_usage=True
183
  )
184
-
185
  return model.eval()
186
-
187
  except Exception as e:
188
  raise ImportError(
189
  f"Failed to load LLaMA model. Please ensure:\n"
@@ -194,143 +150,61 @@ class SkinGPT4(nn.Module):
194
  )
195
 
196
  def encode_image(self, x):
197
- """Convert image to patch embeddings following Eq. (1)"""
198
- # x: (B, C, H, W)
199
  x = x.to(self.dtype)
200
  if x.dim() == 3:
201
- x = x.unsqueeze(0) # Add batch dimension if missing
202
  if x.dim() != 4:
203
  raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
204
-
205
- B, C, H, W = x.shape
206
- N = (H * W) // (self.P ** 2)
207
-
208
  x = self.vit.patch_embed(x)
209
-
210
  num_patches = x.shape[1]
211
  pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
212
  x = x + pos_embed
213
-
214
- # Add class token
215
  class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
216
  x = torch.cat([class_token, x], dim=1)
217
  for blk in self.vit.blocks:
218
  x = blk(x)
219
  x = self.vit.norm(x)
220
  vit_features = self.ln_vision(x)
221
- print(f"vit features (first 5): {vit_features[0, 0, :5]}")
222
-
223
- # Q-Former forward pass
224
  with torch.no_grad():
225
  qformer_output = self.q_former(vit_features.float())
226
- print(f"Q-Former output (first 5): {qformer_output[0, 0, :5]}")
227
  image_embeds = self.llama_proj(qformer_output.to(self.dtype))
228
-
229
-
230
  return image_embeds
231
 
232
  def generate(self, images, user_input=None, max_new_tokens=300):
233
-
234
  image_embeds = self.encode_image(images)
235
-
236
- print(f"Aligned features : {image_embeds}")
237
- # print(f"\n Images embeddings shape : {image_embeds.shape} \n Llama config hidden size : {self.llama.config.hidden_size}")
238
-
239
- print(
240
- f"\n[VALIDATION] Visual embeds - Mean: {image_embeds.mean().item():.4f}, Std: {image_embeds.std().item():.4f}")
241
-
242
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
243
  raise ValueError(
244
  f"Feature dimension mismatch. "
245
  f"Q-Former output: {image_embeds.shape[-1]}, "
246
  f"LLaMA expected: {self.llama.config.hidden_size}"
247
  )
248
-
249
-
250
- # prompt = """"### Instruction: <Img><IMAGE></Img>
251
- # Could you describe the skin condition in this image?
252
- # Please provide a detailed analysis including possible diagnoses.
253
- # ### Response:
254
- # """
255
-
256
- # prompt = """### Skin Diagnosis Analysis ###
257
- # <IMAGE>
258
- # Could you describe the skin condition in this image?
259
- # Please provide a detailed analysis including possible diagnoses.
260
- # ### Response:"""
261
-
262
  prompt = """### Instruction:
263
  <IMAGE>
264
  Could you describe the skin condition in this image?
265
  ### Response:"""
266
-
267
-
268
- # print(f"\n[DEBUG] Raw Prompt:\n{prompt}")
269
-
270
  self.tokenizer = LlamaTokenizer.from_pretrained(
271
  "meta-llama/Llama-2-13b-chat-hf",
272
  token=token,
273
  padding_side="right"
274
  )
275
- # self.tokenizer.add_special_tokens({'additional_special_tokens': ['<Img>', '</Img>', '<ImageHere>']})
276
  num_added = self.tokenizer.add_special_tokens({
277
  'additional_special_tokens': ['<IMAGE>']
278
  })
279
- # num_added = self.tokenizer.add_special_tokens({
280
- # 'additional_special_tokens': ['<Img>', '</Img>', '<IMAGE>']
281
- # })
282
-
283
  if num_added == 0:
284
  raise ValueError("Failed to add <IMAGE> token!")
285
-
286
  self.llama.resize_token_embeddings(len(self.tokenizer))
287
-
288
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
289
-
290
- # print(f"\n[DEBUG] Tokenized input IDs:\n{inputs.input_ids}")
291
- # print(f"[DEBUG] Special token positions: {self.tokenizer.all_special_tokens}")
292
-
293
- # Prepare embeddings
294
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
295
  visual_embeds = image_embeds.mean(dim=1)
296
-
297
- # image_token_id = self.tokenizer.convert_tokens_to_ids("<ImageHere>")
298
  image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
299
  replace_positions = (inputs.input_ids == image_token_id).nonzero()
300
-
301
  if len(replace_positions) == 0:
302
  raise ValueError("No <IMAGE> tokens found in prompt!")
303
-
304
  if len(replace_positions[0]) == 0:
305
  raise ValueError("Image token not found in prompt")
306
-
307
- # print(f"\n[DEBUG] Image token found at position: {replace_positions}")
308
-
309
-
310
- print(f"\n[DEBUG] Before replacement:")
311
- # print(f"Text embeddings shape: {input_embeddings.shape}")
312
- # print(f"Visual embeddings shape: {visual_embeds.shape}")
313
- # print(f"Image token at {replace_positions[0][1].item()}:")
314
- print(f"Image token embedding (before):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
315
-
316
  for pos in replace_positions:
317
  input_embeddings[0, pos[1]] = visual_embeds[0]
318
 
319
- print(f"\n[DEBUG] After replacement:")
320
- print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
321
-
322
- # outputs = self.llama.generate(
323
- # inputs_embeds=input_embeddings,
324
- # max_new_tokens=max_new_tokens,
325
- # temperature=0.7,
326
- # top_k=40,
327
- # top_p=0.9,
328
- # repetition_penalty=1.1,
329
- # do_sample=True,
330
- # pad_token_id = self.tokenizer.eos_token_id,
331
- # eos_token_id = self.tokenizer.eos_token_id
332
- # )
333
-
334
  outputs = self.llama.generate(
335
  inputs_embeds=input_embeddings,
336
  max_new_tokens=max_new_tokens,
@@ -340,27 +214,16 @@ class SkinGPT4(nn.Module):
340
  repetition_penalty=1.1,
341
  do_sample=True,
342
  )
343
-
344
-
345
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
346
- print(f"Full Output from llama : {full_output}")
347
  response = full_output.split("### Response:")[-1].strip()
348
- # print(f"Response from llama : {full_output}")
349
-
350
  return response
351
 
352
 
353
  class SkinGPTClassifier:
354
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
355
  self.device = torch.device(device)
356
- self.conversation_history = []
357
-
358
  with st.spinner("Loading AI models (this may take several minutes)..."):
359
  self.model = self._load_model()
360
- # print(f"Q-Former output shape: {self.model.q_former(torch.randn(1, 197, 1408)).shape}")
361
- # print(f"Projection layer: {self.model.llama_proj}")
362
-
363
- # Image transformations
364
  self.transform = transforms.Compose([
365
  transforms.Resize((224, 224)),
366
  transforms.ToTensor(),
@@ -378,18 +241,9 @@ class SkinGPTClassifier:
378
 
379
  def predict(self, image):
380
  image = image.convert('RGB')
381
- print(f"Original image mode: {image.mode}, size: {image.size}")
382
-
383
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
384
-
385
- print(f"Tensor shape: {image_tensor.shape}")
386
- print(f"Tensor min/max: {image_tensor.min().item():.4f}/{image_tensor.max().item():.4f}")
387
- print(f"Tensor mean: {image_tensor.mean().item():.4f}")
388
-
389
  with torch.no_grad():
390
  diagnosis = self.model.generate(image_tensor)
391
-
392
  return {
393
  "diagnosis": diagnosis,
394
- "visual_features": None # Can return features if needed
395
  }
 
 
1
  from torch import nn
2
  from torchvision import transforms
 
3
  from transformers import LlamaForCausalLM, LlamaTokenizer, BertModel, BertConfig
4
  from eva_vit import create_eva_vit_g
5
  import requests
6
  from io import BytesIO
7
  import os
8
  from huggingface_hub import hf_hub_download
 
 
9
  import torch
 
 
10
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
11
  token = os.getenv("HF_TOKEN")
12
  import streamlit as st
 
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  class Blip2QFormer(nn.Module):
16
  def __init__(self, num_query_tokens=32, vision_width=1408):
17
  super().__init__()
18
  self.num_query_tokens = num_query_tokens
 
19
  self.bert_config = BertConfig(
20
  vocab_size=30522,
21
  hidden_size=768,
 
40
  torch.zeros(1, num_query_tokens, self.bert_config.hidden_size)
41
  )
42
  self.vision_proj = nn.Linear(vision_width, self.bert_config.hidden_size)
 
 
43
  self._init_weights()
44
 
45
  def _init_weights(self):
 
57
  msg = self.load_state_dict(state_dict, strict=False)
58
 
59
  def forward(self, visual_features):
 
 
 
 
 
60
  visual_embeds = self.vision_proj(visual_features.float())
 
 
 
 
61
  query_tokens = self.query_tokens.expand(visual_embeds.shape[0], -1, -1)
62
  combined_input = torch.cat([query_tokens, visual_embeds], dim=1)
63
  attention_mask = torch.ones(
 
65
  dtype=torch.long,
66
  device=combined_input.device
67
  )
 
 
68
  outputs = self.bert(
69
  attention_mask=attention_mask,
70
  inputs_embeds=combined_input,
71
  return_dict=True
72
  )
 
73
  return outputs.last_hidden_state[:, :self.num_query_tokens]
74
 
75
 
 
78
  def __init__(self, vit_checkpoint_path,
79
  q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
80
  super().__init__()
 
81
  self.device = device
 
82
  self.dtype = MODEL_DTYPE
83
  self.H, self.W, self.C = 224, 224, 3
84
+ self.P = 14
85
+ self.D = 1408
86
  self.num_query_tokens = 32
87
 
88
  self.vit = self._init_vit(vit_checkpoint_path).to(self.dtype)
 
89
  self.ln_vision = nn.LayerNorm(self.D).to(self.dtype)
90
 
91
  self.q_former = Blip2QFormer(
 
95
  self.q_former.load_from_pretrained(q_former_model)
96
  for param in self.q_former.parameters():
97
  param.requires_grad = False
 
 
 
 
98
  self.llama = self._init_llama()
 
99
  self.llama_proj = nn.Linear(
100
  self.q_former.bert_config.hidden_size,
101
  self.llama.config.hidden_size
102
  ).to(self.dtype)
 
 
 
 
103
  for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
104
  for param in module.parameters():
105
  param.requires_grad = False
106
  module.eval()
107
 
108
  def _init_vit(self, vit_checkpoint_path):
 
109
  vit = create_eva_vit_g(
110
  img_size=(self.H, self.W),
111
  patch_size=self.P,
 
121
  if not hasattr(vit, 'norm'):
122
  vit.norm = nn.LayerNorm(self.D)
123
  checkpoint = torch.load(vit_checkpoint_path, map_location='cpu')
 
124
  vit_weights = {k.replace("vit.", ""): v
125
  for k, v in checkpoint.items()
126
  if k.startswith("vit.")}
 
 
127
  vit.load_state_dict(vit_weights, strict=False)
 
 
128
  return vit.eval()
129
 
130
  def _init_llama(self):
 
131
  try:
132
  device_map = {
133
  "": 0 if torch.cuda.is_available() else "cpu"
134
  }
 
135
  model = LlamaForCausalLM.from_pretrained(
136
  "meta-llama/Llama-2-13b-chat-hf",
137
  token=token,
 
139
  device_map=device_map,
140
  low_cpu_mem_usage=True
141
  )
 
142
  return model.eval()
 
143
  except Exception as e:
144
  raise ImportError(
145
  f"Failed to load LLaMA model. Please ensure:\n"
 
150
  )
151
 
152
  def encode_image(self, x):
 
 
153
  x = x.to(self.dtype)
154
  if x.dim() == 3:
155
+ x = x.unsqueeze(0)
156
  if x.dim() != 4:
157
  raise ValueError(f"Input must be 4D tensor (got {x.dim()}D)")
 
 
 
 
158
  x = self.vit.patch_embed(x)
 
159
  num_patches = x.shape[1]
160
  pos_embed = self.vit.pos_embed[:, 1:num_patches + 1, :]
161
  x = x + pos_embed
 
 
162
  class_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
163
  x = torch.cat([class_token, x], dim=1)
164
  for blk in self.vit.blocks:
165
  x = blk(x)
166
  x = self.vit.norm(x)
167
  vit_features = self.ln_vision(x)
 
 
 
168
  with torch.no_grad():
169
  qformer_output = self.q_former(vit_features.float())
 
170
  image_embeds = self.llama_proj(qformer_output.to(self.dtype))
 
 
171
  return image_embeds
172
 
173
  def generate(self, images, user_input=None, max_new_tokens=300):
 
174
  image_embeds = self.encode_image(images)
 
 
 
 
 
 
 
175
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
176
  raise ValueError(
177
  f"Feature dimension mismatch. "
178
  f"Q-Former output: {image_embeds.shape[-1]}, "
179
  f"LLaMA expected: {self.llama.config.hidden_size}"
180
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  prompt = """### Instruction:
182
  <IMAGE>
183
  Could you describe the skin condition in this image?
184
  ### Response:"""
 
 
 
 
185
  self.tokenizer = LlamaTokenizer.from_pretrained(
186
  "meta-llama/Llama-2-13b-chat-hf",
187
  token=token,
188
  padding_side="right"
189
  )
 
190
  num_added = self.tokenizer.add_special_tokens({
191
  'additional_special_tokens': ['<IMAGE>']
192
  })
 
 
 
 
193
  if num_added == 0:
194
  raise ValueError("Failed to add <IMAGE> token!")
 
195
  self.llama.resize_token_embeddings(len(self.tokenizer))
 
196
  inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
 
 
 
 
 
197
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
198
  visual_embeds = image_embeds.mean(dim=1)
 
 
199
  image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
200
  replace_positions = (inputs.input_ids == image_token_id).nonzero()
 
201
  if len(replace_positions) == 0:
202
  raise ValueError("No <IMAGE> tokens found in prompt!")
 
203
  if len(replace_positions[0]) == 0:
204
  raise ValueError("Image token not found in prompt")
 
 
 
 
 
 
 
 
 
 
205
  for pos in replace_positions:
206
  input_embeddings[0, pos[1]] = visual_embeds[0]
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  outputs = self.llama.generate(
209
  inputs_embeds=input_embeddings,
210
  max_new_tokens=max_new_tokens,
 
214
  repetition_penalty=1.1,
215
  do_sample=True,
216
  )
 
 
217
  full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
218
  response = full_output.split("### Response:")[-1].strip()
 
 
219
  return response
220
 
221
 
222
  class SkinGPTClassifier:
223
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
224
  self.device = torch.device(device)
 
 
225
  with st.spinner("Loading AI models (this may take several minutes)..."):
226
  self.model = self._load_model()
 
 
 
 
227
  self.transform = transforms.Compose([
228
  transforms.Resize((224, 224)),
229
  transforms.ToTensor(),
 
241
 
242
  def predict(self, image):
243
  image = image.convert('RGB')
 
 
244
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
 
 
 
 
 
245
  with torch.no_grad():
246
  diagnosis = self.model.generate(image_tensor)
 
247
  return {
248
  "diagnosis": diagnosis,
 
249
  }
app.py CHANGED
@@ -1,32 +1,18 @@
1
  import torch
2
  import random
3
  import numpy as np
4
-
5
  torch.manual_seed(42)
6
  random.seed(42)
7
  np.random.seed(42)
8
-
9
-
10
  import streamlit as st
11
  import io
12
- from fpdf import FPDF
13
- import nest_asyncio
14
- nest_asyncio.apply()
15
- device='cuda' if torch.cuda.is_available() else 'cpu'
16
-
17
-
18
- st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
19
-
20
  from PIL import Image
21
  import os
22
  from transformers import logging
23
-
24
- import torch
25
  from SkinGPT import SkinGPTClassifier
26
-
27
-
28
-
29
-
30
  torch.set_default_dtype(torch.float32) # Main computations in float32
31
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
32
  import warnings
@@ -41,7 +27,11 @@ import warnings
41
  warnings.filterwarnings("ignore")
42
 
43
 
44
- # @st.cache_resource
 
 
 
 
45
  def get_classifier():
46
  classifier = SkinGPTClassifier()
47
  for module in [classifier.model.vit,
@@ -53,12 +43,17 @@ def get_classifier():
53
 
54
  return classifier
55
 
56
- classifier = get_classifier()
 
 
 
57
 
58
  # === Session Init ===
59
  if "messages" not in st.session_state:
60
  st.session_state.messages = []
61
 
 
 
62
 
63
  # === PDF Export ===
64
  def export_chat_to_pdf(messages):
@@ -77,37 +72,56 @@ def export_chat_to_pdf(messages):
77
 
78
  st.title("🧬 DermBOT β€” Skin AI Assistant")
79
  st.caption(f"🧠 Using model: SkinGPT")
80
- uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
81
- if "conversation" not in st.session_state:
82
- st.session_state.conversation = []
83
- if uploaded_file:
84
- st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
 
 
 
 
 
85
  image = Image.open(uploaded_file).convert("RGB")
86
- if not st.session_state.conversation:
87
- with st.spinner("Analyzing image..."):
88
- result = classifier.predict(image)
89
- if "error" in result:
90
- st.error(result["error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  else:
92
- st.session_state.conversation.append(("assistant", result))
93
- with st.chat_message("assistant"):
94
- st.markdown(result["diagnosis"])
95
- else:
96
- # Follow-up questions
97
- if user_query := st.chat_input("Ask a follow-up question..."):
98
- st.session_state.conversation.append(("user", user_query))
99
- with st.chat_message("user"):
100
- st.markdown(user_query)
101
-
102
- # Generate response with context
103
- context = "\n".join([f"{role}: {msg}" for role, msg in st.session_state.conversation])
104
- response = classifier.generate(image, user_input=context)
105
-
106
- st.session_state.conversation.append(("assistant", response))
107
- with st.chat_message("assistant"):
108
- st.markdown(response)
109
-
110
- # === PDF Button ===
111
- if st.button("πŸ“„ Download Chat as PDF"):
112
  pdf_file = export_chat_to_pdf(st.session_state.messages)
113
- st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")
 
 
 
 
 
 
1
  import torch
2
  import random
3
  import numpy as np
 
4
  torch.manual_seed(42)
5
  random.seed(42)
6
  np.random.seed(42)
 
 
7
  import streamlit as st
8
  import io
 
 
 
 
 
 
 
 
9
  from PIL import Image
10
  import os
11
  from transformers import logging
 
 
12
  from SkinGPT import SkinGPTClassifier
13
+ from fpdf import FPDF
14
+ import nest_asyncio
15
+ nest_asyncio.apply()
 
16
  torch.set_default_dtype(torch.float32) # Main computations in float32
17
  MODEL_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
18
  import warnings
 
27
  warnings.filterwarnings("ignore")
28
 
29
 
30
+ device='cuda' if torch.cuda.is_available() else 'cpu'
31
+ st.set_page_config(page_title="SkinGPT", page_icon="🧬", layout="centered")
32
+
33
+
34
+ @st.cache_resource(show_spinner=False)
35
  def get_classifier():
36
  classifier = SkinGPTClassifier()
37
  for module in [classifier.model.vit,
 
43
 
44
  return classifier
45
 
46
+ if 'app_models' not in st.session_state:
47
+ st.session_state.app_models = get_classifier()
48
+
49
+ classifier = st.session_state.app_models
50
 
51
  # === Session Init ===
52
  if "messages" not in st.session_state:
53
  st.session_state.messages = []
54
 
55
+ if "current_image" not in st.session_state:
56
+ st.session_state.current_image = None
57
 
58
  # === PDF Export ===
59
  def export_chat_to_pdf(messages):
 
72
 
73
  st.title("🧬 DermBOT β€” Skin AI Assistant")
74
  st.caption(f"🧠 Using model: SkinGPT")
75
+ uploaded_file = st.file_uploader(
76
+ "Upload a skin image",
77
+ type=["jpg", "jpeg", "png"],
78
+ key="file_uploader"
79
+ )
80
+
81
+ if uploaded_file is not None and uploaded_file != st.session_state.current_image:
82
+ st.session_state.messages = []
83
+ st.session_state.current_image = uploaded_file
84
+
85
  image = Image.open(uploaded_file).convert("RGB")
86
+ st.image(image, caption="Uploaded image", use_column_width=True)
87
+ with st.spinner("Analyzing the image..."):
88
+ result = classifier.predict(image)
89
+
90
+ st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
91
+
92
+ for message in st.session_state.messages:
93
+ with st.chat_message(message["role"]):
94
+ st.markdown(message["content"])
95
+
96
+ # === Chat Interface ===
97
+ if prompt := st.chat_input("Ask a follow-up question..."):
98
+ st.session_state.messages.append({"role": "user", "content": prompt})
99
+ with st.chat_message("user"):
100
+ st.markdown(prompt)
101
+
102
+ with st.chat_message("assistant"):
103
+ with st.spinner("Thinking..."):
104
+ if len(st.session_state.messages) > 1:
105
+ conversation_context = "\n".join(
106
+ f"{m['role']}: {m['content']}"
107
+ for m in st.session_state.messages[:-1] # Exclude current prompt
108
+ )
109
+ augmented_prompt = (
110
+ f"Conversation history:\n{conversation_context}\n\n"
111
+ f"Current question: {prompt}"
112
+ )
113
+ result = classifier.predict(image)
114
  else:
115
+ result = classifier.predict(image)
116
+
117
+ st.markdown(result["diagnosis"])
118
+ st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
119
+
120
+ if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  pdf_file = export_chat_to_pdf(st.session_state.messages)
122
+ st.download_button(
123
+ "Download PDF",
124
+ data=pdf_file,
125
+ file_name="skingpt_chat_history.pdf",
126
+ mime="application/pdf"
127
+ )