from pathlib import Path from transformers import pipeline from PIL import Image import gradio as gr import pandas as pd import numpy as np import altair as alt prompt_images = { "Images_11": "blanks/Images_11_blank.png", "Images_4": "blanks/Images_4_blank.png", "Images_17": "blanks/Images_17_blank.png", "Images_9": "blanks/Images_9_blank.png", "Images_3": "blanks/Images_3_blank.png", "Images_8": "blanks/Images_8_blank.png", "Images_13": "blanks/Images_13_blank.png", "Images_15": "blanks/Images_15_blank.png", "Images_12": "blanks/Images_12_blank.png", "Images_7": "blanks/Images_7_blank.png", "Images_56": "blanks/Images_56_blank.png", "Images_19": "blanks/Images_19_blank.png", } dist = pd.read_csv('./score_norm_distribution.csv', dtype=float) def get_percentile(score): return dist[dist['score_norm'] <= score].iloc[-1, 0] def inverse_scale(logits): # undo the min-max scaling that was done from the JRT range to 0-1 scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188} return logits * (scaler_params['range']) + scaler_params['min'] def get_predictions(img_dict): # gradio passes a dictionary with background, composite, and layers # the composite is what we want img = img_dict['composite'] predictions = classifier(img) return { 'originality' : np.round(predictions[0]['score'], 2), 'jrt' : np.round(inverse_scale(0), 2), 'percentile': get_percentile(predictions[0]['score']) } base_chart = alt.Chart(dist).mark_line().encode( x='percentile', y='score_norm' ) def classify_image(img_dict): # gradio passes a dictionary with background, composite, and layers # the composite is what we want img = img_dict['composite'] p = get_predictions(img_dict) percentile_mark = alt.Chart(pd.DataFrame({'y': [p['originality']]})).mark_rule(color='red').encode(y='y') # Text annotation for the percentile mark text = alt.Chart(pd.DataFrame({'y': [p['originality']], 'text': [f"Percentile: {p['percentile']}; Normalized Score: {p['originality']}"]})).mark_text( align='left', baseline='middle', dx=7, dy=-8 # Nudges text to right so it doesn't overlap with the line ).encode( y='y', text='text' ) return base_chart + percentile_mark + text def update_editor(background, img_editor): # Clear layers and set the selected background img_editor['background'] = background img_editor['layers'] = [] img_editor['composite'] = None return img_editor classifier = pipeline("image-classification", model='POrg/ocsai-d-large') editor = gr.ImageEditor(type='pil', value=dict( background=Image.open(prompt_images['Images_11']), composite=None, layers=[] ), brush=gr.Brush( default_size=2, colors=["#000000", '#333333', '#666666'], color_mode="fixed" ), transforms=[], sources=('upload', 'clipboard'), layers=False ) examples = [] for k, v in prompt_images.items(): examples.append(dict(background=Image.open(v), composite=None, layers=[])) demo = gr.Interface(fn=classify_image, inputs=[editor], outputs=gr.Plot(), title="Ocsai-D", description="Complete the drawing and classify the originality. Examples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)). Choose the brush icon below the image to start editing.", examples=examples ) demo.launch(debug=True)