Elbhnasy commited on
Commit
201ff2c
·
1 Parent(s): 0690cd9

Upload 12 files

Browse files
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_ResNetb34_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ["Autistic", "Non_Autistic"]
12
+
13
+ ### 2. Model and transforms preparation ###
14
+
15
+
16
+ resnet34, resnet34_transforms = create_ResNetb34_model(num_classes=len(class_names) )
17
+
18
+ # Load saved weights
19
+ resnet34.load_state_dict(torch.load(f="eyetracking_model.pth",
20
+ map_location=torch.device("cpu"),))
21
+ ### 3. Predict function ###
22
+ # Create predict function
23
+ def predict(img)-> Tuple[Dict, float]:
24
+ """
25
+ Transforms and performs a prediction on img.
26
+ :param img: target image .
27
+ :return: prediction and time taken.
28
+ """
29
+ # Start the timer
30
+ start_time=timer()
31
+ # Transform the target image and add a batch dimension
32
+ img=img.convert('RGB')
33
+ img = resnet34_transforms(img).unsqueeze(0)
34
+ # put model into evaluation mode and turn infarance mode
35
+ resnet34.eval()
36
+ with torch.inference_mode():
37
+
38
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
39
+ pred_probs=torch.softmax(resnet34(img),dim=1)
40
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
41
+
42
+ pred_labels_and_probs={class_names[i]:float(pred_probs[0][i]) for i in range(len(class_names))}
43
+ # Calculate the prediction time
44
+ pred_time = round(timer() - start_time, 5)
45
+
46
+ # Return the prediction dictionary and prediction time
47
+ return pred_labels_and_probs, pred_time
48
+
49
+ ### 4. Gradio app ###
50
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
51
+
52
+ # Create title, description and article strings
53
+ import gradio as gr
54
+
55
+ # Create title, description and article strings
56
+ title = "Eye Tracking diagnosis"
57
+ description = """A feature extractor computer vision model to Identification of Autism in Children Using visualization of eyetracking records and Deep Neural Networks."""
58
+
59
+
60
+ article = """Eye-tracking is the process of capturing, tracking and measuring eye movements or the absolute point of gaze (POG), which refers to the point where the eye gaze is focused in the visual scene.
61
+ **Visualization of Eye-tracking Scanpaths** represents the sequence of consecutive fixations and saccades as a trace through time and space that may overlap itself.
62
+ We used pre-trained CNN model ResNet34 as feature extractors and a DNN model as a binary classifier to identify autism in children accurately.
63
+ We used a publicly available dataset to train the suggested models, which consisted of Visualization of Eye-tracking Scanpaths of children diagnosed with autism and controls classed as autistic and non-autistic. The Resnet34
64
+ model outperformed the others, with an accuracy of 95.41%."""
65
+
66
+ # Create the Gradio demo
67
+ demo = gr.Interface(fn=predict, # mapping function from input to output
68
+ inputs=gr.Image(type="pil",source="upload"), # what are the inputs?
69
+ outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
70
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
71
+ examples=example_list,
72
+ title=title,
73
+ description=description,
74
+ article=article)
75
+
76
+
77
+ # Launch the demo!
78
+ demo.launch()
examples/TC001_39.png ADDED
examples/TC067_43.png ADDED
examples/TC079_49.png ADDED
examples/TC120_30.png ADDED
examples/TS031_07.png ADDED
examples/TS046_15.png ADDED
examples/TS072_14.png ADDED
examples/TS126_01.png ADDED
eyetracking_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f3a453fc5073a6ee3941de6ed466bac7f371e0b7aca2edde37e82efd3ea64a9
3
+ size 87338101
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ from torch import nn
5
+
6
+
7
+ def create_ResNetb34_model(num_classes:int=3,seed:int=42):
8
+ """
9
+ Creates an ResNetb34 feature extractor model and transforms.
10
+ :param num_classes: number of classes in classifier head.
11
+ Defaults to 3.
12
+ :param seed: random seed value.
13
+ Defaults to 42.
14
+ :return: feature extractor model.
15
+ transforms (torchvision.transforms): ResNetb34 image transforms.
16
+ """
17
+ # 1. Setup pretrained EffNetB1 weights
18
+ weigts = torchvision.models.ResNet34_Weights.DEFAULT
19
+ # 2. Get EffNetB2 transforms
20
+ transform = transforms.Compose([
21
+ weigts.transforms(),
22
+
23
+ #transforms.RandomHorizontalFlip(),
24
+ ])
25
+ # 3. Setup pretrained model
26
+ model=torchvision.models.resnet34(weights= "DEFAULT")
27
+
28
+ # 4. Freeze the base layers in the model (this will freeze all layers to begin with)
29
+ for param in model.parameters():
30
+ param.requires_grad=True
31
+
32
+ # 5. Change classifier head with random seed for reproducibility
33
+ torch.manual_seed(seed)
34
+ model.classifier=nn.Sequential(nn.Dropout(p=0.2,inplace=True),
35
+ nn.Linear(in_features=612,out_features=num_classes))
36
+ return model,transform
37
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ gradio==3.16.2