Spaces:
Running
Running
from pathlib import Path | |
from PIL import Image | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import altair as alt | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
modelname = "POrg/ocsai-d-web" | |
model = AutoModelForImageClassification.from_pretrained(modelname) | |
model.eval() # Set the model to evaluation mode | |
image_processor = AutoImageProcessor.from_pretrained(modelname) | |
prompt_images = { | |
"Images_11": "./blanks/Images_11_blank.png", | |
"Images_4": "./blanks/Images_4_blank.png", | |
"Images_17": "./blanks/Images_17_blank.png", | |
"Images_9": "./blanks/Images_9_blank.png", | |
"Images_3": "./blanks/Images_3_blank.png", | |
"Images_8": "./blanks/Images_8_blank.png", | |
"Images_13": "./blanks/Images_13_blank.png", | |
"Images_15": "./blanks/Images_15_blank.png", | |
"Images_12": "./blanks/Images_12_blank.png", | |
"Images_7": "./blanks/Images_7_blank.png", | |
"Images_56": "./blanks/Images_56_blank.png", | |
"Images_19": "./blanks/Images_19_blank.png", | |
} | |
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float) | |
base_chart = alt.Chart(dist).mark_line().encode( | |
x='percentile', | |
y='score_norm' | |
) | |
def get_percentile(score): | |
return dist[dist['score_norm'] <= score].iloc[-1, 0] | |
def inverse_scale(logits): | |
# undo the min-max scaling that was done from the JRT range to 0-1 | |
scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188} | |
return logits * (scaler_params['range']) + scaler_params['min'] | |
def get_predictions(img): | |
inputs = image_processor(img, return_tensors="pt") | |
prediction = model(**inputs) | |
score = prediction.logits[0].detach().numpy()[0] | |
score = min(max(score, 0), 1) | |
return { | |
'originality': np.round(score, 2), | |
'jrt': np.round(inverse_scale(0), 2), | |
'percentile': get_percentile(score) | |
} | |
def classify_image(img_dict: dict): | |
# gradio passes a dictionary with background, composite, and layers | |
# the composite is what we want | |
img = img_dict['composite'] | |
if img is None: | |
return None | |
p = get_predictions(img.convert('RGB')) | |
label = f"Percentile: {int(p['percentile'])}" | |
label_df = pd.DataFrame({'y': [p['originality']], | |
'x': [p['percentile']], | |
'text': [label]}) | |
point = alt.Chart(label_df).mark_point( | |
shape='triangle', | |
size=200, | |
filled=True, | |
color='red' | |
).encode( | |
x='x', | |
y='y' | |
) | |
txt = alt.Chart(label_df).mark_text( | |
align='left', | |
baseline='middle', | |
dx=10, dy=-10, | |
fontSize=14 | |
).encode( | |
y='y', | |
x='x', | |
text='text' | |
) | |
return base_chart + point + txt | |
def update_editor(background, img_editor): | |
# Clear layers and set the selected background | |
img_editor['background'] = background | |
img_editor['layers'] = [] | |
img_editor['composite'] = None | |
return img_editor | |
editor = gr.ImageEditor(type='pil', | |
value=dict( | |
background=Image.open(prompt_images['Images_11']), | |
composite=None, | |
layers=[] | |
), | |
brush=gr.Brush( | |
default_size=2, | |
colors=["#000000", '#333333', '#666666'], | |
color_mode="fixed" | |
), | |
transforms=[], | |
sources=('upload', 'clipboard'), | |
layers=False | |
) | |
examples = [] | |
for k, v in prompt_images.items(): | |
examples.append([dict(background=Image.open(v), composite=None, layers=[])]) | |
demo = gr.Interface(fn=classify_image, | |
inputs=[editor], | |
outputs=gr.Plot(), | |
title="Ocsai-D", | |
description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).", | |
examples=examples, | |
cache_examples=False | |
) | |
demo.launch(debug=True) | |