|
import os |
|
import argparse |
|
|
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import tensorflow as tf |
|
from keras.engine.training import Model |
|
|
|
from utils.glob import TARGET_IMG_SIZE |
|
from utils.glob import CLASS_LABELS |
|
import utils.data_manip as manip |
|
|
|
|
|
def classify(image_path: str, classifier_path: str, verbose: bool = False, return_original: bool = True) -> tuple: |
|
""" |
|
Uses a trained machine learning model to classify an image loaded from disk. |
|
|
|
:param image_path: Path to the image to be classified. |
|
:param classifier_path: Path to the classifier model to be used. |
|
:param verbose: Verbose output. |
|
:param return_original: Whether to return the original image or the processed image. |
|
:return: The original/processed image (PIL.image) and its classification (str). |
|
""" |
|
|
|
im_original = Image.open(image_path) |
|
im_processed = manip.remove_transparency(im_original) |
|
im_processed = manip.resize_crop(im_processed, TARGET_IMG_SIZE, TARGET_IMG_SIZE) |
|
im_processed = manip.normalize_pixels(im_processed) |
|
im_processed = tf.expand_dims(im_processed, axis=0) |
|
|
|
model: Model = tf.keras.models.load_model(classifier_path) |
|
pred = model.predict(im_processed, verbose=1 if verbose else 0) |
|
|
|
pred_class_idx = tf.argmax(pred, axis=1).numpy()[0] |
|
pred_class_label = CLASS_LABELS[pred_class_idx] |
|
|
|
if return_original: |
|
return im_original, pred_class_label |
|
else: |
|
return im_processed, pred_class_label |
|
|
|
|
|
if __name__ == '__main__': |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument('-f', '--file', required=True, help='the image to be classified') |
|
ap.add_argument('-c', '--classifier', default='models/clf-cnn', help='the machine learning model used for classification, defaults: models/clf-cnn') |
|
ap.add_argument('-g', '--gui', action='store_true', help='show classification result using GUI') |
|
ap.add_argument('-v', '--verbose-level', choices=['0', '1', '2'], default='0', help="verbose level, default: 0") |
|
args = vars(ap.parse_args()) |
|
verbose_level = int(args['verbose_level']) |
|
|
|
img = os.path.abspath(args['file']) |
|
clf = os.path.abspath(args['classifier']) |
|
image, predicted_label = classify(img, clf, False if verbose_level < 2 else True) |
|
|
|
if args['gui']: |
|
fig, ax = plt.subplots(1, 1, num='Flower Image Classifier') |
|
ax.imshow(image) |
|
ax.set_title( |
|
f'{predicted_label}', |
|
fontsize=12, |
|
weight='bold' |
|
) |
|
ax.text( |
|
0.5, -0.08, f'{os.path.relpath(img)}', |
|
horizontalalignment='center', |
|
verticalalignment='center_baseline', |
|
transform=ax.transAxes, |
|
fontsize=8, |
|
) |
|
ax.axis('off') |
|
plt.show() |
|
else: |
|
if verbose_level == 0: |
|
print(predicted_label) |
|
else: |
|
print( |
|
f'Image {os.path.basename(img)} is classified as "{predicted_label}" (model: "{os.path.basename(clf)}")' |
|
) |
|
|