Spaces:
Runtime error
Runtime error
### 1. Imports and class names setup ### | |
import gradio as gr | |
import os | |
import torch | |
from model import create_ResNetb34_model | |
from timeit import default_timer as timer | |
from typing import Tuple, Dict | |
# Setup class names | |
class_names = ["Autistic", "Non_Autistic"] | |
### 2. Model and transforms preparation ### | |
resnet34, resnet34_transforms = create_ResNetb34_model(num_classes=len(class_names) ) | |
# Load saved weights | |
resnet34.load_state_dict(torch.load(f="eyetracking_model.pth", | |
map_location=torch.device("cpu"),)) | |
### 3. Predict function ### | |
# Create predict function | |
def predict(img)-> Tuple[Dict, float]: | |
""" | |
Transforms and performs a prediction on img. | |
:param img: target image . | |
:return: prediction and time taken. | |
""" | |
# Start the timer | |
start_time=timer() | |
# Transform the target image and add a batch dimension | |
img=img.convert('RGB') | |
img = resnet34_transforms(img).unsqueeze(0) | |
# put model into evaluation mode and turn infarance mode | |
resnet34.eval() | |
with torch.inference_mode(): | |
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities | |
pred_probs=torch.softmax(resnet34(img),dim=1) | |
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter) | |
pred_labels_and_probs={class_names[i]:float(pred_probs[0][i]) for i in range(len(class_names))} | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
# Return the prediction dictionary and prediction time | |
return pred_labels_and_probs, pred_time | |
### 4. Gradio app ### | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create title, description and article strings | |
import gradio as gr | |
# Create title, description and article strings | |
title = "Eye Tracking diagnosis" | |
description = """A feature extractor computer vision model to Identification of Autism in Children Using visualization of eyetracking records and Deep Neural Networks.""" | |
article = """Eye-tracking is the process of capturing, tracking and measuring eye movements or the absolute point of gaze (POG), which refers to the point where the eye gaze is focused in the visual scene. | |
**Visualization of Eye-tracking Scanpaths** represents the sequence of consecutive fixations and saccades as a trace through time and space that may overlap itself. | |
We used pre-trained CNN model ResNet34 as feature extractors and a DNN model as a binary classifier to identify autism in children accurately. | |
We used a publicly available dataset to train the suggested models, which consisted of Visualization of Eye-tracking Scanpaths of children diagnosed with autism and controls classed as autistic and non-autistic. The Resnet34 | |
model outperformed the others, with an accuracy of 95.41%.""" | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict, # mapping function from input to output | |
inputs=gr.Image(type="pil",source="upload"), # what are the inputs? | |
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs | |
examples=example_list, | |
title=title, | |
description=description, | |
article=article) | |
# Launch the demo! | |
demo.launch() | |