drhead commited on
Commit
de116ae
·
verified ·
1 Parent(s): 0e775d1

attempt to make tag vis work

Browse files
Files changed (1) hide show
  1. app.py +126 -5
app.py CHANGED
@@ -12,8 +12,6 @@ from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
 
15
- torch.set_grad_enabled(False)
16
-
17
  class Fit(torch.nn.Module):
18
  def __init__(
19
  self,
@@ -155,11 +153,14 @@ for idx, tag in enumerate(allowed_tags):
155
  allowed_tags[idx] = tag.replace("_", " ")
156
 
157
  sorted_tag_score = {}
 
 
158
 
159
  @spaces.GPU(duration=5)
160
  def run_classifier(image, threshold):
161
- global sorted_tag_score
162
- img = image.convert('RGBA')
 
163
  tensor = transform(img).unsqueeze(0)
164
 
165
  with torch.no_grad():
@@ -180,10 +181,124 @@ def create_tags(threshold):
180
  return text_no_impl, filtered_tag_score
181
 
182
  def clear_image():
183
- global sorted_tag_score
 
184
  sorted_tag_score = {}
185
  return "", {}
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  with gr.Blocks(css=".output-class { display: none; }") as demo:
188
  gr.Markdown("""
189
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
@@ -219,5 +334,11 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
219
  outputs=[tag_string, label_box]
220
  )
221
 
 
 
 
 
 
 
222
  if __name__ == "__main__":
223
  demo.launch()
 
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
 
 
 
15
  class Fit(torch.nn.Module):
16
  def __init__(
17
  self,
 
153
  allowed_tags[idx] = tag.replace("_", " ")
154
 
155
  sorted_tag_score = {}
156
+ input_image = None
157
+
158
 
159
  @spaces.GPU(duration=5)
160
  def run_classifier(image, threshold):
161
+ global sorted_tag_score, input_image
162
+ input_image = image.convert('RGBA')
163
+ img = input_image
164
  tensor = transform(img).unsqueeze(0)
165
 
166
  with torch.no_grad():
 
181
  return text_no_impl, filtered_tag_score
182
 
183
  def clear_image():
184
+ global sorted_tag_score, input_image
185
+ input_image = None
186
  sorted_tag_score = {}
187
  return "", {}
188
 
189
+ target_tag_index = None
190
+
191
+ # Store hooks and intermediate values
192
+ gradients = {}
193
+ activations = {}
194
+
195
+ def hook_forward(module, input, output):
196
+ activations['value'] = output
197
+
198
+ def hook_backward(module, grad_in, grad_out):
199
+ gradients['value'] = grad_out[0]
200
+
201
+ def cam_inference(target_tag, threshold):
202
+ global input_image, sorted_tag_score, target_tag_index, gradients, activations
203
+ img = input_image
204
+ tensor = transform(img).unsqueeze(0)
205
+
206
+ gradients = {}
207
+ activations = {}
208
+ cam = None
209
+ target_tag_index = None
210
+
211
+ if target_tag:
212
+ if target_tag not in allowed_tags:
213
+ print(f"Warning: Target tag '{target_tag}' not found in allowed tags.")
214
+ target_tag = None
215
+ else:
216
+ target_tag_index = allowed_tags.index(target_tag)
217
+ handle_forward = model.norm.register_forward_hook(hook_forward)
218
+ handle_backward = model.norm.register_full_backward_hook(hook_backward)
219
+
220
+ probits = model(tensor)[0].cpu()
221
+
222
+
223
+ if target_tag is not None and target_tag_index is not None:
224
+ model.zero_grad()
225
+ target_score = probits[target_tag_index]
226
+ target_score.backward(retain_graph=True)
227
+
228
+ grads = gradients.get('value')
229
+ acts = activations.get('value')
230
+
231
+ if grads is not None and acts is not None:
232
+ patch_grads = grads
233
+ patch_acts = acts
234
+
235
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
236
+
237
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
238
+ cam_1d = torch.relu(cam_1d)
239
+
240
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
241
+
242
+
243
+ handle_forward.remove()
244
+ handle_backward.remove()
245
+ gradients = {}
246
+ activations = {}
247
+
248
+ return create_cam_visualization_pil(cam, vis_threshold=threshold)
249
+
250
+ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
251
+ """
252
+ Overlays CAM on image and returns a PIL image.
253
+
254
+ Args:
255
+ image_pil: PIL Image (RGB)
256
+ cam: 2D numpy array (activation map)
257
+ alpha: float, blending factor
258
+ vis_threshold: float, minimum normalized CAM value to show color
259
+
260
+ Returns:
261
+ PIL.Image.Image with overlay
262
+ """
263
+ if cam is None:
264
+ print("CAM is None, skipping visualization.")
265
+ return image_pil
266
+ global input_image
267
+ # Convert to RGB (in case RGBA or others)
268
+ image_pil = input_image
269
+ w, h = image_pil.size
270
+
271
+ # Resize CAM to match image
272
+ cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
273
+
274
+ # Normalize CAM to [0, 1]
275
+ cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
276
+
277
+ # Apply threshold mask
278
+ mask = cam_norm >= vis_threshold
279
+
280
+ # Create heatmap using matplotlib colormap
281
+ colormap = cm.get_cmap('jet')
282
+ heatmap_rgba = colormap(cam_norm) # shape: (H, W, 4), values in [0, 1]
283
+ heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
284
+
285
+ # Convert heatmap to PIL image
286
+ heatmap_pil = Image.fromarray(heatmap_rgb).convert("RGB")
287
+
288
+ # Convert images to NumPy for blending
289
+ base_np = np.array(image_pil).astype(np.float32)
290
+ heat_np = np.array(heatmap_pil).astype(np.float32)
291
+
292
+ # Blend only where mask is True
293
+ blended_np = base_np.copy()
294
+ blended_np[mask] = base_np[mask] * (1 - alpha) + heat_np[mask] * alpha
295
+ blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
296
+
297
+ # Convert back to PIL image
298
+ blended_img = Image.fromarray(blended_np)
299
+ return blended_img
300
+
301
+
302
  with gr.Blocks(css=".output-class { display: none; }") as demo:
303
  gr.Markdown("""
304
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
 
334
  outputs=[tag_string, label_box]
335
  )
336
 
337
+ label_box.select(
338
+ fn=cam_inference,
339
+ inputs=[threshold_slider],
340
+ outputs=[image_input]
341
+ )
342
+
343
  if __name__ == "__main__":
344
  demo.launch()