Spaces:
Sleeping
Sleeping
import os | |
import re | |
from datetime import datetime | |
from PIL import Image | |
import numpy as np | |
import json | |
from torchvision import models | |
import torch.nn.functional as F | |
from torch import nn | |
import torch | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
from huggingface_hub import HfApi | |
import gradio as gr | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def load_checkpoint(filepath): | |
"""Builds PyTorch Model from saved model | |
Returns built model | |
Arguments: string, filepath of saved PyTorch model | |
""" | |
# Load pretrained weights | |
weights = "IMAGENET1K_V1" | |
# Load model using pretrained weights | |
model = models.maxvit_t(weights=weights) | |
# Load checkpoint | |
checkpoint = torch.load(filepath, map_location=torch.device("cpu")) | |
# Get new classifier from checkpoint | |
new_classifier = checkpoint["classifier"] | |
# Replace pretrained model output classifier layer[5] with newly created classifier | |
model.classifier[5] = new_classifier | |
# Load model weights | |
model.load_state_dict(checkpoint["state_dict"]) | |
# Load class to index mapping | |
model.class_to_idx = checkpoint["class_to_idx"] | |
return model | |
class Network(nn.Module): | |
def __init__(self, input_size, hidden_layers, output_size=102, drop_p=0.2): | |
"""Builds a feedforward network with arbitrary hidden layers. | |
Arguments | |
--------- | |
input_size: integer, size of the input layer | |
output_size: integer, size of the output layer | |
hidden_layers: list of integers, the sizes of the hidden layers | |
drop_p: float, dropout probability | |
""" | |
super().__init__() | |
self.drop_p = drop_p | |
# Input to a hidden layer | |
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])]) | |
# Add a variable number of more hidden layers | |
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:]) | |
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes]) | |
self.output = nn.Linear(hidden_layers[-1], output_size) | |
print( | |
f"\nNumber of layers: {len(self.hidden_layers)}" | |
f"\nNumber of units in layers:{hidden_layers}" | |
) | |
def forward(self, x): | |
"""Forward pass through the network, returns the output logits""" | |
for each in self.hidden_layers: | |
x = F.relu(each(x)) | |
x = F.dropout(x, self.drop_p) | |
x = self.output(x) | |
return F.log_softmax(x, dim=1) | |
model = load_checkpoint("flower_inference_model.pth") | |
def process_image(pil_image): | |
"""Scales, crops, and normalizes a PIL image for a PyTorch model, | |
returns a Numpy array | |
Arguments | |
--------- | |
image: path of the image to be processed | |
""" | |
inp = pil_image | |
img_exif = inp.getexif() | |
# Get image size | |
w, h = inp.size | |
# Create inference directory for prediction | |
os.makedirs("inference", exist_ok=True) | |
# Remove non alphanumeric characters | |
image_path = str(datetime.now()) | |
image_path = re.sub(r"\W+", "_", image_path) | |
# Join to directory path | |
inf_image = os.path.join("inference", image_path + ".jpg") | |
# Use repo for inference | |
inp.save(inf_image, quality=95, keep=True, exif=img_exif) | |
HfApi().upload_file( | |
path_or_fileobj=inf_image, | |
path_in_repo=image_path, | |
repo_id="DanielPFlorian/flower-image-classifier", | |
repo_type="dataset", | |
token=HF_TOKEN, | |
) | |
# resize image so shortest side is 256 preserving aspect ratio | |
if w > h: | |
inp.thumbnail((10000, 256)) | |
elif h > w: | |
inp.thumbnail((256, 10000)) | |
else: | |
inp.thumbnail((256, 256)) | |
# crop center 224x224 | |
w, h = inp.size | |
left = (w - 224) // 2 | |
top = (h - 224) // 2 | |
right = (w + 224) // 2 | |
bottom = (h + 224) // 2 | |
image = inp.crop((left, top, right, bottom)) | |
# Convert pil image to numpy array and scale color channels to [0, 1] | |
np_image = np.array(image) / 255 | |
# Normalize image | |
mean = np.array([0.485, 0.456, 0.406]) # Mean | |
std = np.array([0.229, 0.224, 0.225]) # Standard deviation | |
np_image = (np_image - mean) / std | |
# Move color channels to first dimension | |
np_image = np_image.transpose((2, 0, 1)) | |
return np_image | |
# Category to name mapping | |
with open("cat_to_name.json", "r") as f: | |
cat_to_name = json.load(f) | |
def predict(pil_image, model=model, category_names=cat_to_name, topk=5): | |
"""Predict the class (or classes) of an image using a trained deep learning model. | |
Arguments | |
--------- | |
image_path: path of the image to be processed | |
model: model to be used for prediction | |
topk: number of top predicted classes to return | |
""" | |
# Process image function | |
image = process_image(pil_image) | |
# Convert image to float tensor with batch size of 1 | |
image = torch.as_tensor(image).view((1, 3, 224, 224)).float() | |
# Set model to evaluation mode/ inference mode | |
model.eval() | |
# Turn off gradients to speed up this part | |
with torch.no_grad(): | |
# Forward Pass. Ouputs log probabilities of classes | |
log_ps = model.forward(image) | |
# Exponential of log probabilities for each class | |
ps = torch.exp(log_ps) | |
# Get top k predictions. Returns probabilities and class indexes | |
top_probs, idx = ps.topk(topk, dim=1) | |
# Convert tensors to lists. Index[0] returns unnested List | |
top_probs, idx = top_probs.tolist()[0], idx.tolist()[0] | |
# Convert top_probs to percentages | |
percentages = [round(prob * 100.00, 2) for prob in top_probs] | |
# Converts class_labels:indexes to indexes:class_labels | |
idx_to_class = {val: key for key, val in model.class_to_idx.items()} | |
# get class labels from indexes | |
top_labels = [idx_to_class[lab] for lab in idx] | |
# Get names from labels | |
if category_names: | |
top_labels = [category_names[str(lab)] for lab in top_labels] | |
# Plot Functionality | |
image = pil_image | |
fig, (ax1, ax2) = plt.subplots(ncols=2) | |
ax1.imshow(image) | |
ax1.axis("off") | |
ax2.barh(np.arange(len(top_labels)), percentages) | |
asp = np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0] | |
ax2.set_aspect(asp) | |
ax2.set_yticks(np.arange(len(top_labels))) | |
ax2.set_yticklabels(top_labels) | |
ax2.invert_yaxis() | |
ax2.xaxis.set_major_formatter(ticker.PercentFormatter()) | |
plt.tight_layout() | |
ax2.set_title("Class Probability") | |
plt.show() | |
return fig | |
# Gradio Interface# Gradio Interface | |
gr.Interface( | |
predict, | |
inputs=gr.Image(type="pil", label="Upload a flower image"), | |
outputs=gr.Plot(label="Plot"), | |
title="What kind of flower is this?", | |
).launch() | |