Spaces:
Sleeping
Sleeping
ilyi
commited on
Commit
·
da43dfa
1
Parent(s):
02e992d
Refine Gradio App.
Browse files- app.py +28 -28
- src/__init__.py +3 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/config.cpython-311.pyc +0 -0
- src/__pycache__/predict.cpython-311.pyc +0 -0
- src/__pycache__/utils.cpython-311.pyc +0 -0
- src/assets/header.md +8 -0
- src/config.py +21 -0
- src/models/__init__.py +1 -0
- src/models/__pycache__/__init__.cpython-311.pyc +0 -0
- src/models/__pycache__/mobilenet.cpython-311.pyc +0 -0
- src/models/mobilenet.py +17 -0
- {models → src/models}/mobilenetv3_large_100_224.pt +0 -0
- src/predict.py +32 -0
- src/utils.py +4 -0
app.py
CHANGED
@@ -1,34 +1,34 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
from
|
4 |
-
from
|
5 |
-
from PIL import Image
|
6 |
-
import requests
|
7 |
-
|
8 |
-
# Load a pre-trained model
|
9 |
-
model = models.resnet50(pretrained=True)
|
10 |
-
model.eval()
|
11 |
-
|
12 |
-
# Preprocess the input image
|
13 |
-
def preprocess(image):
|
14 |
-
transform = transforms.Compose([
|
15 |
-
transforms.Resize(256),
|
16 |
-
transforms.CenterCrop(224),
|
17 |
-
transforms.ToTensor(),
|
18 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
19 |
-
])
|
20 |
-
return transform(image).unsqueeze(0)
|
21 |
-
|
22 |
-
|
23 |
-
def predict(image):
|
24 |
-
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
25 |
-
preprocessed_img = preprocess(image)
|
26 |
-
outputs = model(preprocessed_img)
|
27 |
-
_, predicted = torch.max(outputs, 1)
|
28 |
-
return predicted.item()
|
29 |
|
30 |
# Create a Gradio interface
|
31 |
image = gr.inputs.Image()
|
32 |
label = gr.outputs.Label(num_top_classes=1)
|
33 |
|
34 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from src.config import Config
|
3 |
+
from src.predict import predict
|
4 |
+
from gradio.components import Image, Label, Button
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Create a Gradio interface
|
7 |
image = gr.inputs.Image()
|
8 |
label = gr.outputs.Label(num_top_classes=1)
|
9 |
|
10 |
+
with gr.Blocks(theme='gradio/monochrome') as demo:
|
11 |
+
with gr.Row():
|
12 |
+
with open('src/assets/header.md', 'r', encoding='utf-8') as f:
|
13 |
+
header = gr.Markdown(f.read())
|
14 |
+
|
15 |
+
with gr.Row(align="center", align_items="center"):
|
16 |
+
input_image = Image(
|
17 |
+
label="Image",
|
18 |
+
interactive=True
|
19 |
+
|
20 |
+
)
|
21 |
+
output = Label(
|
22 |
+
num_top_classes=5,
|
23 |
+
label="Prediction"
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
with gr.Row():
|
28 |
+
predict_button = Button(value="Predict Animal", label="Predict", info="Click to make a prediction.")
|
29 |
+
predict_button.click(fn=predict, inputs=input_image, outputs=output)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
demo.launch()
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .config import Config
|
2 |
+
from .utils import *
|
3 |
+
from .predict import predict
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (322 Bytes). View file
|
|
src/__pycache__/config.cpython-311.pyc
ADDED
Binary file (2.27 kB). View file
|
|
src/__pycache__/predict.cpython-311.pyc
ADDED
Binary file (2.6 kB). View file
|
|
src/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (447 Bytes). View file
|
|
src/assets/header.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🦁 Wildlife Classification 🐘
|
2 |
+
|
3 |
+
This simple application uses a pre-trained model to classify images of wildlife animals. This project is driven by the ambition to protect wildlife and to raise awareness of animal mistreatment.
|
4 |
+
### > *Species-appropriate is only the freedom.* (R. M. Lehmann)
|
5 |
+
|
6 |
+
---
|
7 |
+
Further information can be found in the github repository [here](https://github.com/ilyii/wildlife).
|
8 |
+
|
src/config.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
from src.models import MobileNetV3Large
|
5 |
+
|
6 |
+
|
7 |
+
class Config:
|
8 |
+
# - - - DATA - - - #
|
9 |
+
classes = [
|
10 |
+
'african buffalo', 'african elephant', 'alligator', 'alligators', 'alpine ibex', 'anaconda', 'antelope', 'badger', 'bat', 'bear', 'beavers', 'bee', 'beetle', 'bighorn sheep', 'bison', 'black bear', 'blue whale', 'boar', 'bottlenose dolphin', 'bowhead whale', 'butterfly', 'capybara', 'cat', 'caterpillar', 'catfish', 'chamois', 'cheetah', 'chicken', 'chimpanzee', 'clownfish', 'cockroach', 'common dolphin', 'cow', 'coyote', 'crab', 'crayfish', 'crocodile', 'crocodiles', 'crow', 'deer', 'dog', 'dolphin', 'donkey', 'dragonfly', 'duck', 'eagle', 'elephant', 'elk', 'false killer whale', 'flamingo', 'fly', 'fox', 'frogs', 'gazelle', 'geese', 'giraffe', 'goat', 'golden eagle', 'goldfish', 'goose', 'gorilla', 'grasshopper', 'gray whale', 'hamster', 'hare', 'hedgehog', 'himalayan tahr', 'hippopotamus', 'hornbill', 'horse', 'hummingbird', 'humpback dolphin', 'humpback whale', 'hyena', 'irrawaddy dolphin', 'jaguar', 'jellyfish', 'kangaroo', 'koala', 'ladybugs', 'leopard', 'lion', 'lizard', 'lobster', 'lynx', 'macaw', 'marmot', 'minke whale', 'monkey', 'moose', 'mosquito', 'moth', 'mountain goat', 'mouse', 'octopus', 'okapi', 'orangutan', 'orca', 'ostrich', 'otter', 'owl', 'ox', 'oyster', 'panda', 'panther', 'parrot', 'pelecaniformes', 'penguin', 'pig', 'pigeon', 'pika', 'poison dart frog', 'porcupine', 'possum', 'prairie dog', 'rabbit', 'raccoon', 'rat', 'red panda', 'reindeer', 'rhino', 'rhinoceros', 'right whale', 'rocky mountain elk', 'salmon', 'sandpiper', 'seahorse', 'seal', 'sei whale', 'shark', 'sheep', 'sloth', 'snails', 'snake', 'snow leopard', 'sparrow', 'sperm whale', 'spotted hyena', 'squid', 'squirrel', 'starfish', 'swan', 'tapir', 'tiger', 'toads', 'toucan', 'turkey', 'turtle', 'warthog', 'water snakes', 'whale', 'wildebeest', 'wolf', 'wolverine', 'wombat', 'woodpecker', 'zebra'
|
11 |
+
]
|
12 |
+
|
13 |
+
num_classes = len(classes)
|
14 |
+
|
15 |
+
# - - - PARAMS - - - #
|
16 |
+
modelpath = './src/models/mobilenetv3_large_100_224.pt'
|
17 |
+
imgsize = 224
|
18 |
+
model = MobileNetV3Large(modelpath, num_classes)
|
19 |
+
|
20 |
+
|
21 |
+
|
src/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .mobilenet import *
|
src/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (234 Bytes). View file
|
|
src/models/__pycache__/mobilenet.cpython-311.pyc
ADDED
Binary file (1.75 kB). View file
|
|
src/models/mobilenet.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class MobileNetV3Large(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, ckpt, num_classes) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.model = torchvision.models.mobilenet_v3_large(pretrained=True)
|
11 |
+
self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
|
12 |
+
self.model.load_state_dict(torch.load(ckpt))
|
13 |
+
self.model.to('cpu')
|
14 |
+
self.model.eval()
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return self.model(x)
|
{models → src/models}/mobilenetv3_large_100_224.pt
RENAMED
File without changes
|
src/predict.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision import models
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from src.config import Config
|
8 |
+
from src.utils import id2label
|
9 |
+
|
10 |
+
|
11 |
+
def predict(image):
|
12 |
+
if image is None:
|
13 |
+
raise gr.Error("No image found. Please upload an image to predict.")
|
14 |
+
|
15 |
+
transform = transforms.Compose([
|
16 |
+
transforms.Resize(int(Config.imgsize * 1.143)),
|
17 |
+
transforms.CenterCrop(Config.imgsize),
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
20 |
+
])
|
21 |
+
model = Config.model
|
22 |
+
try:
|
23 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
24 |
+
except:
|
25 |
+
raise gr.Error("Image could not be converted to RGB. Please try another image.")
|
26 |
+
preprocessed_img = transform(image).unsqueeze(0)
|
27 |
+
outputs = model(preprocessed_img)
|
28 |
+
outputs = torch.softmax(outputs, dim=1)
|
29 |
+
preds = torch.topk(outputs, 5)
|
30 |
+
labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)}
|
31 |
+
return labels
|
32 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config import Config
|
2 |
+
|
3 |
+
def id2label(id:int):
|
4 |
+
return Config.classes[id]
|