File size: 2,990 Bytes
af1bda1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import argparse

import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
from keras.engine.training import Model

from utils.glob import TARGET_IMG_SIZE
from utils.glob import CLASS_LABELS
import utils.data_manip as manip


def classify(image_path: str, classifier_path: str, verbose: bool = False, return_original: bool = True) -> tuple:
    """
    Uses a trained machine learning model to classify an image loaded from disk.

    :param image_path: Path to the image to be classified.
    :param classifier_path: Path to the classifier model to be used.
    :param verbose: Verbose output.
    :param return_original: Whether to return the original image or the processed image.
    :return: The original/processed image (PIL.image) and its classification (str).
    """

    im_original = Image.open(image_path)
    im_processed = manip.remove_transparency(im_original)
    im_processed = manip.resize_crop(im_processed, TARGET_IMG_SIZE, TARGET_IMG_SIZE)
    im_processed = manip.normalize_pixels(im_processed)
    im_processed = tf.expand_dims(im_processed, axis=0)

    model: Model = tf.keras.models.load_model(classifier_path)
    pred = model.predict(im_processed, verbose=1 if verbose else 0)

    pred_class_idx = tf.argmax(pred, axis=1).numpy()[0]
    pred_class_label = CLASS_LABELS[pred_class_idx]

    if return_original:
        return im_original, pred_class_label
    else:
        return im_processed, pred_class_label


if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('-f', '--file', required=True, help='the image to be classified')
    ap.add_argument('-c', '--classifier', default='models/clf-cnn', help='the machine learning model used for classification, defaults: models/clf-cnn')
    ap.add_argument('-g', '--gui', action='store_true', help='show classification result using GUI')
    ap.add_argument('-v', '--verbose-level', choices=['0', '1', '2'], default='0', help="verbose level, default: 0")
    args = vars(ap.parse_args())
    verbose_level = int(args['verbose_level'])

    img = os.path.abspath(args['file'])
    clf = os.path.abspath(args['classifier'])
    image, predicted_label = classify(img, clf, False if verbose_level < 2 else True)

    if args['gui']:
        fig, ax = plt.subplots(1, 1, num='Flower Image Classifier')
        ax.imshow(image)
        ax.set_title(
            f'{predicted_label}',
            fontsize=12,
            weight='bold'
        )
        ax.text(
            0.5, -0.08, f'{os.path.relpath(img)}',
            horizontalalignment='center',
            verticalalignment='center_baseline',
            transform=ax.transAxes,
            fontsize=8,
        )
        ax.axis('off')
        plt.show()
    else:
        if verbose_level == 0:
            print(predicted_label)
        else:
            print(
                f'Image {os.path.basename(img)} is classified as "{predicted_label}" (model: "{os.path.basename(clf)}")'
            )