FashionMNIST / app.py
rufimelo's picture
sup
3243534
import os
import random
import cv2
import gradio as gr
import joblib
import numpy as np
import torch
from models.cnn import Classifier
from models.feed_forward import FeedForwardClassifier
simple_classifier = Classifier()
CNN_PATH = "models/classifier_cnn.pth"
simple_classifier.load_state_dict(torch.load(CNN_PATH))
feed_forward_classifier = FeedForwardClassifier(784)
FF_PATH = "models/classifier.pth"
feed_forward_classifier.load_state_dict(torch.load(FF_PATH))
# Required for the classifier
from sklearn.ensemble import RandomForestClassifier
RF_PATH = "models/fashion_mnist_rf_model.pkl"
rf_clf = joblib.load(RF_PATH)
# Required for the classifier
from sklearn.svm import SVC
SVM_PATH = "models/fashion_mnist_svm_model.pkl"
svm_clf = joblib.load(SVM_PATH)
SVM_PATH_RBF = "models/fashion_mnist_svm_model_rbf.pkl"
svm_clf_rbf = joblib.load(SVM_PATH_RBF)
# Required for the classifier
from sklearn.linear_model import LogisticRegression
LR_PATH = "models/fashion_mnist_lr_model.pkl"
lr_clf = joblib.load(LR_PATH)
# Required for the classifier
from sklearn.neighbors import KNeighborsClassifier
KNN_PATH = "models/fashion_mnist_knn_model.pkl"
knn_clf = joblib.load(KNN_PATH)
# Required for the classifier
from xgboost import XGBClassifier
XGB_PATH = "models/fashion_mnist_xgb_model.pkl"
xgb_clf = joblib.load(XGB_PATH)
LABELS = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
def classify(img: str):
# read image
img = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (28, 28))
img = img / 255.0
img = np.array(img).reshape(-1, 1, 28, 28)
img = torch.from_numpy(img).float()
# CNN Classifier
cnn_output = simple_classifier(img)
cnn_output = torch.nn.functional.softmax(cnn_output, dim=1)
cnn_output_pred = torch.argmax(cnn_output)
cnn_output_label = LABELS[cnn_output_pred.item()]
cnn_confidence = round(torch.max(cnn_output).item() * 100, 2)
cnn_output_str = (
f"{'CNN:':<35} {cnn_output_label:<15} with {cnn_confidence:.2f}% confidence"
)
# Feed Forward Classifier
feed_forward_output = feed_forward_classifier(img)
feed_forward_output = torch.nn.functional.softmax(feed_forward_output, dim=1)
feed_forward_output_pred = torch.argmax(feed_forward_output)
feed_forward_output_label = LABELS[feed_forward_output_pred.item()]
feed_forward_confidence = round(torch.max(feed_forward_output).item() * 100, 2)
feed_forward_output_str = f"{'Feed Forward:':<35} {feed_forward_output_label:<15} with {feed_forward_confidence:.2f}% confidence"
# XGBoost Classifier
xgb_output = xgb_clf.predict(img.reshape(1, -1))
xgb_output_label = LABELS[xgb_output[0]]
xgb_confidence = round(
float(np.max(xgb_clf.predict_proba(img.reshape(1, -1))[0])) * 100, 2
)
xgb_output_str = (
f"{'XGBoost:':<35} {xgb_output_label:<15} with {xgb_confidence:.2f}% confidence"
)
# Random Forest Classifier
rf_output = rf_clf.predict(img.reshape(1, -1))
rf_output_label = LABELS[rf_output[0]]
rf_output_str = f"{'Random Forest:':<35} {rf_output_label:<15}"
# SVM with Linear Kernel Classifier
svm_output = svm_clf.predict(img.reshape(1, -1))
svm_output_label = LABELS[svm_output[0]]
svm_output_str = f"{'SVM with Linear kernel:':<35} {svm_output_label:<15}"
# SVM with RBF Kernel Classifier
svm_output_rbf = svm_clf_rbf.predict(img.reshape(1, -1))
svm_output_label_rbf = LABELS[svm_output_rbf[0]]
svm_output_str_rbf = f"{'SVM with RBF kernel:':<35} {svm_output_label_rbf:<15}"
# Logistic Regression Classifier
lr_output = lr_clf.predict(img.reshape(1, -1))
lr_output_label = LABELS[lr_output[0]]
lr_output_str = f"{'Logistic Regression:':<35} {lr_output_label:<15}"
# KNN Classifier
knn_output = knn_clf.predict(img.reshape(1, -1))
knn_output_label = LABELS[knn_output[0]]
knn_output_str = f"{'KNN:':<35} {knn_output_label:<15}"
# Combine output
output = (
cnn_output_str
+ "\n"
+ feed_forward_output_str
+ "\n"
+ xgb_output_str
+ "\n"
+ rf_output_str
+ "\n"
+ svm_output_str
+ "\n"
+ svm_output_str_rbf
+ "\n"
+ lr_output_str
+ "\n"
+ knn_output_str
)
return output
folder = "./images"
examples = []
for filename in os.listdir(folder):
img_path = os.path.join(folder, filename)
examples.append([img_path])
random.shuffle(examples)
iface = gr.Interface(
fn=classify,
title="Fashion MNIST Classifier - TAECAC @ FEUP",
description="Simple Proof of Concept.",
inputs=gr.Image(label="Image", type="filepath"),
outputs=gr.Textbox(label="Classification output"),
examples=examples,
examples_per_page=100,
theme=gr.themes.Soft(
primary_hue=gr.themes.colors.indigo,
secondary_hue=gr.themes.colors.gray,
neutral_hue=gr.themes.colors.slate,
font=["avenir"],
),
)
iface.launch()