Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,24 +2,19 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import pytorch_lightning as pl
|
4 |
from torchvision.datasets import MNIST
|
5 |
-
from
|
6 |
-
from torch.utils.data import DataLoader, random_split
|
7 |
import torch
|
8 |
import albumentations as A
|
9 |
from albumentations.pytorch import ToTensorV2
|
10 |
-
|
11 |
from torchvision import transforms
|
12 |
import numpy as np
|
13 |
import torch
|
14 |
from torchvision import datasets
|
15 |
-
from torch.utils.data import Dataset, DataLoader
|
16 |
from torchvision.transforms import ToTensor
|
17 |
from torchmetrics import Accuracy
|
18 |
from torch.nn import functional as F
|
19 |
import matplotlib.pyplot as plt
|
20 |
-
|
21 |
import gradio as gr
|
22 |
-
import torch
|
23 |
from PIL import Image
|
24 |
from Dataset.testalbumentation import TestAlbumentation
|
25 |
from Model.Lit_cifar_module import LitCifar
|
@@ -34,16 +29,16 @@ classes = ('plane', 'car', 'bird', 'cat',
|
|
34 |
global_classes = 5
|
35 |
|
36 |
def inference(input_image, transparency, target_layer, num_top_classes1, gradcam_image_display = False):
|
37 |
-
|
38 |
test_transform = TestAlbumentation()
|
39 |
-
|
40 |
-
|
41 |
-
out0 = model(
|
42 |
out = out0.detach().numpy()
|
43 |
confidences = {classes[i] : float(out[0][i]) for i in range(10)}
|
44 |
val = torch.argmax(out0).detach().numpy().tolist()
|
45 |
-
|
46 |
-
input_image_np,visualization=gradcame(
|
47 |
return confidences, visualization
|
48 |
|
49 |
interface = gr.Interface(inference,
|
|
|
2 |
import torch.nn as nn
|
3 |
import pytorch_lightning as pl
|
4 |
from torchvision.datasets import MNIST
|
5 |
+
from torch.utils.data import DataLoader, random_split, Dataset
|
|
|
6 |
import torch
|
7 |
import albumentations as A
|
8 |
from albumentations.pytorch import ToTensorV2
|
|
|
9 |
from torchvision import transforms
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
from torchvision import datasets
|
|
|
13 |
from torchvision.transforms import ToTensor
|
14 |
from torchmetrics import Accuracy
|
15 |
from torch.nn import functional as F
|
16 |
import matplotlib.pyplot as plt
|
|
|
17 |
import gradio as gr
|
|
|
18 |
from PIL import Image
|
19 |
from Dataset.testalbumentation import TestAlbumentation
|
20 |
from Model.Lit_cifar_module import LitCifar
|
|
|
29 |
global_classes = 5
|
30 |
|
31 |
def inference(input_image, transparency, target_layer, num_top_classes1, gradcam_image_display = False):
|
32 |
+
image = input_image
|
33 |
test_transform = TestAlbumentation()
|
34 |
+
image1 = test_transform(image)
|
35 |
+
image1 = image1.unsqueeze(0).cpu()
|
36 |
+
out0 = model(image1)
|
37 |
out = out0.detach().numpy()
|
38 |
confidences = {classes[i] : float(out[0][i]) for i in range(10)}
|
39 |
val = torch.argmax(out0).detach().numpy().tolist()
|
40 |
+
target = [val]
|
41 |
+
input_image_np,visualization=gradcame(model, 0, target, image1, target_layer, transparency)
|
42 |
return confidences, visualization
|
43 |
|
44 |
interface = gr.Interface(inference,
|