Phizzly's picture
update Gradio input
b6f2207
raw
history blame
6.88 kB
import os
import re
from datetime import datetime
from PIL import Image
import numpy as np
import json
from torchvision import models
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from huggingface_hub import HfApi
import gradio as gr
HF_TOKEN = os.environ.get("HF_TOKEN")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_checkpoint(filepath):
"""Builds PyTorch Model from saved model
Returns built model
Arguments: string, filepath of saved PyTorch model
"""
# Load pretrained weights
weights = "IMAGENET1K_V1"
# Load model using pretrained weights
model = models.maxvit_t(weights=weights)
# Load checkpoint
checkpoint = torch.load(filepath, map_location=torch.device("cpu"))
# Get new classifier from checkpoint
new_classifier = checkpoint["classifier"]
# Replace pretrained model output classifier layer[5] with newly created classifier
model.classifier[5] = new_classifier
# Load model weights
model.load_state_dict(checkpoint["state_dict"])
# Load class to index mapping
model.class_to_idx = checkpoint["class_to_idx"]
return model
class Network(nn.Module):
def __init__(self, input_size, hidden_layers, output_size=102, drop_p=0.2):
"""Builds a feedforward network with arbitrary hidden layers.
Arguments
---------
input_size: integer, size of the input layer
output_size: integer, size of the output layer
hidden_layers: list of integers, the sizes of the hidden layers
drop_p: float, dropout probability
"""
super().__init__()
self.drop_p = drop_p
# Input to a hidden layer
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
# Add a variable number of more hidden layers
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
self.output = nn.Linear(hidden_layers[-1], output_size)
print(
f"\nNumber of layers: {len(self.hidden_layers)}"
f"\nNumber of units in layers:{hidden_layers}"
)
def forward(self, x):
"""Forward pass through the network, returns the output logits"""
for each in self.hidden_layers:
x = F.relu(each(x))
x = F.dropout(x, self.drop_p)
x = self.output(x)
return F.log_softmax(x, dim=1)
model = load_checkpoint("flower_inference_model.pth")
def process_image(pil_image):
"""Scales, crops, and normalizes a PIL image for a PyTorch model,
returns a Numpy array
Arguments
---------
image: path of the image to be processed
"""
inp = pil_image
img_exif = inp.getexif()
# Get image size
w, h = inp.size
# Create inference directory for prediction
os.makedirs("inference", exist_ok=True)
# Remove non alphanumeric characters
image_path = str(datetime.now())
image_path = re.sub(r"\W+", "_", image_path)
# Join to directory path
inf_image = os.path.join("inference", image_path + ".jpg")
# Use repo for inference
inp.save(inf_image, quality=95, keep=True, exif=img_exif)
HfApi().upload_file(
path_or_fileobj=inf_image,
path_in_repo=image_path,
repo_id="DanielPFlorian/flower-image-classifier",
repo_type="dataset",
token=HF_TOKEN,
)
# resize image so shortest side is 256 preserving aspect ratio
if w > h:
inp.thumbnail((10000, 256))
elif h > w:
inp.thumbnail((256, 10000))
else:
inp.thumbnail((256, 256))
# crop center 224x224
w, h = inp.size
left = (w - 224) // 2
top = (h - 224) // 2
right = (w + 224) // 2
bottom = (h + 224) // 2
image = inp.crop((left, top, right, bottom))
# Convert pil image to numpy array and scale color channels to [0, 1]
np_image = np.array(image) / 255
# Normalize image
mean = np.array([0.485, 0.456, 0.406]) # Mean
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
np_image = (np_image - mean) / std
# Move color channels to first dimension
np_image = np_image.transpose((2, 0, 1))
return np_image
# Category to name mapping
with open("cat_to_name.json", "r") as f:
cat_to_name = json.load(f)
def predict(pil_image, model=model, category_names=cat_to_name, topk=5):
"""Predict the class (or classes) of an image using a trained deep learning model.
Arguments
---------
image_path: path of the image to be processed
model: model to be used for prediction
topk: number of top predicted classes to return
"""
# Process image function
image = process_image(pil_image)
# Convert image to float tensor with batch size of 1
image = torch.as_tensor(image).view((1, 3, 224, 224)).float()
# Set model to evaluation mode/ inference mode
model.eval()
# Turn off gradients to speed up this part
with torch.no_grad():
# Forward Pass. Ouputs log probabilities of classes
log_ps = model.forward(image)
# Exponential of log probabilities for each class
ps = torch.exp(log_ps)
# Get top k predictions. Returns probabilities and class indexes
top_probs, idx = ps.topk(topk, dim=1)
# Convert tensors to lists. Index[0] returns unnested List
top_probs, idx = top_probs.tolist()[0], idx.tolist()[0]
# Convert top_probs to percentages
percentages = [round(prob * 100.00, 2) for prob in top_probs]
# Converts class_labels:indexes to indexes:class_labels
idx_to_class = {val: key for key, val in model.class_to_idx.items()}
# get class labels from indexes
top_labels = [idx_to_class[lab] for lab in idx]
# Get names from labels
if category_names:
top_labels = [category_names[str(lab)] for lab in top_labels]
# Plot Functionality
image = pil_image
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.imshow(image)
ax1.axis("off")
ax2.barh(np.arange(len(top_labels)), percentages)
asp = np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0]
ax2.set_aspect(asp)
ax2.set_yticks(np.arange(len(top_labels)))
ax2.set_yticklabels(top_labels)
ax2.invert_yaxis()
ax2.xaxis.set_major_formatter(ticker.PercentFormatter())
plt.tight_layout()
ax2.set_title("Class Probability")
plt.show()
return fig
# Gradio Interface# Gradio Interface
gr.Interface(
predict,
inputs=gr.Image(type="pil", label="Upload a flower image"),
outputs=gr.Plot(label="Plot"),
title="What kind of flower is this?",
).launch()