drhead commited on
Commit
ff9a3d4
·
verified ·
1 Parent(s): 3b3f560

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -11,6 +11,8 @@ from torchvision.transforms import transforms
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
 
 
14
 
15
  class Fit(torch.nn.Module):
16
  def __init__(
@@ -198,7 +200,8 @@ def hook_forward(module, input, output):
198
  def hook_backward(module, grad_in, grad_out):
199
  gradients['value'] = grad_out[0]
200
 
201
- def cam_inference(target_tag, threshold):
 
202
  print(f"target_tag: {target_tag}")
203
  global input_image, sorted_tag_score, target_tag_index, gradients, activations
204
  img = input_image
@@ -268,7 +271,7 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
268
  w, h = image_pil.size
269
 
270
  # Resize CAM to match image
271
- cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
272
 
273
  # Normalize CAM to [0, 1]
274
  cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
@@ -335,7 +338,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
335
 
336
  label_box.select(
337
  fn=cam_inference,
338
- inputs=[label_box, threshold_slider],
339
  outputs=[image_input]
340
  )
341
 
 
11
  from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
+ import numpy as np
15
+ import matplotlib.cm as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
 
200
  def hook_backward(module, grad_in, grad_out):
201
  gradients['value'] = grad_out[0]
202
 
203
+ def cam_inference(threshold, evt: gr.SelectData):
204
+ target_tag = evt.value
205
  print(f"target_tag: {target_tag}")
206
  global input_image, sorted_tag_score, target_tag_index, gradients, activations
207
  img = input_image
 
271
  w, h = image_pil.size
272
 
273
  # Resize CAM to match image
274
+ cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.Resampling.BILINEAR))
275
 
276
  # Normalize CAM to [0, 1]
277
  cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
 
338
 
339
  label_box.select(
340
  fn=cam_inference,
341
+ inputs=[threshold_slider],
342
  outputs=[image_input]
343
  )
344