drhead commited on
Commit
d3aa745
·
verified ·
1 Parent(s): 06cf327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -31
app.py CHANGED
@@ -154,8 +154,6 @@ allowed_tags = list(tags.keys())
154
  for idx, tag in enumerate(allowed_tags):
155
  allowed_tags[idx] = tag.replace("_", " ")
156
 
157
-
158
-
159
  @spaces.GPU(duration=5)
160
  def run_classifier(image: Image.Image, threshold):
161
  img = image.convert('RGBA')
@@ -186,9 +184,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
186
 
187
  gradients = {}
188
  activations = {}
189
- cam = None
190
- target_tag_index = None
191
-
192
 
193
  def hook_forward(module, input, output):
194
  activations['value'] = output
@@ -200,29 +195,24 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
200
  handle_forward = model.norm.register_forward_hook(hook_forward)
201
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
202
 
203
- probits = model(tensor)[0].cpu()
204
 
205
  model.zero_grad()
206
- target_score = probits[target_tag_index]
207
- target_score.backward(retain_graph=True)
208
-
209
- grads = gradients.get('value')
210
- acts = activations.get('value')
211
 
212
- patch_grads = grads
213
- patch_acts = acts
214
-
215
- weights = torch.mean(patch_grads, dim=1).squeeze(0)
216
-
217
- cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
218
- cam_1d = torch.relu(cam_1d)
219
-
220
- cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
 
221
 
222
  handle_forward.remove()
223
  handle_backward.remove()
224
- gradients = {}
225
- activations = {}
226
 
227
  return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
228
 
@@ -245,26 +235,30 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
245
  size = max(w, h)
246
 
247
  # Normalize CAM to [0, 1]
248
- cam_norm = (cam - cam.min()) / (np.ptp(cam) + 1e-8)
 
249
 
250
  # Create heatmap using matplotlib colormap
251
  colormap = cm.get_cmap('inferno')
252
- cam_colored = colormap(cam_norm)[:, :, :3] # RGB
253
- cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
254
 
255
- cam_rgba = np.dstack((cam_colored, cam_alpha)) # Shape: (H, W, 4)
 
 
256
 
257
  # Resize CAM to match image
258
- cam_resized = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA").resize((216,216), resample=Image.Resampling.NEAREST).resize((size, size), resample=Image.Resampling.BICUBIC)
 
259
 
260
- cam_image = transforms.CenterCrop((h, w))(cam_resized)
 
 
261
 
262
  # Composite over original
263
  composite = Image.alpha_composite(image_pil, cam_image)
264
 
265
  return composite
266
 
267
-
268
  with gr.Blocks(css=".output-class { display: none; }") as demo:
269
  gr.Markdown("""
270
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
@@ -280,10 +274,20 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
280
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
281
  cam_state = gr.State()
282
  with gr.Row():
283
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
284
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
285
  threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
286
- cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="CAM Threshold")
287
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
288
  with gr.Column():
289
  tag_string = gr.Textbox(label="Tag String")
 
154
  for idx, tag in enumerate(allowed_tags):
155
  allowed_tags[idx] = tag.replace("_", " ")
156
 
 
 
157
  @spaces.GPU(duration=5)
158
  def run_classifier(image: Image.Image, threshold):
159
  img = image.convert('RGBA')
 
184
 
185
  gradients = {}
186
  activations = {}
 
 
 
187
 
188
  def hook_forward(module, input, output):
189
  activations['value'] = output
 
195
  handle_forward = model.norm.register_forward_hook(hook_forward)
196
  handle_backward = model.norm.register_full_backward_hook(hook_backward)
197
 
198
+ probits = model(tensor)[0]
199
 
200
  model.zero_grad()
201
+ probits[target_tag_index].backward(retain_graph=True)
 
 
 
 
202
 
203
+ with torch.no_grad():
204
+ patch_grads = gradients.get('value')
205
+ patch_acts = activations.get('value')
206
+
207
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
208
+
209
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
210
+ cam_1d = torch.relu(cam_1d)
211
+
212
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
213
 
214
  handle_forward.remove()
215
  handle_backward.remove()
 
 
216
 
217
  return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
218
 
 
235
  size = max(w, h)
236
 
237
  # Normalize CAM to [0, 1]
238
+ cam -= cam.min()
239
+ cam /= cam.max()
240
 
241
  # Create heatmap using matplotlib colormap
242
  colormap = cm.get_cmap('inferno')
243
+ cam_rgb = colormap(cam)[:, :, :3] # RGB
 
244
 
245
+ # Create alpha channel
246
+ cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
247
+ cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
248
 
249
  # Resize CAM to match image
250
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
251
+ cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
252
 
253
+ # Model uses padded image as input, this matches attention map to input image aspect ratio
254
+ cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
255
+ cam_pil = transforms.CenterCrop((h, w))(cam_pil)
256
 
257
  # Composite over original
258
  composite = Image.alpha_composite(image_pil, cam_image)
259
 
260
  return composite
261
 
 
262
  with gr.Blocks(css=".output-class { display: none; }") as demo:
263
  gr.Markdown("""
264
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
 
274
  sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
275
  cam_state = gr.State()
276
  with gr.Row():
277
+ custom_css = """
278
+ .inferno-slider input[type=range] {
279
+ background: linear-gradient(to right,
280
+ #000004, #1b0c41, #4a0c6b, #781c6d,
281
+ #a52c60, #cf4446, #ed6925, #fb9b06,
282
+ #f7d13d, #fcffa4
283
+ ) !important;
284
+ background-size: 100% 100% !important;
285
+ }
286
+ """
287
+ with gr.Column(css=custom_css):
288
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
289
  threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
290
+ cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
291
  alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
292
  with gr.Column():
293
  tag_string = gr.Textbox(label="Tag String")