Elbhnasy's picture
Upload 12 files
201ff2c
### 1. Imports and class names setup ###
import gradio as gr
import os
import torch
from model import create_ResNetb34_model
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup class names
class_names = ["Autistic", "Non_Autistic"]
### 2. Model and transforms preparation ###
resnet34, resnet34_transforms = create_ResNetb34_model(num_classes=len(class_names) )
# Load saved weights
resnet34.load_state_dict(torch.load(f="eyetracking_model.pth",
map_location=torch.device("cpu"),))
### 3. Predict function ###
# Create predict function
def predict(img)-> Tuple[Dict, float]:
"""
Transforms and performs a prediction on img.
:param img: target image .
:return: prediction and time taken.
"""
# Start the timer
start_time=timer()
# Transform the target image and add a batch dimension
img=img.convert('RGB')
img = resnet34_transforms(img).unsqueeze(0)
# put model into evaluation mode and turn infarance mode
resnet34.eval()
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs=torch.softmax(resnet34(img),dim=1)
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
pred_labels_and_probs={class_names[i]:float(pred_probs[0][i]) for i in range(len(class_names))}
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred_labels_and_probs, pred_time
### 4. Gradio app ###
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create title, description and article strings
import gradio as gr
# Create title, description and article strings
title = "Eye Tracking diagnosis"
description = """A feature extractor computer vision model to Identification of Autism in Children Using visualization of eyetracking records and Deep Neural Networks."""
article = """Eye-tracking is the process of capturing, tracking and measuring eye movements or the absolute point of gaze (POG), which refers to the point where the eye gaze is focused in the visual scene.
**Visualization of Eye-tracking Scanpaths** represents the sequence of consecutive fixations and saccades as a trace through time and space that may overlap itself.
We used pre-trained CNN model ResNet34 as feature extractors and a DNN model as a binary classifier to identify autism in children accurately.
We used a publicly available dataset to train the suggested models, which consisted of Visualization of Eye-tracking Scanpaths of children diagnosed with autism and controls classed as autistic and non-autistic. The Resnet34
model outperformed the others, with an accuracy of 95.41%."""
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil",source="upload"), # what are the inputs?
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
examples=example_list,
title=title,
description=description,
article=article)
# Launch the demo!
demo.launch()