Spaces:
Runtime error
Runtime error
Commit
·
81dfb50
1
Parent(s):
92c266a
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,11 @@ import gradio as gr
|
|
10 |
model = ResNet18()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
def inference(input_img, transparency, target_layer_number):
|
14 |
transform = transforms.ToTensor()
|
15 |
input_img = transform(input_img)
|
@@ -28,5 +33,5 @@ def inference(input_img, transparency, target_layer_number):
|
|
28 |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
29 |
return classes[prediction[0].item()], visualization
|
30 |
|
31 |
-
demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, label="Which Layer?")], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
|
32 |
demo.launch()
|
|
|
10 |
model = ResNet18()
|
11 |
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
|
12 |
|
13 |
+
inv_normalize = transforms.Normalize(
|
14 |
+
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
|
15 |
+
std=[1/0.23, 1/0.23, 1/0.23]
|
16 |
+
)
|
17 |
+
|
18 |
def inference(input_img, transparency, target_layer_number):
|
19 |
transform = transforms.ToTensor()
|
20 |
input_img = transform(input_img)
|
|
|
33 |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
34 |
return classes[prediction[0].item()], visualization
|
35 |
|
36 |
+
demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, step=1, label="Which Layer?")], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
|
37 |
demo.launch()
|