from pathlib import Path from PIL import Image import gradio as gr import pandas as pd import numpy as np import altair as alt from transformers import AutoModelForImageClassification, AutoImageProcessor modelname = "POrg/ocsai-d-web" model = AutoModelForImageClassification.from_pretrained(modelname) model.eval() # Set the model to evaluation mode image_processor = AutoImageProcessor.from_pretrained(modelname) prompt_images = { "Images_11": "./blanks/Images_11_blank.png", "Images_4": "./blanks/Images_4_blank.png", "Images_17": "./blanks/Images_17_blank.png", "Images_9": "./blanks/Images_9_blank.png", "Images_3": "./blanks/Images_3_blank.png", "Images_8": "./blanks/Images_8_blank.png", "Images_13": "./blanks/Images_13_blank.png", "Images_15": "./blanks/Images_15_blank.png", "Images_12": "./blanks/Images_12_blank.png", "Images_7": "./blanks/Images_7_blank.png", "Images_56": "./blanks/Images_56_blank.png", "Images_19": "./blanks/Images_19_blank.png", } dist = pd.read_csv('./score_norm_distribution.csv', dtype=float) base_chart = alt.Chart(dist).mark_line().encode( x='percentile', y='score_norm' ) def get_percentile(score): return dist[dist['score_norm'] <= score].iloc[-1, 0] def inverse_scale(logits): # undo the min-max scaling that was done from the JRT range to 0-1 scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188} return logits * (scaler_params['range']) + scaler_params['min'] def get_predictions(img): inputs = image_processor(img, return_tensors="pt") prediction = model(**inputs) score = prediction.logits[0].detach().numpy()[0] score = min(max(score, 0), 1) return { 'originality': np.round(score, 2), 'jrt': np.round(inverse_scale(0), 2), 'percentile': get_percentile(score) } def classify_image(img_dict: dict): # gradio passes a dictionary with background, composite, and layers # the composite is what we want img = img_dict['composite'] if img is None: return None p = get_predictions(img.convert('RGB')) label = f"Percentile: {int(p['percentile'])}" label_df = pd.DataFrame({'y': [p['originality']], 'x': [p['percentile']], 'text': [label]}) point = alt.Chart(label_df).mark_point( shape='triangle', size=200, filled=True, color='red' ).encode( x='x', y='y' ) txt = alt.Chart(label_df).mark_text( align='left', baseline='middle', dx=10, dy=-10, fontSize=14 ).encode( y='y', x='x', text='text' ) return base_chart + point + txt def update_editor(background, img_editor): # Clear layers and set the selected background img_editor['background'] = background img_editor['layers'] = [] img_editor['composite'] = None return img_editor editor = gr.ImageEditor(type='pil', value=dict( background=Image.open(prompt_images['Images_11']), composite=None, layers=[] ), brush=gr.Brush( default_size=2, colors=["#000000", '#333333', '#666666'], color_mode="fixed" ), transforms=[], sources=('upload', 'clipboard'), layers=False ) examples = [] for k, v in prompt_images.items(): examples.append([dict(background=Image.open(v), composite=None, layers=[])]) demo = gr.Interface(fn=classify_image, inputs=[editor], outputs=gr.Plot(), title="Ocsai-D", description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).", examples=examples, cache_examples=False ) demo.launch(debug=True)