theschoolofai commited on
Commit
04459be
·
1 Parent(s): 578a79b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from resnet import ResNet18
8
+ import gradio as gr
9
+
10
+ model = ResNet18()
11
+ device = torch.device("cpu")
12
+ model.load_state_dict(torch.load("model.pth"), strict=False, map_location=device)
13
+
14
+ def inference(input_img, transparency):
15
+ transform = transforms.ToTensor()
16
+ input_img = transform(input_img)
17
+ input_img = input_img
18
+ input_img = input_img.unsqueeze(0)
19
+ outputs = model(input_img)
20
+ _, prediction = torch.max(outputs, 1)
21
+ target_layers = [model.layer2[-2]]
22
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
23
+ grayscale_cam = cam(input_tensor=input_img, targets=targets)
24
+ grayscale_cam = grayscale_cam[0, :]
25
+ img = input_img.squeeze(0)
26
+ img = inv_normalize(img)
27
+ rgb_img = np.transpose(img, (1, 2, 0))
28
+ rgb_img = rgb_img.numpy()
29
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
30
+ return classes[prediction[0].item()], visualization
31
+
32
+ demo = gr.Interface(inference, [gr.Image(shape=(32, 32)), gr.Slider(0, 1)], ["text", gr.Image(shape=(32, 32)).style(width=128, height=128)])
33
+ demo.launch()