ilyi commited on
Commit
da43dfa
·
1 Parent(s): 02e992d

Refine Gradio App.

Browse files
app.py CHANGED
@@ -1,34 +1,34 @@
1
  import gradio as gr
2
- import torch
3
- from torchvision import transforms
4
- from torchvision import models
5
- from PIL import Image
6
- import requests
7
-
8
- # Load a pre-trained model
9
- model = models.resnet50(pretrained=True)
10
- model.eval()
11
-
12
- # Preprocess the input image
13
- def preprocess(image):
14
- transform = transforms.Compose([
15
- transforms.Resize(256),
16
- transforms.CenterCrop(224),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
- ])
20
- return transform(image).unsqueeze(0)
21
-
22
-
23
- def predict(image):
24
- image = Image.fromarray(image.astype('uint8'), 'RGB')
25
- preprocessed_img = preprocess(image)
26
- outputs = model(preprocessed_img)
27
- _, predicted = torch.max(outputs, 1)
28
- return predicted.item()
29
 
30
  # Create a Gradio interface
31
  image = gr.inputs.Image()
32
  label = gr.outputs.Label(num_top_classes=1)
33
 
34
- gr.Interface(predict, image, label, capture_session=True).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from src.config import Config
3
+ from src.predict import predict
4
+ from gradio.components import Image, Label, Button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Create a Gradio interface
7
  image = gr.inputs.Image()
8
  label = gr.outputs.Label(num_top_classes=1)
9
 
10
+ with gr.Blocks(theme='gradio/monochrome') as demo:
11
+ with gr.Row():
12
+ with open('src/assets/header.md', 'r', encoding='utf-8') as f:
13
+ header = gr.Markdown(f.read())
14
+
15
+ with gr.Row(align="center", align_items="center"):
16
+ input_image = Image(
17
+ label="Image",
18
+ interactive=True
19
+
20
+ )
21
+ output = Label(
22
+ num_top_classes=5,
23
+ label="Prediction"
24
+ )
25
+
26
+
27
+ with gr.Row():
28
+ predict_button = Button(value="Predict Animal", label="Predict", info="Click to make a prediction.")
29
+ predict_button.click(fn=predict, inputs=input_image, outputs=output)
30
+
31
+
32
+
33
+
34
+ demo.launch()
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import Config
2
+ from .utils import *
3
+ from .predict import predict
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (322 Bytes). View file
 
src/__pycache__/config.cpython-311.pyc ADDED
Binary file (2.27 kB). View file
 
src/__pycache__/predict.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (447 Bytes). View file
 
src/assets/header.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # 🦁 Wildlife Classification 🐘
2
+
3
+ This simple application uses a pre-trained model to classify images of wildlife animals. This project is driven by the ambition to protect wildlife and to raise awareness of animal mistreatment.
4
+ ### > *Species-appropriate is only the freedom.* (R. M. Lehmann)
5
+
6
+ ---
7
+ Further information can be found in the github repository [here](https://github.com/ilyii/wildlife).
8
+
src/config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from src.models import MobileNetV3Large
5
+
6
+
7
+ class Config:
8
+ # - - - DATA - - - #
9
+ classes = [
10
+ 'african buffalo', 'african elephant', 'alligator', 'alligators', 'alpine ibex', 'anaconda', 'antelope', 'badger', 'bat', 'bear', 'beavers', 'bee', 'beetle', 'bighorn sheep', 'bison', 'black bear', 'blue whale', 'boar', 'bottlenose dolphin', 'bowhead whale', 'butterfly', 'capybara', 'cat', 'caterpillar', 'catfish', 'chamois', 'cheetah', 'chicken', 'chimpanzee', 'clownfish', 'cockroach', 'common dolphin', 'cow', 'coyote', 'crab', 'crayfish', 'crocodile', 'crocodiles', 'crow', 'deer', 'dog', 'dolphin', 'donkey', 'dragonfly', 'duck', 'eagle', 'elephant', 'elk', 'false killer whale', 'flamingo', 'fly', 'fox', 'frogs', 'gazelle', 'geese', 'giraffe', 'goat', 'golden eagle', 'goldfish', 'goose', 'gorilla', 'grasshopper', 'gray whale', 'hamster', 'hare', 'hedgehog', 'himalayan tahr', 'hippopotamus', 'hornbill', 'horse', 'hummingbird', 'humpback dolphin', 'humpback whale', 'hyena', 'irrawaddy dolphin', 'jaguar', 'jellyfish', 'kangaroo', 'koala', 'ladybugs', 'leopard', 'lion', 'lizard', 'lobster', 'lynx', 'macaw', 'marmot', 'minke whale', 'monkey', 'moose', 'mosquito', 'moth', 'mountain goat', 'mouse', 'octopus', 'okapi', 'orangutan', 'orca', 'ostrich', 'otter', 'owl', 'ox', 'oyster', 'panda', 'panther', 'parrot', 'pelecaniformes', 'penguin', 'pig', 'pigeon', 'pika', 'poison dart frog', 'porcupine', 'possum', 'prairie dog', 'rabbit', 'raccoon', 'rat', 'red panda', 'reindeer', 'rhino', 'rhinoceros', 'right whale', 'rocky mountain elk', 'salmon', 'sandpiper', 'seahorse', 'seal', 'sei whale', 'shark', 'sheep', 'sloth', 'snails', 'snake', 'snow leopard', 'sparrow', 'sperm whale', 'spotted hyena', 'squid', 'squirrel', 'starfish', 'swan', 'tapir', 'tiger', 'toads', 'toucan', 'turkey', 'turtle', 'warthog', 'water snakes', 'whale', 'wildebeest', 'wolf', 'wolverine', 'wombat', 'woodpecker', 'zebra'
11
+ ]
12
+
13
+ num_classes = len(classes)
14
+
15
+ # - - - PARAMS - - - #
16
+ modelpath = './src/models/mobilenetv3_large_100_224.pt'
17
+ imgsize = 224
18
+ model = MobileNetV3Large(modelpath, num_classes)
19
+
20
+
21
+
src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mobilenet import *
src/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (234 Bytes). View file
 
src/models/__pycache__/mobilenet.cpython-311.pyc ADDED
Binary file (1.75 kB). View file
 
src/models/mobilenet.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+
5
+
6
+ class MobileNetV3Large(nn.Module):
7
+
8
+ def __init__(self, ckpt, num_classes) -> None:
9
+ super().__init__()
10
+ self.model = torchvision.models.mobilenet_v3_large(pretrained=True)
11
+ self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
12
+ self.model.load_state_dict(torch.load(ckpt))
13
+ self.model.to('cpu')
14
+ self.model.eval()
15
+
16
+ def forward(self, x):
17
+ return self.model(x)
{models → src/models}/mobilenetv3_large_100_224.pt RENAMED
File without changes
src/predict.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from torchvision import transforms
4
+ from torchvision import models
5
+ from PIL import Image
6
+
7
+ from src.config import Config
8
+ from src.utils import id2label
9
+
10
+
11
+ def predict(image):
12
+ if image is None:
13
+ raise gr.Error("No image found. Please upload an image to predict.")
14
+
15
+ transform = transforms.Compose([
16
+ transforms.Resize(int(Config.imgsize * 1.143)),
17
+ transforms.CenterCrop(Config.imgsize),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20
+ ])
21
+ model = Config.model
22
+ try:
23
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
24
+ except:
25
+ raise gr.Error("Image could not be converted to RGB. Please try another image.")
26
+ preprocessed_img = transform(image).unsqueeze(0)
27
+ outputs = model(preprocessed_img)
28
+ outputs = torch.softmax(outputs, dim=1)
29
+ preds = torch.topk(outputs, 5)
30
+ labels = {id2label(preds.indices[0][i].item()).capitalize(): preds.values[0][i].item() for i in range(5)}
31
+ return labels
32
+
src/utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from src.config import Config
2
+
3
+ def id2label(id:int):
4
+ return Config.classes[id]