Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -18,14 +18,46 @@ from torchmetrics import Accuracy
|
|
18 |
from torch.nn import functional as F
|
19 |
import matplotlib.pyplot as plt
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
#trainer = pl.Trainer()
|
31 |
-
#trainer.fit(model, train_loader, test_loader)
|
|
|
18 |
from torch.nn import functional as F
|
19 |
import matplotlib.pyplot as plt
|
20 |
|
21 |
+
import gradio as gr
|
22 |
+
import torch
|
23 |
+
from PIL import Image
|
24 |
+
from Dataset.testalbumentation import TestAlbumentation
|
25 |
+
from Model.Lit_cifar_module import LitCifar
|
26 |
+
from utils import *
|
27 |
+
|
28 |
+
model = LitCifar().cpu()
|
29 |
+
model.load_state_dict(torch.load('final_dict.pth', map_location=torch.device('cpu')))
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
classes = ('plane', 'car', 'bird', 'cat',
|
33 |
+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
34 |
+
global_classes = 5
|
35 |
|
36 |
+
def inference(input_image, transparency, target_layer, num_top_classes1, gradcam_image_display = False):
|
37 |
+
im = input_image
|
38 |
+
test_transform = TestAlbumentation()
|
39 |
+
im1 = test_transform(im)
|
40 |
+
im1 = im1.unsqueeze(0).cpu()
|
41 |
+
out0 = model(im1)
|
42 |
+
out = out0.detach().numpy()
|
43 |
+
confidences = {classes[i] : float(out[0][i]) for i in range(10)}
|
44 |
+
val = torch.argmax(out0).detach().numpy().tolist()
|
45 |
+
targ = [val]
|
46 |
+
input_image_np,visualization=gradcame(net, 0, targ, im1, target_layer, transparency)
|
47 |
+
return confidences, visualization
|
48 |
+
|
49 |
+
interface = gr.Interface(inference,
|
50 |
+
inputs = [gr.Image(shape=(32,32), type="pil", label = "Input image"),
|
51 |
+
gr.Slider(0,1, value = 0.5, label="opacity"),
|
52 |
+
gr.Slider(-2,-1, value = -2, step = 1, label="gradcam layer"),
|
53 |
+
gr.Slider(0,9, value = 0, step = 1, label="no. of top classes to display"),
|
54 |
+
gr.Checkbox(default=False, label="Show Gradcam Image")],
|
55 |
+
outputs = [gr.Label(num_top_classes=global_classes),
|
56 |
+
gr.Image(shape=(32,32), label = "Output")],
|
57 |
+
title = "Gradcam output of network trained on cifar10",
|
58 |
+
examples = [["cat.jpg", 0.5, -1], ["dog.jpg",0.5,-1]],
|
59 |
+
)
|
60 |
|
61 |
|
62 |
+
# Launch the Gradio interface
|
63 |
+
interface.launch()
|
|
|
|