File size: 3,928 Bytes
35c6f04 6e13731 35c6f04 095c1a9 24fa736 095c1a9 35c6f04 ea7e208 35c6f04 790e8db 35c6f04 ea7e208 35c6f04 fa14ef7 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 e2e7758 35c6f04 ed3066d 35c6f04 c7ddf50 35c6f04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from model import TinyCNN
from timeit import default_timer as timer
from typing import Tuple, Dict
import torch
import torchvision
from torchvision import transforms
from torch import nn
# Setup class names
with open("class_names.txt", "r") as f: # reading them in from class_names.txt
class_names = [food_name.strip() for food_name in f.readlines()]
### 2. Model and transforms preparation ###
# Create model
TinyCNN_model = TinyCNN(input_shape=3, # number of color channels (3 for RGB)
hidden_units=64,
output_shape=len(class_names))
loss_fn = nn.CrossEntropyLoss() # measure how wrong our model is
optimizer = torch.optim.Adam(params = TinyCNN_model.parameters() ,lr=0.001)
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor()
])
# Load saved weights
TinyCNN_model.load_state_dict(
torch.load(
f="TinyCNN_3.pth",
map_location=torch.device("cpu"), # load to CPU
)
)
### 3. Predict function ###
# Create predict function
def predict(img) :
"""Transforms and performs a prediction on img and returns prediction and time taken.
"""
# Start the timer
start_time = timer()
# Transform the target image and add a batch dimension
img = transform(img).unsqueeze(dim=0)
# Put model into evaluation mode and turn on inference mode
TinyCNN_model.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(TinyCNN_model(img), dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
emoji_list = [["emojis/" + example] for example in os.listdir("emojis")]
emoji_1 = torch.argmax(pred_probs)
emoji = class_names[emoji_1]
if emoji == 'angry':
a = emoji_list[0]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'disgust':
a = emoji_list[1]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'fear':
a = emoji_list[2]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'happy':
a = emoji_list[3]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'neutral':
a = emoji_list[4]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'sad':
a = emoji_list[5]
a = a[0]
return pred_labels_and_probs,a
elif emoji == 'surprise':
a = emoji_list[6]
a = a[0]
return pred_labels_and_probs,a
# Return the prediction dictionary and prediction time
### 4. Gradio app ###
# Create title, description and article strings
title = "Expression Detection"
description = "An app to predict emotions from the list.[Angry, Disgust, Fear, Happy, Neutral, Sad, Surprise]. The model can predict on single face only. So upload an image which has only one face"
article = "Created as a college project."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(sources=["upload"], type='pil'),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Image(label="Emotion"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
# Launch the app!
demo.launch()
|