anindya-hf-2002 commited on
Commit
eebbfac
·
verified ·
1 Parent(s): 8c549a7

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -42
app.py CHANGED
@@ -1,43 +1,41 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision import transforms
4
-
5
- from model import EfficientNet
6
- # Load the PyTorch model
7
- # model = torch.load('efficientnet_b2-epoch08-val_loss0_02_cat-vs-dog_clasifier.pt')
8
- model = EfficientNet.load_from_checkpoint("./efficientnet_b2-epoch49-val_loss0.02.ckpt")
9
- model.eval()
10
- # Define the image preprocessing
11
- transform = transforms.Compose([
12
- transforms.Resize(256),
13
- transforms.CenterCrop(224),
14
- transforms.ToTensor(),
15
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
16
- ])
17
-
18
- # Define the prediction function
19
- def predict(image):
20
- image = transform(image).unsqueeze(0)
21
- with torch.no_grad():
22
- output = model(image)
23
- print(output)
24
- probabilities = torch.softmax(output, dim=1)
25
- print(probabilities)
26
- _, predicted = torch.max(output.data, 1)
27
- prediction = 'Cat' if predicted == 0 else 'Dog'
28
- confidence = probabilities.max().item()
29
- result = {prediction: confidence}
30
- return result
31
-
32
- # Create the Gradio interface
33
- app = gr.Interface(
34
- fn=predict,
35
- inputs=gr.Image(type="pil"),
36
- outputs=gr.Label(num_top_classes=1),
37
- title="Dog vs Cat Classifier",
38
- description="Upload an image to classify whether it's a dog or a cat.",
39
- allow_flagging='never'
40
- )
41
-
42
- # Launch the app
43
  app.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+
5
+ from model import EfficientNet
6
+ # Load the PyTorch model
7
+ # model = torch.load('efficientnet_b2-epoch08-val_loss0_02_cat-vs-dog_clasifier.pt')
8
+ model = EfficientNet.load_from_checkpoint("./efficientnet_b2-epoch49-val_loss0.02.ckpt")
9
+ model.eval()
10
+ # Define the image preprocessing
11
+ transform = transforms.Compose([
12
+ transforms.Resize(256),
13
+ transforms.CenterCrop(224),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
16
+ ])
17
+
18
+ # Define the prediction function
19
+ def predict(image):
20
+ image = transform(image).unsqueeze(0)
21
+ with torch.no_grad():
22
+ output = model(image)
23
+ probabilities = torch.softmax(output, dim=1)
24
+ _, predicted = torch.max(output.data, 1)
25
+ prediction = 'Cat' if predicted == 0 else 'Dog'
26
+ confidence = probabilities.max().item()
27
+ result = {prediction: confidence}
28
+ return result
29
+
30
+ # Create the Gradio interface
31
+ app = gr.Interface(
32
+ fn=predict,
33
+ inputs=gr.Image(type="pil"),
34
+ outputs=gr.Label(num_top_classes=1),
35
+ title="Dog vs Cat Classifier",
36
+ description="Upload an image to classify whether it's a dog or a cat.",
37
+ allow_flagging='never'
38
+ )
39
+
40
+ # Launch the app
 
 
41
  app.launch()