Spaces:
Running
Running
Peter Organisciak
commited on
Commit
·
8c532f5
1
Parent(s):
370875f
Fix logits parsing
Browse files- gradio-ocsai-d.py +45 -37
gradio-ocsai-d.py
CHANGED
@@ -1,28 +1,36 @@
|
|
1 |
from pathlib import Path
|
2 |
-
from transformers import pipeline
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
import numpy as np
|
7 |
import altair as alt
|
|
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
10 |
prompt_images = {
|
11 |
-
"Images_11": "blanks/Images_11_blank.png",
|
12 |
-
"Images_4": "blanks/Images_4_blank.png",
|
13 |
-
"Images_17": "blanks/Images_17_blank.png",
|
14 |
-
"Images_9": "blanks/Images_9_blank.png",
|
15 |
-
"Images_3": "blanks/Images_3_blank.png",
|
16 |
-
"Images_8": "blanks/Images_8_blank.png",
|
17 |
-
"Images_13": "blanks/Images_13_blank.png",
|
18 |
-
"Images_15": "blanks/Images_15_blank.png",
|
19 |
-
"Images_12": "blanks/Images_12_blank.png",
|
20 |
-
"Images_7": "blanks/Images_7_blank.png",
|
21 |
-
"Images_56": "blanks/Images_56_blank.png",
|
22 |
-
"Images_19": "blanks/Images_19_blank.png",
|
23 |
}
|
24 |
|
25 |
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
def get_percentile(score):
|
@@ -35,41 +43,43 @@ def inverse_scale(logits):
|
|
35 |
return logits * (scaler_params['range']) + scaler_params['min']
|
36 |
|
37 |
|
38 |
-
def get_predictions(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
return {
|
44 |
-
'originality'
|
45 |
-
'jrt'
|
46 |
-
'percentile': get_percentile(
|
47 |
}
|
48 |
|
49 |
|
50 |
-
|
51 |
-
x='percentile',
|
52 |
-
y='score_norm'
|
53 |
-
)
|
54 |
-
|
55 |
-
def classify_image(img_dict):
|
56 |
# gradio passes a dictionary with background, composite, and layers
|
57 |
# the composite is what we want
|
58 |
-
img = img_dict['composite']
|
59 |
-
p = get_predictions(
|
60 |
|
61 |
-
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
65 |
align='left',
|
66 |
baseline='middle',
|
67 |
-
dx=
|
68 |
).encode(
|
69 |
y='y',
|
70 |
text='text'
|
71 |
)
|
72 |
-
return base_chart +
|
73 |
|
74 |
|
75 |
def update_editor(background, img_editor):
|
@@ -80,8 +90,6 @@ def update_editor(background, img_editor):
|
|
80 |
return img_editor
|
81 |
|
82 |
|
83 |
-
classifier = pipeline("image-classification", model='POrg/ocsai-d-large')
|
84 |
-
|
85 |
editor = gr.ImageEditor(type='pil',
|
86 |
value=dict(
|
87 |
background=Image.open(prompt_images['Images_11']),
|
@@ -106,7 +114,7 @@ demo = gr.Interface(fn=classify_image,
|
|
106 |
inputs=[editor],
|
107 |
outputs=gr.Plot(),
|
108 |
title="Ocsai-D",
|
109 |
-
description="Complete the drawing and classify the originality.
|
110 |
examples=examples
|
111 |
)
|
112 |
|
|
|
1 |
from pathlib import Path
|
|
|
2 |
from PIL import Image
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
6 |
import altair as alt
|
7 |
+
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
8 |
|
9 |
+
modelname = "POrg/ocsai-d-web"
|
10 |
+
model = AutoModelForImageClassification.from_pretrained(modelname)
|
11 |
+
model.eval() # Set the model to evaluation mode
|
12 |
+
image_processor = AutoImageProcessor.from_pretrained(modelname)
|
13 |
|
14 |
prompt_images = {
|
15 |
+
"Images_11": "./blanks/Images_11_blank.png",
|
16 |
+
"Images_4": "./blanks/Images_4_blank.png",
|
17 |
+
"Images_17": "./blanks/Images_17_blank.png",
|
18 |
+
"Images_9": "./blanks/Images_9_blank.png",
|
19 |
+
"Images_3": "./blanks/Images_3_blank.png",
|
20 |
+
"Images_8": "./blanks/Images_8_blank.png",
|
21 |
+
"Images_13": "./blanks/Images_13_blank.png",
|
22 |
+
"Images_15": "./blanks/Images_15_blank.png",
|
23 |
+
"Images_12": "./blanks/Images_12_blank.png",
|
24 |
+
"Images_7": "./blanks/Images_7_blank.png",
|
25 |
+
"Images_56": "./blanks/Images_56_blank.png",
|
26 |
+
"Images_19": "./blanks/Images_19_blank.png",
|
27 |
}
|
28 |
|
29 |
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)
|
30 |
+
base_chart = alt.Chart(dist).mark_line().encode(
|
31 |
+
x='percentile',
|
32 |
+
y='score_norm'
|
33 |
+
)
|
34 |
|
35 |
|
36 |
def get_percentile(score):
|
|
|
43 |
return logits * (scaler_params['range']) + scaler_params['min']
|
44 |
|
45 |
|
46 |
+
def get_predictions(img):
|
47 |
+
inputs = image_processor(img, return_tensors="pt")
|
48 |
+
prediction = model(**inputs)
|
49 |
+
score = prediction.logits[0].detach().numpy()[0]
|
50 |
+
score = min(max(score, 0), 1)
|
51 |
return {
|
52 |
+
'originality': np.round(score, 2),
|
53 |
+
'jrt': np.round(inverse_scale(0), 2),
|
54 |
+
'percentile': get_percentile(score)
|
55 |
}
|
56 |
|
57 |
|
58 |
+
def classify_image(img_dict: dict):
|
|
|
|
|
|
|
|
|
|
|
59 |
# gradio passes a dictionary with background, composite, and layers
|
60 |
# the composite is what we want
|
61 |
+
img = img_dict['composite'].convert('RGB')
|
62 |
+
p = get_predictions(img)
|
63 |
|
64 |
+
score_round = str(np.round(p['originality'], 2))
|
65 |
+
label = f"Percentile: {p['percentile']}; Normalized Score: {score_round}"
|
66 |
+
label_df = pd.DataFrame({'y': [p['originality']], 'text': [label]})
|
67 |
|
68 |
+
rule = alt.Chart(label_df).mark_rule(
|
69 |
+
color='red'
|
70 |
+
).encode(
|
71 |
+
y='y'
|
72 |
+
)
|
73 |
+
|
74 |
+
txt = alt.Chart(label_df).mark_text(
|
75 |
align='left',
|
76 |
baseline='middle',
|
77 |
+
dx=0, dy=-8
|
78 |
).encode(
|
79 |
y='y',
|
80 |
text='text'
|
81 |
)
|
82 |
+
return base_chart + rule + txt
|
83 |
|
84 |
|
85 |
def update_editor(background, img_editor):
|
|
|
90 |
return img_editor
|
91 |
|
92 |
|
|
|
|
|
93 |
editor = gr.ImageEditor(type='pil',
|
94 |
value=dict(
|
95 |
background=Image.open(prompt_images['Images_11']),
|
|
|
114 |
inputs=[editor],
|
115 |
outputs=gr.Plot(),
|
116 |
title="Ocsai-D",
|
117 |
+
description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).",
|
118 |
examples=examples
|
119 |
)
|
120 |
|