File size: 6,884 Bytes
0bce035
 
 
 
4d817ae
 
 
 
 
 
 
 
 
 
0bce035
 
7e7870c
0bce035
4d817ae
da2b35e
 
5dc340d
4d817ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dc340d
 
2b228e6
5dc340d
 
80b6d82
0bce035
 
4d817ae
0bce035
 
 
 
73bc643
e3c1c26
4d817ae
 
 
 
 
 
 
 
f14216f
 
4d817ae
 
80b6d82
4d817ae
 
80b6d82
4d817ae
80b6d82
4d817ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8f8c8
b33e657
5dc340d
 
 
 
80b6d82
4d817ae
 
 
 
 
 
 
 
80b6d82
4d817ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c342e41
4b15bf8
28fd108
4d817ae
 
 
 
 
 
 
 
 
 
 
 
c342e41
4d817ae
5dc340d
9d8f8c8
b6f2207
4d817ae
 
b6f2207
5dc340d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import re
from datetime import datetime

from PIL import Image
import numpy as np
import json
from torchvision import models
import torch.nn.functional as F
from torch import nn
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from huggingface_hub import HfApi
import gradio as gr


HF_TOKEN = os.environ.get("HF_TOKEN")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_checkpoint(filepath):
    """Builds PyTorch Model from saved model
    Returns built model

    Arguments: string, filepath of saved PyTorch model
    """

    # Load pretrained weights
    weights = "IMAGENET1K_V1"

    # Load model using pretrained weights
    model = models.maxvit_t(weights=weights)

    # Load checkpoint
    checkpoint = torch.load(filepath, map_location=torch.device("cpu"))

    # Get new classifier from checkpoint
    new_classifier = checkpoint["classifier"]

    # Replace pretrained model output classifier layer[5] with newly created classifier
    model.classifier[5] = new_classifier

    # Load model weights
    model.load_state_dict(checkpoint["state_dict"])

    # Load class to index mapping
    model.class_to_idx = checkpoint["class_to_idx"]

    return model


class Network(nn.Module):
    def __init__(self, input_size, hidden_layers, output_size=102, drop_p=0.2):
        """Builds a feedforward network with arbitrary hidden layers.

        Arguments
        ---------
        input_size: integer, size of the input layer
        output_size: integer, size of the output layer
        hidden_layers: list of integers, the sizes of the hidden layers
        drop_p: float, dropout probability
        """
        super().__init__()

        self.drop_p = drop_p

        # Input to a hidden layer
        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])

        # Add a variable number of more hidden layers
        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])

        self.output = nn.Linear(hidden_layers[-1], output_size)

        print(
            f"\nNumber of layers: {len(self.hidden_layers)}"
            f"\nNumber of units in layers:{hidden_layers}"
        )

    def forward(self, x):
        """Forward pass through the network, returns the output logits"""

        for each in self.hidden_layers:
            x = F.relu(each(x))
            x = F.dropout(x, self.drop_p)
        x = self.output(x)

        return F.log_softmax(x, dim=1)


model = load_checkpoint("flower_inference_model.pth")


def process_image(pil_image):
    """Scales, crops, and normalizes a PIL image for a PyTorch model,
    returns a Numpy array

    Arguments
    ---------
    image: path of the image to be processed
    """
    inp = pil_image
    img_exif = inp.getexif()

    # Get image size
    w, h = inp.size

    # Create inference directory for prediction
    os.makedirs("inference", exist_ok=True)

    # Remove non alphanumeric characters
    image_path = str(datetime.now())
    image_path = re.sub(r"\W+", "_", image_path)

    # Join to directory path
    inf_image = os.path.join("inference", image_path + ".jpg")

    # Use repo for inference
    inp.save(inf_image, quality=95, keep=True, exif=img_exif)
    HfApi().upload_file(
        path_or_fileobj=inf_image,
        path_in_repo=image_path,
        repo_id="DanielPFlorian/flower-image-classifier",
        repo_type="dataset",
        token=HF_TOKEN,
    )

    # resize image so shortest side is 256 preserving aspect ratio
    if w > h:
        inp.thumbnail((10000, 256))
    elif h > w:
        inp.thumbnail((256, 10000))
    else:
        inp.thumbnail((256, 256))

    # crop center 224x224
    w, h = inp.size
    left = (w - 224) // 2
    top = (h - 224) // 2
    right = (w + 224) // 2
    bottom = (h + 224) // 2
    image = inp.crop((left, top, right, bottom))

    # Convert pil image to numpy array and scale color channels to [0, 1]
    np_image = np.array(image) / 255

    # Normalize image
    mean = np.array([0.485, 0.456, 0.406])  # Mean
    std = np.array([0.229, 0.224, 0.225])  # Standard deviation
    np_image = (np_image - mean) / std

    # Move color channels to first dimension
    np_image = np_image.transpose((2, 0, 1))

    return np_image


# Category to name mapping
with open("cat_to_name.json", "r") as f:
    cat_to_name = json.load(f)


def predict(pil_image, model=model, category_names=cat_to_name, topk=5):
    """Predict the class (or classes) of an image using a trained deep learning model.
    Arguments
    ---------
    image_path: path of the image to be processed
    model: model to be used for prediction
    topk: number of top predicted classes to return
    """
    # Process image function
    image = process_image(pil_image)

    # Convert image to float tensor with batch size of 1
    image = torch.as_tensor(image).view((1, 3, 224, 224)).float()

    # Set model to evaluation mode/ inference mode
    model.eval()

    # Turn off gradients to speed up this part
    with torch.no_grad():
        # Forward Pass. Ouputs log probabilities of classes
        log_ps = model.forward(image)

        # Exponential of log probabilities for each class
        ps = torch.exp(log_ps)

        # Get top k predictions. Returns probabilities and class indexes
        top_probs, idx = ps.topk(topk, dim=1)

        # Convert tensors to lists. Index[0] returns unnested List
        top_probs, idx = top_probs.tolist()[0], idx.tolist()[0]

        # Convert top_probs to percentages
        percentages = [round(prob * 100.00, 2) for prob in top_probs]

        # Converts class_labels:indexes to indexes:class_labels
        idx_to_class = {val: key for key, val in model.class_to_idx.items()}

        # get class labels from indexes
        top_labels = [idx_to_class[lab] for lab in idx]

        # Get names from labels
        if category_names:
            top_labels = [category_names[str(lab)] for lab in top_labels]

        # Plot Functionality

        image = pil_image
        fig, (ax1, ax2) = plt.subplots(ncols=2)
        ax1.imshow(image)
        ax1.axis("off")
        ax2.barh(np.arange(len(top_labels)), percentages)
        asp = np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0]
        ax2.set_aspect(asp)
        ax2.set_yticks(np.arange(len(top_labels)))
        ax2.set_yticklabels(top_labels)
        ax2.invert_yaxis()
        ax2.xaxis.set_major_formatter(ticker.PercentFormatter())
        plt.tight_layout()
        ax2.set_title("Class Probability")
        plt.show()

        return fig


# Gradio Interface# Gradio Interface
gr.Interface(
    predict,
    inputs=gr.Image(type="pil", label="Upload a flower image"),
    outputs=gr.Plot(label="Plot"),
    title="What kind of flower is this?",
).launch()