File size: 4,451 Bytes
23770d2
 
 
 
 
 
8c532f5
23770d2
8c532f5
 
 
 
23770d2
 
8c532f5
 
 
 
 
 
 
 
 
 
 
 
23770d2
 
 
8c532f5
 
 
 
23770d2
 
 
 
 
 
 
 
 
 
 
 
8c532f5
 
 
 
 
23770d2
8c532f5
 
 
23770d2
 
 
8c532f5
23770d2
 
69b1c36
 
 
 
 
23770d2
4699c28
 
 
 
 
 
 
 
 
8c532f5
 
4699c28
8c532f5
 
 
 
23770d2
 
4699c28
 
23770d2
 
4699c28
23770d2
 
4699c28
23770d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c131cd
23770d2
 
 
 
 
8c532f5
4699c28
 
23770d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pathlib import Path
from PIL import Image
import gradio as gr
import pandas as pd
import numpy as np
import altair as alt
from transformers import AutoModelForImageClassification, AutoImageProcessor

modelname = "POrg/ocsai-d-web"
model = AutoModelForImageClassification.from_pretrained(modelname)
model.eval()  # Set the model to evaluation mode
image_processor = AutoImageProcessor.from_pretrained(modelname)

prompt_images = {
    "Images_11": "./blanks/Images_11_blank.png",
    "Images_4": "./blanks/Images_4_blank.png",
    "Images_17": "./blanks/Images_17_blank.png",
    "Images_9": "./blanks/Images_9_blank.png",
    "Images_3": "./blanks/Images_3_blank.png",
    "Images_8": "./blanks/Images_8_blank.png",
    "Images_13": "./blanks/Images_13_blank.png",
    "Images_15": "./blanks/Images_15_blank.png",
    "Images_12": "./blanks/Images_12_blank.png",
    "Images_7": "./blanks/Images_7_blank.png",
    "Images_56": "./blanks/Images_56_blank.png",
    "Images_19": "./blanks/Images_19_blank.png",
}

dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)
base_chart = alt.Chart(dist).mark_line().encode(
    x='percentile',
    y='score_norm'
)


def get_percentile(score):
    return dist[dist['score_norm'] <= score].iloc[-1, 0]


def inverse_scale(logits):
    # undo the min-max scaling that was done from the JRT range to 0-1
    scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188}
    return logits * (scaler_params['range']) + scaler_params['min']


def get_predictions(img):
    inputs = image_processor(img, return_tensors="pt")
    prediction = model(**inputs)
    score = prediction.logits[0].detach().numpy()[0]
    score = min(max(score, 0), 1)
    return {
        'originality': np.round(score, 2),
        'jrt': np.round(inverse_scale(0), 2),
        'percentile': get_percentile(score)
    }


def classify_image(img_dict: dict):
    # gradio passes a dictionary with background, composite, and layers
    # the composite is what we want
    img = img_dict['composite']
    if img is None:
        return None
    
    p = get_predictions(img.convert('RGB'))

    label = f"Percentile: {int(p['percentile'])}"
    label_df = pd.DataFrame({'y': [p['originality']],
                             'x': [p['percentile']],
                             'text': [label]})

    point = alt.Chart(label_df).mark_point(
        shape='triangle',
        size=200,
        filled=True,
        color='red'
    ).encode(
        x='x',
        y='y'
    )
    
    txt = alt.Chart(label_df).mark_text(
        align='left',
        baseline='middle',
        dx=10, dy=-10,
        fontSize=14
    ).encode(
        y='y',
        x='x',
        text='text'
    )
    return base_chart + point + txt


def update_editor(background, img_editor):
    # Clear layers and set the selected background
    img_editor['background'] = background
    img_editor['layers'] = []
    img_editor['composite'] = None
    return img_editor


editor = gr.ImageEditor(type='pil',
                        value=dict(
                            background=Image.open(prompt_images['Images_11']),
                            composite=None,
                            layers=[]
                        ),
                        brush=gr.Brush(
                            default_size=2,
                            colors=["#000000", '#333333', '#666666'],
                            color_mode="fixed"
                        ),
                        transforms=[],
                        sources=('upload', 'clipboard'),
                        layers=False
                        )

examples = []
for k, v in prompt_images.items():
    examples.append([dict(background=Image.open(v), composite=None, layers=[])])

demo = gr.Interface(fn=classify_image,
                     inputs=[editor],
                     outputs=gr.Plot(),
                     title="Ocsai-D",
                     description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).",
                     examples=examples,
                     cache_examples=False
                     )

demo.launch(debug=True)