ilyi commited on
Commit
a97d1bb
·
1 Parent(s): 9b4c713

Classification v1.

Browse files
Files changed (2) hide show
  1. app.py +34 -0
  2. models/mobilenetv3_large_100_224.pt +3 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from torchvision import models
5
+ from PIL import Image
6
+ import requests
7
+
8
+ # Load a pre-trained model
9
+ model = models.resnet50(pretrained=True)
10
+ model.eval()
11
+
12
+ # Preprocess the input image
13
+ def preprocess(image):
14
+ transform = transforms.Compose([
15
+ transforms.Resize(256),
16
+ transforms.CenterCrop(224),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+ return transform(image).unsqueeze(0)
21
+
22
+
23
+ def predict(image):
24
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
25
+ preprocessed_img = preprocess(image)
26
+ outputs = model(preprocessed_img)
27
+ _, predicted = torch.max(outputs, 1)
28
+ return predicted.item()
29
+
30
+ # Create a Gradio interface
31
+ image = gr.inputs.Image()
32
+ label = gr.outputs.Label(num_top_classes=1)
33
+
34
+ gr.Interface(predict, image, label, capture_session=True).launch()
models/mobilenetv3_large_100_224.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b8037bebe37c941a2cc49bb9590a2a5b1cb66a2650752db66c0d430dd3ef8fd
3
+ size 17822949