theschoolofai commited on
Commit
6457fca
·
1 Parent(s): 81dfb50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import gradio as gr
5
  from PIL import Image
6
  from pytorch_grad_cam import GradCAM
 
7
  from resnet import ResNet18
8
  import gradio as gr
9
 
@@ -14,6 +15,8 @@ inv_normalize = transforms.Normalize(
14
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
15
  std=[1/0.23, 1/0.23, 1/0.23]
16
  )
 
 
17
 
18
  def inference(input_img, transparency, target_layer_number):
19
  transform = transforms.ToTensor()
@@ -33,5 +36,5 @@ def inference(input_img, transparency, target_layer_number):
33
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
34
  return classes[prediction[0].item()], visualization
35
 
36
- demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, step=1, label="Which Layer?")], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
37
  demo.launch()
 
4
  import gradio as gr
5
  from PIL import Image
6
  from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
  from resnet import ResNet18
9
  import gradio as gr
10
 
 
15
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
16
  std=[1/0.23, 1/0.23, 1/0.23]
17
  )
18
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
19
+ 'dog', 'frog', 'horse', 'ship', 'truck')
20
 
21
  def inference(input_img, transparency, target_layer_number):
22
  transform = transforms.ToTensor()
 
36
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
37
  return classes[prediction[0].item()], visualization
38
 
39
+ demo = gr.Interface(inference, [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, step=1, label="Which Layer?")], ["text", gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)])
40
  demo.launch()