SahithiR commited on
Commit
e472c55
·
1 Parent(s): 1682e68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -24,35 +24,40 @@ model = LitCifar().cpu()
24
  model.load_state_dict(torch.load('final_dict.pth', map_location=torch.device('cpu')))
25
  model.eval()
26
 
27
- classes = ('plane', 'car', 'bird', 'cat',
28
- 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
29
- global_classes = 5
30
 
31
- def inference(input_image, transparency, target_layer, num_top_classes1, gradcam_image_display = False):
32
- image = input_image
33
- test_transform = TestAlbumentation()
34
- image1 = test_transform(image)
35
- image1 = image1.unsqueeze(0).cpu()
36
- out0 = model(image1)
37
- out = out0.detach().numpy()
38
- confidences = {classes[i] : float(out[0][i]) for i in range(10)}
39
- val = torch.argmax(out0).detach().numpy().tolist()
40
- target = [val]
41
- input_image_np,visualization=gradcame(model, 0, target, image1, target_layer, transparency)
42
- return confidences, visualization
43
-
44
- interface = gr.Interface(inference,
45
- inputs = [gr.Image(shape=(32,32), type="pil", label = "Input image"),
46
- gr.Slider(0,1, value = 0.5, label="opacity"),
47
- gr.Slider(-2,-1, value = -2, step = 1, label="gradcam layer"),
48
- gr.Slider(0,9, value = 0, step = 1, label="no. of top classes to display"),
49
- gr.Checkbox(default=False, label="Show Gradcam Image")],
50
- outputs = [gr.Label(num_top_classes=global_classes),
51
- gr.Image(shape=(32,32), label = "Output")],
52
- title = "Gradcam output of network trained on cifar10",
53
- examples = [["cat.jpg", 0.5, -1], ["dog.jpg",0.5,-1]],
54
- )
55
 
56
-
57
- # Launch the Gradio interface
58
- interface.launch()
 
 
 
 
 
 
 
 
 
 
24
  model.load_state_dict(torch.load('final_dict.pth', map_location=torch.device('cpu')))
25
  model.eval()
26
 
27
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
28
+ 'dog', 'frog', 'horse', 'ship', 'truck')
 
29
 
30
+ def inference(input_img, transparency = 0.5, target_layer_number = -1):
31
+ transform = transforms.ToTensor()
32
+ org_img = input_img
33
+ input_img = transform(input_img)
34
+ input_img = input_img
35
+ input_img = input_img.unsqueeze(0)
36
+ outputs = model(input_img)
37
+ softmax = torch.nn.Softmax(dim=0)
38
+ o = softmax(outputs.flatten())
39
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
40
+ _, prediction = torch.max(outputs, 1)
41
+ target_layers = [model.layer2[target_layer_number]]
42
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
43
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
44
+ grayscale_cam = grayscale_cam[0, :]
45
+ img = input_img.squeeze(0)
46
+ img = inv_normalize(img)
47
+ rgb_img = np.transpose(img, (1, 2, 0))
48
+ rgb_img = rgb_img.numpy()
49
+ visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
50
+ return confidences, visualization
 
 
 
51
 
52
+ title = "CIFAR10 trained on ResNet18 Model with GradCAM"
53
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
54
+ examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.5, -1]]
55
+ demo = gr.Interface(
56
+ inference,
57
+ inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?")],
58
+ outputs = [gr.Label(num_top_classes=3), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
59
+ title = title,
60
+ description = description,
61
+ examples = examples,
62
+ )
63
+ demo.launch()