HakimAiV2 / figures /main_figure_2a.py
scdrand23's picture
not working version
814a594
raw
history blame
3.3 kB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json, os
from statannot import add_stat_annotation
from statannotations.Annotator import Annotator
df = pd.read_csv('results/all_eval/all_metrics_median.csv')
metric = 'dice'
model_names = {metric: 'BiomedParse', f'medsam_{metric}': 'MedSAM (oracle box)', f'sam_{metric}': 'SAM (oracle box)',
f'dino_medsam_{metric}': 'MedSAM (Grounding DINO)', f'dino_sam_{metric}': 'SAM (Grounding DINO)'}
df = df.rename(columns=model_names)
score_vars = list(model_names.values())
modality_list = ['CT', 'MRI', 'X-Ray', 'Pathology', 'Ultrasound', 'Fundus', 'Endoscope', 'Dermoscopy', 'OCT']
# modify modality names
mod_names = {'CT': 'CT', 'MRI': 'MRI', 'MRI-T2': 'MRI', 'MRI-ADC': 'MRI', 'MRI-FLAIR': 'MRI', 'MRI-T1-Gd': 'MRI', 'X-Ray': 'X-Ray', 'pathology': 'Pathology',
'ultrasound': 'Ultrasound', 'fundus': 'Fundus', 'endoscope': 'Endoscope', 'dermoscopy': 'Dermoscopy', 'OCT': 'OCT', 'All': 'All'}
df['modality'] = df['modality'].apply(lambda x: mod_names[x])
# add an "All" modality
all_df = df.copy()
all_df['modality'] = 'All'
df = pd.concat([df, all_df])
df_long = df[['modality', 'task']+score_vars].melt(id_vars=['modality', 'task'], var_name='Model', value_name='Performance')
# add statistical annotations
fig, ax = plt.subplots(figsize=(9, 6))
ax = sns.boxplot(data=df_long, x='modality', y='Performance', hue='Model', ax=ax, palette='Set2',
order=['All']+modality_list,
whis=2, saturation=0.6, linewidth=0.8, fliersize=0.5) # whiskers at 5th and 95th percentile)
#errorbar='sd', capsize=0.1, errwidth=1.5)
# no frame
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
# add arrow on y axis
ax.annotate('', xy=(0, 1.05), xytext=(0, -0.01), arrowprops=dict(arrowstyle='->', lw=1, color='black'), xycoords='axes fraction')
plt.title('')
if metric == 'dice':
plt.ylabel('Dice score', fontsize=18)
elif metric == 'assd':
plt.ylabel('ASSD', fontsize=18)
plt.xlabel('')
plt.xticks(rotation=45, fontsize=16)
plt.yticks(fontsize=14)
# axis thickness
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
# change to log scale
if metric == 'assd':
plt.yscale('log')
# set legend names
ax.legend(score_vars, fontsize=14)
# legend on top in a row, without frame
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.4), ncol=2, fontsize=14, frameon=False)
# Define pairs between models for each modality
box_pairs = []
# Add statistical annotations for each modality
for modality in ['All']+modality_list:
# Define pairs between models within the same modality
box_pairs += [((modality, 'BiomedParse'), (modality, 'MedSAM (oracle box)'))]
annotator = Annotator(ax, box_pairs, data=df_long, x='modality', y='Performance', hue='Model',
order=['All']+modality_list)
annotator.configure(test='t-test_paired', text_format='star', loc='inside', hide_non_significant=True)
annotator.apply_test(alternative='less')
annotator.annotate()
plt.tight_layout()
# save the plot
ax.get_figure().savefig(f'plots/{metric}_comparison.png')
ax.get_figure().savefig(f'plots/{metric}_comparison.pdf', bbox_inches='tight')