osbm commited on
Commit
6255c29
·
verified ·
1 Parent(s): 7662ff2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from torchvision import models
7
+
8
+ def predict(image):
9
+ print(type(image))
10
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
11
+ # Load model
12
+ model = models.resnet50(pretrained=True)
13
+ num_ftrs = model.fc.in_features
14
+ model.fc = nn.Linear(num_ftrs, 1)
15
+ model.load_state_dict(torch.load("best_f1.pth"))
16
+ model.eval()
17
+
18
+ # Preprocess image
19
+ valid_transform = transforms.Compose([
20
+ # transforms.ToPILImage(), # Convert the image to a PIL Image
21
+ transforms.Resize((224, 224)), # Resize the image to final_size x final_size
22
+ transforms.ToTensor(), # Convert the image to a PyTorch tensor
23
+ transforms.Normalize( # Normalize the image
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ input_batch = valid_transform(image).unsqueeze(0)
30
+ # Make prediction
31
+ with torch.no_grad():
32
+ output = model(input_batch)
33
+ output = torch.sigmoid(output).squeeze().item()
34
+ if output > 0.5:
35
+ predicted = 1
36
+ else:
37
+ predicted = 0
38
+
39
+ int2label = {0: "cat", 1: "dog"}
40
+ return int2label[predicted]
41
+
42
+ demo = gr.Interface(
43
+ predict,
44
+ inputs="image",
45
+ outputs="label",
46
+ title="Cats vs Dogs",
47
+ description="This model predicts whether an image contains a cat or a dog.",
48
+ examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"]
49
+ )
50
+
51
+ demo.launch()