SahithiR commited on
Commit
703bc1b
·
1 Parent(s): 20d8eca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -8
app.py CHANGED
@@ -18,14 +18,46 @@ from torchmetrics import Accuracy
18
  from torch.nn import functional as F
19
  import matplotlib.pyplot as plt
20
 
21
- #model = Cifar10SearchDataset()
22
- #trainer = pl.Trainer(accelerator="auto",max_epochs=24)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- #trainer.fit(model)
25
- #trainer.test(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
- #model = MyLightningModule()
29
- #trainer = pl.Trainer(accelerator="auto",max_epochs=24)
30
- #trainer = pl.Trainer()
31
- #trainer.fit(model, train_loader, test_loader)
 
18
  from torch.nn import functional as F
19
  import matplotlib.pyplot as plt
20
 
21
+ import gradio as gr
22
+ import torch
23
+ from PIL import Image
24
+ from Dataset.testalbumentation import TestAlbumentation
25
+ from Model.Lit_cifar_module import LitCifar
26
+ from utils import *
27
+
28
+ model = LitCifar().cpu()
29
+ model.load_state_dict(torch.load('final_dict.pth', map_location=torch.device('cpu')))
30
+ model.eval()
31
+
32
+ classes = ('plane', 'car', 'bird', 'cat',
33
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
34
+ global_classes = 5
35
 
36
+ def inference(input_image, transparency, target_layer, num_top_classes1, gradcam_image_display = False):
37
+ im = input_image
38
+ test_transform = TestAlbumentation()
39
+ im1 = test_transform(im)
40
+ im1 = im1.unsqueeze(0).cpu()
41
+ out0 = model(im1)
42
+ out = out0.detach().numpy()
43
+ confidences = {classes[i] : float(out[0][i]) for i in range(10)}
44
+ val = torch.argmax(out0).detach().numpy().tolist()
45
+ targ = [val]
46
+ input_image_np,visualization=gradcame(net, 0, targ, im1, target_layer, transparency)
47
+ return confidences, visualization
48
+
49
+ interface = gr.Interface(inference,
50
+ inputs = [gr.Image(shape=(32,32), type="pil", label = "Input image"),
51
+ gr.Slider(0,1, value = 0.5, label="opacity"),
52
+ gr.Slider(-2,-1, value = -2, step = 1, label="gradcam layer"),
53
+ gr.Slider(0,9, value = 0, step = 1, label="no. of top classes to display"),
54
+ gr.Checkbox(default=False, label="Show Gradcam Image")],
55
+ outputs = [gr.Label(num_top_classes=global_classes),
56
+ gr.Image(shape=(32,32), label = "Output")],
57
+ title = "Gradcam output of network trained on cifar10",
58
+ examples = [["cat.jpg", 0.5, -1], ["dog.jpg",0.5,-1]],
59
+ )
60
 
61
 
62
+ # Launch the Gradio interface
63
+ interface.launch()