KeerthiVM commited on
Commit
7f3d8d1
·
1 Parent(s): 2e04d58
Files changed (2) hide show
  1. SkinGPT.py +41 -23
  2. app.py +6 -4
SkinGPT.py CHANGED
@@ -68,8 +68,10 @@ class Blip2QFormer(nn.Module):
68
  outputs = self.bert(
69
  attention_mask=attention_mask,
70
  inputs_embeds=combined_input,
 
71
  return_dict=True
72
  )
 
73
  return outputs.last_hidden_state[:, :self.num_query_tokens]
74
 
75
 
@@ -100,6 +102,15 @@ class SkinGPT4(nn.Module):
100
  self.q_former.bert_config.hidden_size,
101
  self.llama.config.hidden_size
102
  ).to(self.dtype)
 
 
 
 
 
 
 
 
 
103
  for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
104
  for param in module.parameters():
105
  param.requires_grad = False
@@ -170,30 +181,26 @@ class SkinGPT4(nn.Module):
170
  image_embeds = self.llama_proj(qformer_output.to(self.dtype))
171
  return image_embeds
172
 
173
- def generate(self, images, user_input=None, max_new_tokens=300):
174
- image_embeds = self.encode_image(images)
175
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
176
  raise ValueError(
177
  f"Feature dimension mismatch. "
178
  f"Q-Former output: {image_embeds.shape[-1]}, "
179
  f"LLaMA expected: {self.llama.config.hidden_size}"
180
  )
181
- prompt = """### Instruction:
182
- <IMAGE>
183
- Could you describe the skin condition in this image?
184
- ### Response:"""
185
- self.tokenizer = LlamaTokenizer.from_pretrained(
186
- "meta-llama/Llama-2-13b-chat-hf",
187
- token=token,
188
- padding_side="right"
189
- )
190
- num_added = self.tokenizer.add_special_tokens({
191
- 'additional_special_tokens': ['<IMAGE>']
192
- })
193
- if num_added == 0:
194
- raise ValueError("Failed to add <IMAGE> token!")
195
- self.llama.resize_token_embeddings(len(self.tokenizer))
196
- inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
197
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
198
  visual_embeds = image_embeds.mean(dim=1)
199
  image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
@@ -218,7 +225,10 @@ class SkinGPT4(nn.Module):
218
  response = full_output.split("### Response:")[-1].strip()
219
  return response
220
 
221
-
 
 
 
222
  class SkinGPTClassifier:
223
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
224
  self.device = torch.device(device)
@@ -229,6 +239,7 @@ class SkinGPTClassifier:
229
  transforms.ToTensor(),
230
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
231
  ])
 
232
 
233
  def _load_model(self):
234
  model_path = hf_hub_download(
@@ -239,11 +250,18 @@ class SkinGPTClassifier:
239
  model = model.to(self.device)
240
  return model
241
 
242
- def predict(self, image):
243
- image = image.convert('RGB')
244
- image_tensor = self.transform(image).unsqueeze(0).to(self.device)
 
 
 
 
245
  with torch.no_grad():
246
- diagnosis = self.model.generate(image_tensor)
 
 
 
247
  return {
248
  "diagnosis": diagnosis,
249
  }
 
68
  outputs = self.bert(
69
  attention_mask=attention_mask,
70
  inputs_embeds=combined_input,
71
+ output_attentions=True,
72
  return_dict=True
73
  )
74
+ plot_attention(outputs.attentions[-1][:, :, :32, 32:])
75
  return outputs.last_hidden_state[:, :self.num_query_tokens]
76
 
77
 
 
102
  self.q_former.bert_config.hidden_size,
103
  self.llama.config.hidden_size
104
  ).to(self.dtype)
105
+ self.tokenizer = LlamaTokenizer.from_pretrained(
106
+ "meta-llama/Llama-2-13b-chat-hf",
107
+ token=token,
108
+ padding_side="right"
109
+ )
110
+ self.tokenizer.add_special_tokens({
111
+ 'additional_special_tokens': ['<IMAGE>']
112
+ })
113
+ self.llama.resize_token_embeddings(len(self.tokenizer))
114
  for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
115
  for param in module.parameters():
116
  param.requires_grad = False
 
181
  image_embeds = self.llama_proj(qformer_output.to(self.dtype))
182
  return image_embeds
183
 
184
+ def generate_from_embeddings(self, image_embeds, user_input=None, max_new_tokens=300):
 
185
  if image_embeds.shape[-1] != self.llama.config.hidden_size:
186
  raise ValueError(
187
  f"Feature dimension mismatch. "
188
  f"Q-Former output: {image_embeds.shape[-1]}, "
189
  f"LLaMA expected: {self.llama.config.hidden_size}"
190
  )
191
+
192
+ if user_input:
193
+ prompt = f"""### Instruction:
194
+ <IMAGE>
195
+ {user_input}
196
+ ### Response:"""
197
+ else:
198
+ prompt = """### Instruction:
199
+ <IMAGE>
200
+ Could you describe the skin condition in this image?
201
+ ### Response:"""
202
+
203
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(image_embeds.device)
 
 
 
204
  input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
205
  visual_embeds = image_embeds.mean(dim=1)
206
  image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
 
225
  response = full_output.split("### Response:")[-1].strip()
226
  return response
227
 
228
+ def generate(self, images, user_input=None, max_new_tokens=300):
229
+ image_embeds = self.encode_image(images)
230
+ return self.generate_from_embeddings(image_embeds, user_input, max_new_tokens)
231
+
232
  class SkinGPTClassifier:
233
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
234
  self.device = torch.device(device)
 
239
  transforms.ToTensor(),
240
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
241
  ])
242
+ self.current_image_embeddings = None
243
 
244
  def _load_model(self):
245
  model_path = hf_hub_download(
 
250
  model = model.to(self.device)
251
  return model
252
 
253
+ def predict(self, image, user_input=None, reuse_embeddings=False):
254
+ if not reuse_embeddings or self.current_image_embeddings is None:
255
+ image = image.convert('RGB')
256
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
257
+ with torch.no_grad():
258
+ self.current_image_embeddings = self.model.encode_image(image_tensor)
259
+
260
  with torch.no_grad():
261
+ diagnosis = self.model.generate_from_embeddings(
262
+ self.current_image_embeddings,
263
+ user_input=user_input
264
+ )
265
  return {
266
  "diagnosis": diagnosis,
267
  }
app.py CHANGED
@@ -81,11 +81,12 @@ uploaded_file = st.file_uploader(
81
  if uploaded_file is not None and uploaded_file != st.session_state.current_image:
82
  st.session_state.messages = []
83
  st.session_state.current_image = uploaded_file
 
84
 
85
  image = Image.open(uploaded_file).convert("RGB")
86
  st.image(image, caption="Uploaded image", use_column_width=True)
87
  with st.spinner("Analyzing the image..."):
88
- result = classifier.predict(image)
89
 
90
  st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
91
 
@@ -101,18 +102,19 @@ if prompt := st.chat_input("Ask a follow-up question..."):
101
 
102
  with st.chat_message("assistant"):
103
  with st.spinner("Thinking..."):
 
104
  if len(st.session_state.messages) > 1:
105
  conversation_context = "\n".join(
106
  f"{m['role']}: {m['content']}"
107
- for m in st.session_state.messages[:-1] # Exclude current prompt
108
  )
109
  augmented_prompt = (
110
  f"Conversation history:\n{conversation_context}\n\n"
111
  f"Current question: {prompt}"
112
  )
113
- result = classifier.predict(image)
114
  else:
115
- result = classifier.predict(image)
116
 
117
  st.markdown(result["diagnosis"])
118
  st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
 
81
  if uploaded_file is not None and uploaded_file != st.session_state.current_image:
82
  st.session_state.messages = []
83
  st.session_state.current_image = uploaded_file
84
+ classifier.current_image_embeddings = None
85
 
86
  image = Image.open(uploaded_file).convert("RGB")
87
  st.image(image, caption="Uploaded image", use_column_width=True)
88
  with st.spinner("Analyzing the image..."):
89
+ result = classifier.predict(image, reuse_embeddings=False)
90
 
91
  st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
92
 
 
102
 
103
  with st.chat_message("assistant"):
104
  with st.spinner("Thinking..."):
105
+ image = Image.open(st.session_state.current_image).convert("RGB")
106
  if len(st.session_state.messages) > 1:
107
  conversation_context = "\n".join(
108
  f"{m['role']}: {m['content']}"
109
+ for m in st.session_state.messages[:-1]
110
  )
111
  augmented_prompt = (
112
  f"Conversation history:\n{conversation_context}\n\n"
113
  f"Current question: {prompt}"
114
  )
115
+ result = classifier.predict(image, user_input=augmented_prompt, reuse_embeddings=True)
116
  else:
117
+ result = classifier.predict(image, user_input=prompt, reuse_embeddings=False)
118
 
119
  st.markdown(result["diagnosis"])
120
  st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})