Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ from torchvision.transforms import transforms
|
|
11 |
from torchvision.transforms import InterpolationMode
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
14 |
|
15 |
class Fit(torch.nn.Module):
|
16 |
def __init__(
|
@@ -198,7 +200,8 @@ def hook_forward(module, input, output):
|
|
198 |
def hook_backward(module, grad_in, grad_out):
|
199 |
gradients['value'] = grad_out[0]
|
200 |
|
201 |
-
def cam_inference(
|
|
|
202 |
print(f"target_tag: {target_tag}")
|
203 |
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
204 |
img = input_image
|
@@ -268,7 +271,7 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
|
|
268 |
w, h = image_pil.size
|
269 |
|
270 |
# Resize CAM to match image
|
271 |
-
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.BILINEAR))
|
272 |
|
273 |
# Normalize CAM to [0, 1]
|
274 |
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
@@ -335,7 +338,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
335 |
|
336 |
label_box.select(
|
337 |
fn=cam_inference,
|
338 |
-
inputs=[
|
339 |
outputs=[image_input]
|
340 |
)
|
341 |
|
|
|
11 |
from torchvision.transforms import InterpolationMode
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
+
import numpy as np
|
15 |
+
import matplotlib.cm as cm
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
|
|
200 |
def hook_backward(module, grad_in, grad_out):
|
201 |
gradients['value'] = grad_out[0]
|
202 |
|
203 |
+
def cam_inference(threshold, evt: gr.SelectData):
|
204 |
+
target_tag = evt.value
|
205 |
print(f"target_tag: {target_tag}")
|
206 |
global input_image, sorted_tag_score, target_tag_index, gradients, activations
|
207 |
img = input_image
|
|
|
271 |
w, h = image_pil.size
|
272 |
|
273 |
# Resize CAM to match image
|
274 |
+
cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.Resampling.BILINEAR))
|
275 |
|
276 |
# Normalize CAM to [0, 1]
|
277 |
cam_norm = (cam_resized - cam_resized.min()) / (cam_resized.ptp() + 1e-8)
|
|
|
338 |
|
339 |
label_box.select(
|
340 |
fn=cam_inference,
|
341 |
+
inputs=[threshold_slider],
|
342 |
outputs=[image_input]
|
343 |
)
|
344 |
|