theschoolofai commited on
Commit
92c266a
·
1 Parent(s): bdd18b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,16 +10,16 @@ import gradio as gr
10
  model = ResNet18()
11
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
12
 
13
- def inference(input_img, transparency):
14
  transform = transforms.ToTensor()
15
  input_img = transform(input_img)
16
  input_img = input_img
17
  input_img = input_img.unsqueeze(0)
18
  outputs = model(input_img)
19
  _, prediction = torch.max(outputs, 1)
20
- target_layers = [model.layer2[-2]]
21
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
22
- grayscale_cam = cam(input_tensor=input_img, targets=targets)
23
  grayscale_cam = grayscale_cam[0, :]
24
  img = input_img.squeeze(0)
25
  img = inv_normalize(img)
@@ -28,5 +28,5 @@ def inference(input_img, transparency):
28
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
29
  return classes[prediction[0].item()], visualization
30
 
31
- demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1)], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
32
  demo.launch()
 
10
  model = ResNet18()
11
  model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
12
 
13
+ def inference(input_img, transparency, target_layer_number):
14
  transform = transforms.ToTensor()
15
  input_img = transform(input_img)
16
  input_img = input_img
17
  input_img = input_img.unsqueeze(0)
18
  outputs = model(input_img)
19
  _, prediction = torch.max(outputs, 1)
20
+ target_layers = [model.layer2[target_layer_number]]
21
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
22
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
23
  grayscale_cam = grayscale_cam[0, :]
24
  img = input_img.squeeze(0)
25
  img = inv_normalize(img)
 
28
  visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
29
  return classes[prediction[0].item()], visualization
30
 
31
+ demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-5, -1, value = -2, label="Which Layer?")], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
32
  demo.launch()