import os import argparse import matplotlib.pyplot as plt from PIL import Image import tensorflow as tf from import Model from utils.glob import TARGET_IMG_SIZE from utils.glob import CLASS_LABELS import utils.data_manip as manip def classify(image_path: str, classifier_path: str, verbose: bool = False, return_original: bool = True) -> tuple: """ Uses a trained machine learning model to classify an image loaded from disk. :param image_path: Path to the image to be classified. :param classifier_path: Path to the classifier model to be used. :param verbose: Verbose output. :param return_original: Whether to return the original image or the processed image. :return: The original/processed image (PIL.image) and its classification (str). """ im_original = im_processed = manip.remove_transparency(im_original) im_processed = manip.resize_crop(im_processed, TARGET_IMG_SIZE, TARGET_IMG_SIZE) im_processed = manip.normalize_pixels(im_processed) im_processed = tf.expand_dims(im_processed, axis=0) model: Model = tf.keras.models.load_model(classifier_path) pred = model.predict(im_processed, verbose=1 if verbose else 0) pred_class_idx = tf.argmax(pred, axis=1).numpy()[0] pred_class_label = CLASS_LABELS[pred_class_idx] if return_original: return im_original, pred_class_label else: return im_processed, pred_class_label if __name__ == '__main__': ap = argparse.ArgumentParser() ap.add_argument('-f', '--file', required=True, help='the image to be classified') ap.add_argument('-c', '--classifier', default='models/clf-cnn', help='the machine learning model used for classification, defaults: models/clf-cnn') ap.add_argument('-g', '--gui', action='store_true', help='show classification result using GUI') ap.add_argument('-v', '--verbose-level', choices=['0', '1', '2'], default='0', help="verbose level, default: 0") args = vars(ap.parse_args()) verbose_level = int(args['verbose_level']) img = os.path.abspath(args['file']) clf = os.path.abspath(args['classifier']) image, predicted_label = classify(img, clf, False if verbose_level < 2 else True) if args['gui']: fig, ax = plt.subplots(1, 1, num='Flower Image Classifier') ax.imshow(image) ax.set_title( f'{predicted_label}', fontsize=12, weight='bold' ) ax.text( 0.5, -0.08, f'{os.path.relpath(img)}', horizontalalignment='center', verticalalignment='center_baseline', transform=ax.transAxes, fontsize=8, ) ax.axis('off') else: if verbose_level == 0: print(predicted_label) else: print( f'Image {os.path.basename(img)} is classified as "{predicted_label}" (model: "{os.path.basename(clf)}")' )