Spaces:
Runtime error
Runtime error
Commit
·
92c266a
1
Parent(s):
bdd18b5
Update app.py
Browse files
app.py
CHANGED
@@ -10,16 +10,16 @@ import gradio as gr
|
|
10 |
model = ResNet18()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
13 |
-
def inference(input_img, transparency):
|
14 |
transform = transforms.ToTensor()
|
15 |
input_img = transform(input_img)
|
16 |
input_img = input_img
|
17 |
input_img = input_img.unsqueeze(0)
|
18 |
outputs = model(input_img)
|
19 |
_, prediction = torch.max(outputs, 1)
|
20 |
-
target_layers = [model.layer2[
|
21 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
22 |
-
grayscale_cam = cam(input_tensor=input_img, targets=
|
23 |
grayscale_cam = grayscale_cam[0, :]
|
24 |
img = input_img.squeeze(0)
|
25 |
img = inv_normalize(img)
|
@@ -28,5 +28,5 @@ def inference(input_img, transparency):
|
|
28 |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
29 |
return classes[prediction[0].item()], visualization
|
30 |
|
31 |
-
demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1)], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
|
32 |
demo.launch()
|
|
|
10 |
model = ResNet18()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
13 |
+
def inference(input_img, transparency, target_layer_number):
|
14 |
transform = transforms.ToTensor()
|
15 |
input_img = transform(input_img)
|
16 |
input_img = input_img
|
17 |
input_img = input_img.unsqueeze(0)
|
18 |
outputs = model(input_img)
|
19 |
_, prediction = torch.max(outputs, 1)
|
20 |
+
target_layers = [model.layer2[target_layer_number]]
|
21 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
|
22 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
23 |
grayscale_cam = grayscale_cam[0, :]
|
24 |
img = input_img.squeeze(0)
|
25 |
img = inv_normalize(img)
|
|
|
28 |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
29 |
return classes[prediction[0].item()], visualization
|
30 |
|
31 |
+
demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, label="Which Layer?")], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
|
32 |
demo.launch()
|