Spaces:
Running
Running
File size: 4,451 Bytes
23770d2 8c532f5 23770d2 8c532f5 23770d2 8c532f5 23770d2 8c532f5 23770d2 8c532f5 23770d2 8c532f5 23770d2 8c532f5 23770d2 69b1c36 23770d2 4699c28 8c532f5 4699c28 8c532f5 23770d2 4699c28 23770d2 4699c28 23770d2 4699c28 23770d2 4c131cd 23770d2 8c532f5 4699c28 23770d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from pathlib import Path
from PIL import Image
import gradio as gr
import pandas as pd
import numpy as np
import altair as alt
from transformers import AutoModelForImageClassification, AutoImageProcessor
modelname = "POrg/ocsai-d-web"
model = AutoModelForImageClassification.from_pretrained(modelname)
model.eval() # Set the model to evaluation mode
image_processor = AutoImageProcessor.from_pretrained(modelname)
prompt_images = {
"Images_11": "./blanks/Images_11_blank.png",
"Images_4": "./blanks/Images_4_blank.png",
"Images_17": "./blanks/Images_17_blank.png",
"Images_9": "./blanks/Images_9_blank.png",
"Images_3": "./blanks/Images_3_blank.png",
"Images_8": "./blanks/Images_8_blank.png",
"Images_13": "./blanks/Images_13_blank.png",
"Images_15": "./blanks/Images_15_blank.png",
"Images_12": "./blanks/Images_12_blank.png",
"Images_7": "./blanks/Images_7_blank.png",
"Images_56": "./blanks/Images_56_blank.png",
"Images_19": "./blanks/Images_19_blank.png",
}
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)
base_chart = alt.Chart(dist).mark_line().encode(
x='percentile',
y='score_norm'
)
def get_percentile(score):
return dist[dist['score_norm'] <= score].iloc[-1, 0]
def inverse_scale(logits):
# undo the min-max scaling that was done from the JRT range to 0-1
scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188}
return logits * (scaler_params['range']) + scaler_params['min']
def get_predictions(img):
inputs = image_processor(img, return_tensors="pt")
prediction = model(**inputs)
score = prediction.logits[0].detach().numpy()[0]
score = min(max(score, 0), 1)
return {
'originality': np.round(score, 2),
'jrt': np.round(inverse_scale(0), 2),
'percentile': get_percentile(score)
}
def classify_image(img_dict: dict):
# gradio passes a dictionary with background, composite, and layers
# the composite is what we want
img = img_dict['composite']
if img is None:
return None
p = get_predictions(img.convert('RGB'))
label = f"Percentile: {int(p['percentile'])}"
label_df = pd.DataFrame({'y': [p['originality']],
'x': [p['percentile']],
'text': [label]})
point = alt.Chart(label_df).mark_point(
shape='triangle',
size=200,
filled=True,
color='red'
).encode(
x='x',
y='y'
)
txt = alt.Chart(label_df).mark_text(
align='left',
baseline='middle',
dx=10, dy=-10,
fontSize=14
).encode(
y='y',
x='x',
text='text'
)
return base_chart + point + txt
def update_editor(background, img_editor):
# Clear layers and set the selected background
img_editor['background'] = background
img_editor['layers'] = []
img_editor['composite'] = None
return img_editor
editor = gr.ImageEditor(type='pil',
value=dict(
background=Image.open(prompt_images['Images_11']),
composite=None,
layers=[]
),
brush=gr.Brush(
default_size=2,
colors=["#000000", '#333333', '#666666'],
color_mode="fixed"
),
transforms=[],
sources=('upload', 'clipboard'),
layers=False
)
examples = []
for k, v in prompt_images.items():
examples.append([dict(background=Image.open(v), composite=None, layers=[])])
demo = gr.Interface(fn=classify_image,
inputs=[editor],
outputs=gr.Plot(),
title="Ocsai-D",
description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).",
examples=examples,
cache_examples=False
)
demo.launch(debug=True)
|