File size: 2,990 Bytes
af1bda1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import argparse
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from keras.engine.training import Model
from utils.glob import TARGET_IMG_SIZE
from utils.glob import CLASS_LABELS
import utils.data_manip as manip
def classify(image_path: str, classifier_path: str, verbose: bool = False, return_original: bool = True) -> tuple:
"""
Uses a trained machine learning model to classify an image loaded from disk.
:param image_path: Path to the image to be classified.
:param classifier_path: Path to the classifier model to be used.
:param verbose: Verbose output.
:param return_original: Whether to return the original image or the processed image.
:return: The original/processed image (PIL.image) and its classification (str).
"""
im_original = Image.open(image_path)
im_processed = manip.remove_transparency(im_original)
im_processed = manip.resize_crop(im_processed, TARGET_IMG_SIZE, TARGET_IMG_SIZE)
im_processed = manip.normalize_pixels(im_processed)
im_processed = tf.expand_dims(im_processed, axis=0)
model: Model = tf.keras.models.load_model(classifier_path)
pred = model.predict(im_processed, verbose=1 if verbose else 0)
pred_class_idx = tf.argmax(pred, axis=1).numpy()[0]
pred_class_label = CLASS_LABELS[pred_class_idx]
if return_original:
return im_original, pred_class_label
else:
return im_processed, pred_class_label
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('-f', '--file', required=True, help='the image to be classified')
ap.add_argument('-c', '--classifier', default='models/clf-cnn', help='the machine learning model used for classification, defaults: models/clf-cnn')
ap.add_argument('-g', '--gui', action='store_true', help='show classification result using GUI')
ap.add_argument('-v', '--verbose-level', choices=['0', '1', '2'], default='0', help="verbose level, default: 0")
args = vars(ap.parse_args())
verbose_level = int(args['verbose_level'])
img = os.path.abspath(args['file'])
clf = os.path.abspath(args['classifier'])
image, predicted_label = classify(img, clf, False if verbose_level < 2 else True)
if args['gui']:
fig, ax = plt.subplots(1, 1, num='Flower Image Classifier')
ax.imshow(image)
ax.set_title(
f'{predicted_label}',
fontsize=12,
weight='bold'
)
ax.text(
0.5, -0.08, f'{os.path.relpath(img)}',
horizontalalignment='center',
verticalalignment='center_baseline',
transform=ax.transAxes,
fontsize=8,
)
ax.axis('off')
plt.show()
else:
if verbose_level == 0:
print(predicted_label)
else:
print(
f'Image {os.path.basename(img)} is classified as "{predicted_label}" (model: "{os.path.basename(clf)}")'
)
|