Ocsai_Drawings / gradio-ocsai-d.py
Peter Organisciak
turn off example caching
4699c28
from pathlib import Path
from PIL import Image
import gradio as gr
import pandas as pd
import numpy as np
import altair as alt
from transformers import AutoModelForImageClassification, AutoImageProcessor
modelname = "POrg/ocsai-d-web"
model = AutoModelForImageClassification.from_pretrained(modelname)
model.eval() # Set the model to evaluation mode
image_processor = AutoImageProcessor.from_pretrained(modelname)
prompt_images = {
"Images_11": "./blanks/Images_11_blank.png",
"Images_4": "./blanks/Images_4_blank.png",
"Images_17": "./blanks/Images_17_blank.png",
"Images_9": "./blanks/Images_9_blank.png",
"Images_3": "./blanks/Images_3_blank.png",
"Images_8": "./blanks/Images_8_blank.png",
"Images_13": "./blanks/Images_13_blank.png",
"Images_15": "./blanks/Images_15_blank.png",
"Images_12": "./blanks/Images_12_blank.png",
"Images_7": "./blanks/Images_7_blank.png",
"Images_56": "./blanks/Images_56_blank.png",
"Images_19": "./blanks/Images_19_blank.png",
}
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)
base_chart = alt.Chart(dist).mark_line().encode(
x='percentile',
y='score_norm'
)
def get_percentile(score):
return dist[dist['score_norm'] <= score].iloc[-1, 0]
def inverse_scale(logits):
# undo the min-max scaling that was done from the JRT range to 0-1
scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188}
return logits * (scaler_params['range']) + scaler_params['min']
def get_predictions(img):
inputs = image_processor(img, return_tensors="pt")
prediction = model(**inputs)
score = prediction.logits[0].detach().numpy()[0]
score = min(max(score, 0), 1)
return {
'originality': np.round(score, 2),
'jrt': np.round(inverse_scale(0), 2),
'percentile': get_percentile(score)
}
def classify_image(img_dict: dict):
# gradio passes a dictionary with background, composite, and layers
# the composite is what we want
img = img_dict['composite']
if img is None:
return None
p = get_predictions(img.convert('RGB'))
label = f"Percentile: {int(p['percentile'])}"
label_df = pd.DataFrame({'y': [p['originality']],
'x': [p['percentile']],
'text': [label]})
point = alt.Chart(label_df).mark_point(
shape='triangle',
size=200,
filled=True,
color='red'
).encode(
x='x',
y='y'
)
txt = alt.Chart(label_df).mark_text(
align='left',
baseline='middle',
dx=10, dy=-10,
fontSize=14
).encode(
y='y',
x='x',
text='text'
)
return base_chart + point + txt
def update_editor(background, img_editor):
# Clear layers and set the selected background
img_editor['background'] = background
img_editor['layers'] = []
img_editor['composite'] = None
return img_editor
editor = gr.ImageEditor(type='pil',
value=dict(
background=Image.open(prompt_images['Images_11']),
composite=None,
layers=[]
),
brush=gr.Brush(
default_size=2,
colors=["#000000", '#333333', '#666666'],
color_mode="fixed"
),
transforms=[],
sources=('upload', 'clipboard'),
layers=False
)
examples = []
for k, v in prompt_images.items():
examples.append([dict(background=Image.open(v), composite=None, layers=[])])
demo = gr.Interface(fn=classify_image,
inputs=[editor],
outputs=gr.Plot(),
title="Ocsai-D",
description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).",
examples=examples,
cache_examples=False
)
demo.launch(debug=True)