Spaces:
Runtime error
Runtime error
Upload 12 files
Browse files- app.py +78 -0
- examples/TC001_39.png +0 -0
- examples/TC067_43.png +0 -0
- examples/TC079_49.png +0 -0
- examples/TC120_30.png +0 -0
- examples/TS031_07.png +0 -0
- examples/TS046_15.png +0 -0
- examples/TS072_14.png +0 -0
- examples/TS126_01.png +0 -0
- eyetracking_model.pth +3 -0
- model.py +37 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 1. Imports and class names setup ###
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from model import create_ResNetb34_model
|
7 |
+
from timeit import default_timer as timer
|
8 |
+
from typing import Tuple, Dict
|
9 |
+
|
10 |
+
# Setup class names
|
11 |
+
class_names = ["Autistic", "Non_Autistic"]
|
12 |
+
|
13 |
+
### 2. Model and transforms preparation ###
|
14 |
+
|
15 |
+
|
16 |
+
resnet34, resnet34_transforms = create_ResNetb34_model(num_classes=len(class_names) )
|
17 |
+
|
18 |
+
# Load saved weights
|
19 |
+
resnet34.load_state_dict(torch.load(f="eyetracking_model.pth",
|
20 |
+
map_location=torch.device("cpu"),))
|
21 |
+
### 3. Predict function ###
|
22 |
+
# Create predict function
|
23 |
+
def predict(img)-> Tuple[Dict, float]:
|
24 |
+
"""
|
25 |
+
Transforms and performs a prediction on img.
|
26 |
+
:param img: target image .
|
27 |
+
:return: prediction and time taken.
|
28 |
+
"""
|
29 |
+
# Start the timer
|
30 |
+
start_time=timer()
|
31 |
+
# Transform the target image and add a batch dimension
|
32 |
+
img=img.convert('RGB')
|
33 |
+
img = resnet34_transforms(img).unsqueeze(0)
|
34 |
+
# put model into evaluation mode and turn infarance mode
|
35 |
+
resnet34.eval()
|
36 |
+
with torch.inference_mode():
|
37 |
+
|
38 |
+
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
39 |
+
pred_probs=torch.softmax(resnet34(img),dim=1)
|
40 |
+
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
41 |
+
|
42 |
+
pred_labels_and_probs={class_names[i]:float(pred_probs[0][i]) for i in range(len(class_names))}
|
43 |
+
# Calculate the prediction time
|
44 |
+
pred_time = round(timer() - start_time, 5)
|
45 |
+
|
46 |
+
# Return the prediction dictionary and prediction time
|
47 |
+
return pred_labels_and_probs, pred_time
|
48 |
+
|
49 |
+
### 4. Gradio app ###
|
50 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
51 |
+
|
52 |
+
# Create title, description and article strings
|
53 |
+
import gradio as gr
|
54 |
+
|
55 |
+
# Create title, description and article strings
|
56 |
+
title = "Eye Tracking diagnosis"
|
57 |
+
description = """A feature extractor computer vision model to Identification of Autism in Children Using visualization of eyetracking records and Deep Neural Networks."""
|
58 |
+
|
59 |
+
|
60 |
+
article = """Eye-tracking is the process of capturing, tracking and measuring eye movements or the absolute point of gaze (POG), which refers to the point where the eye gaze is focused in the visual scene.
|
61 |
+
**Visualization of Eye-tracking Scanpaths** represents the sequence of consecutive fixations and saccades as a trace through time and space that may overlap itself.
|
62 |
+
We used pre-trained CNN model ResNet34 as feature extractors and a DNN model as a binary classifier to identify autism in children accurately.
|
63 |
+
We used a publicly available dataset to train the suggested models, which consisted of Visualization of Eye-tracking Scanpaths of children diagnosed with autism and controls classed as autistic and non-autistic. The Resnet34
|
64 |
+
model outperformed the others, with an accuracy of 95.41%."""
|
65 |
+
|
66 |
+
# Create the Gradio demo
|
67 |
+
demo = gr.Interface(fn=predict, # mapping function from input to output
|
68 |
+
inputs=gr.Image(type="pil",source="upload"), # what are the inputs?
|
69 |
+
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
|
70 |
+
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
|
71 |
+
examples=example_list,
|
72 |
+
title=title,
|
73 |
+
description=description,
|
74 |
+
article=article)
|
75 |
+
|
76 |
+
|
77 |
+
# Launch the demo!
|
78 |
+
demo.launch()
|
examples/TC001_39.png
ADDED
![]() |
examples/TC067_43.png
ADDED
![]() |
examples/TC079_49.png
ADDED
![]() |
examples/TC120_30.png
ADDED
![]() |
examples/TS031_07.png
ADDED
![]() |
examples/TS046_15.png
ADDED
![]() |
examples/TS072_14.png
ADDED
![]() |
examples/TS126_01.png
ADDED
![]() |
eyetracking_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f3a453fc5073a6ee3941de6ed466bac7f371e0b7aca2edde37e82efd3ea64a9
|
3 |
+
size 87338101
|
model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torchvision import transforms
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
def create_ResNetb34_model(num_classes:int=3,seed:int=42):
|
8 |
+
"""
|
9 |
+
Creates an ResNetb34 feature extractor model and transforms.
|
10 |
+
:param num_classes: number of classes in classifier head.
|
11 |
+
Defaults to 3.
|
12 |
+
:param seed: random seed value.
|
13 |
+
Defaults to 42.
|
14 |
+
:return: feature extractor model.
|
15 |
+
transforms (torchvision.transforms): ResNetb34 image transforms.
|
16 |
+
"""
|
17 |
+
# 1. Setup pretrained EffNetB1 weights
|
18 |
+
weigts = torchvision.models.ResNet34_Weights.DEFAULT
|
19 |
+
# 2. Get EffNetB2 transforms
|
20 |
+
transform = transforms.Compose([
|
21 |
+
weigts.transforms(),
|
22 |
+
|
23 |
+
#transforms.RandomHorizontalFlip(),
|
24 |
+
])
|
25 |
+
# 3. Setup pretrained model
|
26 |
+
model=torchvision.models.resnet34(weights= "DEFAULT")
|
27 |
+
|
28 |
+
# 4. Freeze the base layers in the model (this will freeze all layers to begin with)
|
29 |
+
for param in model.parameters():
|
30 |
+
param.requires_grad=True
|
31 |
+
|
32 |
+
# 5. Change classifier head with random seed for reproducibility
|
33 |
+
torch.manual_seed(seed)
|
34 |
+
model.classifier=nn.Sequential(nn.Dropout(p=0.2,inplace=True),
|
35 |
+
nn.Linear(in_features=612,out_features=num_classes))
|
36 |
+
return model,transform
|
37 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
gradio==3.16.2
|