fix added
Browse files- SkinGPT.py +41 -23
- app.py +6 -4
SkinGPT.py
CHANGED
@@ -68,8 +68,10 @@ class Blip2QFormer(nn.Module):
|
|
68 |
outputs = self.bert(
|
69 |
attention_mask=attention_mask,
|
70 |
inputs_embeds=combined_input,
|
|
|
71 |
return_dict=True
|
72 |
)
|
|
|
73 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
74 |
|
75 |
|
@@ -100,6 +102,15 @@ class SkinGPT4(nn.Module):
|
|
100 |
self.q_former.bert_config.hidden_size,
|
101 |
self.llama.config.hidden_size
|
102 |
).to(self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
|
104 |
for param in module.parameters():
|
105 |
param.requires_grad = False
|
@@ -170,30 +181,26 @@ class SkinGPT4(nn.Module):
|
|
170 |
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
|
171 |
return image_embeds
|
172 |
|
173 |
-
def
|
174 |
-
image_embeds = self.encode_image(images)
|
175 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
176 |
raise ValueError(
|
177 |
f"Feature dimension mismatch. "
|
178 |
f"Q-Former output: {image_embeds.shape[-1]}, "
|
179 |
f"LLaMA expected: {self.llama.config.hidden_size}"
|
180 |
)
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
"
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
raise ValueError("Failed to add <IMAGE> token!")
|
195 |
-
self.llama.resize_token_embeddings(len(self.tokenizer))
|
196 |
-
inputs = self.tokenizer(prompt, return_tensors="pt").to(images.device)
|
197 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
198 |
visual_embeds = image_embeds.mean(dim=1)
|
199 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
@@ -218,7 +225,10 @@ class SkinGPT4(nn.Module):
|
|
218 |
response = full_output.split("### Response:")[-1].strip()
|
219 |
return response
|
220 |
|
221 |
-
|
|
|
|
|
|
|
222 |
class SkinGPTClassifier:
|
223 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
224 |
self.device = torch.device(device)
|
@@ -229,6 +239,7 @@ class SkinGPTClassifier:
|
|
229 |
transforms.ToTensor(),
|
230 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
231 |
])
|
|
|
232 |
|
233 |
def _load_model(self):
|
234 |
model_path = hf_hub_download(
|
@@ -239,11 +250,18 @@ class SkinGPTClassifier:
|
|
239 |
model = model.to(self.device)
|
240 |
return model
|
241 |
|
242 |
-
def predict(self, image):
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
with torch.no_grad():
|
246 |
-
diagnosis = self.model.
|
|
|
|
|
|
|
247 |
return {
|
248 |
"diagnosis": diagnosis,
|
249 |
}
|
|
|
68 |
outputs = self.bert(
|
69 |
attention_mask=attention_mask,
|
70 |
inputs_embeds=combined_input,
|
71 |
+
output_attentions=True,
|
72 |
return_dict=True
|
73 |
)
|
74 |
+
plot_attention(outputs.attentions[-1][:, :, :32, 32:])
|
75 |
return outputs.last_hidden_state[:, :self.num_query_tokens]
|
76 |
|
77 |
|
|
|
102 |
self.q_former.bert_config.hidden_size,
|
103 |
self.llama.config.hidden_size
|
104 |
).to(self.dtype)
|
105 |
+
self.tokenizer = LlamaTokenizer.from_pretrained(
|
106 |
+
"meta-llama/Llama-2-13b-chat-hf",
|
107 |
+
token=token,
|
108 |
+
padding_side="right"
|
109 |
+
)
|
110 |
+
self.tokenizer.add_special_tokens({
|
111 |
+
'additional_special_tokens': ['<IMAGE>']
|
112 |
+
})
|
113 |
+
self.llama.resize_token_embeddings(len(self.tokenizer))
|
114 |
for module in [self.vit, self.ln_vision, self.q_former, self.llama_proj, self.llama]:
|
115 |
for param in module.parameters():
|
116 |
param.requires_grad = False
|
|
|
181 |
image_embeds = self.llama_proj(qformer_output.to(self.dtype))
|
182 |
return image_embeds
|
183 |
|
184 |
+
def generate_from_embeddings(self, image_embeds, user_input=None, max_new_tokens=300):
|
|
|
185 |
if image_embeds.shape[-1] != self.llama.config.hidden_size:
|
186 |
raise ValueError(
|
187 |
f"Feature dimension mismatch. "
|
188 |
f"Q-Former output: {image_embeds.shape[-1]}, "
|
189 |
f"LLaMA expected: {self.llama.config.hidden_size}"
|
190 |
)
|
191 |
+
|
192 |
+
if user_input:
|
193 |
+
prompt = f"""### Instruction:
|
194 |
+
<IMAGE>
|
195 |
+
{user_input}
|
196 |
+
### Response:"""
|
197 |
+
else:
|
198 |
+
prompt = """### Instruction:
|
199 |
+
<IMAGE>
|
200 |
+
Could you describe the skin condition in this image?
|
201 |
+
### Response:"""
|
202 |
+
|
203 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(image_embeds.device)
|
|
|
|
|
|
|
204 |
input_embeddings = self.llama.model.embed_tokens(inputs.input_ids)
|
205 |
visual_embeds = image_embeds.mean(dim=1)
|
206 |
image_token_id = self.tokenizer.convert_tokens_to_ids("<IMAGE>")
|
|
|
225 |
response = full_output.split("### Response:")[-1].strip()
|
226 |
return response
|
227 |
|
228 |
+
def generate(self, images, user_input=None, max_new_tokens=300):
|
229 |
+
image_embeds = self.encode_image(images)
|
230 |
+
return self.generate_from_embeddings(image_embeds, user_input, max_new_tokens)
|
231 |
+
|
232 |
class SkinGPTClassifier:
|
233 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
234 |
self.device = torch.device(device)
|
|
|
239 |
transforms.ToTensor(),
|
240 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
241 |
])
|
242 |
+
self.current_image_embeddings = None
|
243 |
|
244 |
def _load_model(self):
|
245 |
model_path = hf_hub_download(
|
|
|
250 |
model = model.to(self.device)
|
251 |
return model
|
252 |
|
253 |
+
def predict(self, image, user_input=None, reuse_embeddings=False):
|
254 |
+
if not reuse_embeddings or self.current_image_embeddings is None:
|
255 |
+
image = image.convert('RGB')
|
256 |
+
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
257 |
+
with torch.no_grad():
|
258 |
+
self.current_image_embeddings = self.model.encode_image(image_tensor)
|
259 |
+
|
260 |
with torch.no_grad():
|
261 |
+
diagnosis = self.model.generate_from_embeddings(
|
262 |
+
self.current_image_embeddings,
|
263 |
+
user_input=user_input
|
264 |
+
)
|
265 |
return {
|
266 |
"diagnosis": diagnosis,
|
267 |
}
|
app.py
CHANGED
@@ -81,11 +81,12 @@ uploaded_file = st.file_uploader(
|
|
81 |
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
|
82 |
st.session_state.messages = []
|
83 |
st.session_state.current_image = uploaded_file
|
|
|
84 |
|
85 |
image = Image.open(uploaded_file).convert("RGB")
|
86 |
st.image(image, caption="Uploaded image", use_column_width=True)
|
87 |
with st.spinner("Analyzing the image..."):
|
88 |
-
result = classifier.predict(image)
|
89 |
|
90 |
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|
91 |
|
@@ -101,18 +102,19 @@ if prompt := st.chat_input("Ask a follow-up question..."):
|
|
101 |
|
102 |
with st.chat_message("assistant"):
|
103 |
with st.spinner("Thinking..."):
|
|
|
104 |
if len(st.session_state.messages) > 1:
|
105 |
conversation_context = "\n".join(
|
106 |
f"{m['role']}: {m['content']}"
|
107 |
-
for m in st.session_state.messages[:-1]
|
108 |
)
|
109 |
augmented_prompt = (
|
110 |
f"Conversation history:\n{conversation_context}\n\n"
|
111 |
f"Current question: {prompt}"
|
112 |
)
|
113 |
-
result = classifier.predict(image)
|
114 |
else:
|
115 |
-
result = classifier.predict(image)
|
116 |
|
117 |
st.markdown(result["diagnosis"])
|
118 |
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|
|
|
81 |
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
|
82 |
st.session_state.messages = []
|
83 |
st.session_state.current_image = uploaded_file
|
84 |
+
classifier.current_image_embeddings = None
|
85 |
|
86 |
image = Image.open(uploaded_file).convert("RGB")
|
87 |
st.image(image, caption="Uploaded image", use_column_width=True)
|
88 |
with st.spinner("Analyzing the image..."):
|
89 |
+
result = classifier.predict(image, reuse_embeddings=False)
|
90 |
|
91 |
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|
92 |
|
|
|
102 |
|
103 |
with st.chat_message("assistant"):
|
104 |
with st.spinner("Thinking..."):
|
105 |
+
image = Image.open(st.session_state.current_image).convert("RGB")
|
106 |
if len(st.session_state.messages) > 1:
|
107 |
conversation_context = "\n".join(
|
108 |
f"{m['role']}: {m['content']}"
|
109 |
+
for m in st.session_state.messages[:-1]
|
110 |
)
|
111 |
augmented_prompt = (
|
112 |
f"Conversation history:\n{conversation_context}\n\n"
|
113 |
f"Current question: {prompt}"
|
114 |
)
|
115 |
+
result = classifier.predict(image, user_input=augmented_prompt, reuse_embeddings=True)
|
116 |
else:
|
117 |
+
result = classifier.predict(image, user_input=prompt, reuse_embeddings=False)
|
118 |
|
119 |
st.markdown(result["diagnosis"])
|
120 |
st.session_state.messages.append({"role": "assistant", "content": result["diagnosis"]})
|