jerome-white's picture
Add document information
dacb5dc
raw
history blame
8.05 kB
import itertools as it
import functools as ft
from pathlib import Path
from dataclasses import fields, asdict
import pandas as pd
import gradio as gr
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset, get_dataset_split_names
from scipy.special import expit
from matplotlib.ticker import FixedLocator, StrMethodFormatter
from hdinterval import HDI, HDInterval
#
#
#
def compare(df, model_1, model_2):
mcol = 'source'
models = [
model_1,
model_2,
]
view = (df
.query(f'{mcol} in @models')
.pivot(index=['chain', 'sample'],
columns=mcol,
values='value'))
return expit(view[model_1] - view[model_2])
#
#
#
class DataSummarizer:
def __init__(self, name, ci):
self.name = name
self.ci = ci
def __call__(self, df):
df = (self
.summarize(df)
.sort_values(by=['value', 'uncertainty'],
ascending=[False, True])
.drop(columns='uncertainty')
.reset_index(drop=True))
df.index += 1
return df.reset_index(names='rank')
def aggregate(self, source, df):
values = df['value']
hdi = HDInterval(values)
interval = hdi(self.ci)
agg = {
'source': source,
'value': values.median(),
'uncertainty': interval.width(),
}
agg.update(asdict(interval))
return agg
def summarize(self, df):
groups = df.groupby('source', sort=False)
records = it.starmap(self.aggregate, groups)
return pd.DataFrame.from_records(records)
#
#
#
class DataPlotter:
def __init__(self, df):
self.df = df
def plot(self):
fig = plt.figure(dpi=200)
ax = fig.gca()
self.draw(ax)
ax.grid(visible=True,
axis='both',
alpha=0.25,
linestyle='dotted')
fig.tight_layout()
return fig
def draw(self, ax):
raise NotImplementedError()
class ComparisonPlotter(DataPlotter):
_uncertain = 0.5
def __init__(self, df, model_1, model_2, ci):
super().__init__(compare(df, model_1, model_2))
self.interval = HDInterval(self.df)
self.ci = ci
def draw(self, ax):
hdi = self.interval(self.ci)
(c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2)
ax = sns.histplot(data=self.df,
stat='density',
color=c_hist)
ax.set_xlabel('\u03B1$_{1}$ - \u03B1$_{2}$')
self.pr(ax, hdi, c_hdi)
self.min_inclusive(ax)
def min_inclusive(self, ax):
try:
ci = self.interval.at(self._uncertain)
inclusive = '\u2208'
except OverflowError:
ci = 1
inclusive = '\u2209'
except FloatingPointError:
return
ax.text(x=0.02,
y=0.975,
s=f'{self._uncertain} {inclusive} {ci:.0%} HDI',
fontsize='small',
fontstyle='italic',
horizontalalignment='left',
verticalalignment='top',
transform=ax.transAxes)
def pr(self, ax, hdi, color):
x = self.df.median()
zorder = ax.zorder - 1
(label, *_) = ax.get_xticklabels()
parts = label.get_text().split('.')
decimals = len(parts[-1]) + 1 if parts else 2
fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}'
ax.axvline(x=x,
color=color,
linestyle='dashed')
ax.axvspan(xmin=hdi.lower,
xmax=hdi.upper,
alpha=0.15,
color=color,
zorder=zorder)
ax_ = ax.secondary_xaxis('top')
ax_.xaxis.set_major_locator(FixedLocator([x]))
ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt))
#
#
#
class ComparisonMenu:
def __init__(self, df, ci=95):
self.df = df
self.ci = ci
def __call__(self, model_1, model_2, ci):
if model_1 and model_2:
ci /= 100
cp = ComparisonPlotter(self.df, model_1, model_2, ci)
return cp.plot()
def build_and_get(self):
models = self.df['source'].unique()
choices = sorted(models, key=lambda x: x.lower())
for i in range(1, 3):
label = f'Model {i}'
yield gr.Dropdown(label=label, choices=choices)
yield gr.Number(value=self.ci,
label='HDI (%)',
minimum=0,
maximum=100)
#
#
#
class ParameterLayout:
def __init__(self, name, ci=0.95, **kwargs):
self.name = name
self.ci = ci
self.summarize = DataSummarizer(self.name, self.ci)
def __str__(self):
return self.name
def __call__(self, df):
self.extra(df, self.table(df))
def table(self, df):
view = self.summarize(df)
columns = {
x.name: f'{self.ci:.0%} HDI {x.name}' for x in fields(HDI)
}
for i in view.columns:
columns.setdefault(i, i.title())
view = (view
.rename(columns=columns)
.style
.format(precision=4))
with gr.Row():
return gr.Dataframe(view)
def extra(self, df, *args):
raise NotImplementedError()
class PersonLayout(ParameterLayout):
def __init__(self, ci=0.95, **kwargs):
super().__init__('ability', ci, **kwargs)
def extra(self, df, *args):
with gr.Row():
with gr.Column():
gr.Markdown(f'''
Probability that Model 1 is preferred to Model 2. The
histogram is represents the distribution of the
difference in estimated model abilities. The dashed
vertical line is its median. The shaded region
demarcates the chosen [highest density
interval](https://cran.r-project.org/package=HDInterval)
(HDI). The note in the upper left denotes the smallest
HDI that is inclusive of
{ComparisonPlotter._uncertain}.
''')
with gr.Column(scale=3):
display = gr.Plot()
with gr.Column():
menu = ComparisonMenu(df)
inputs = list(menu.build_and_get())
button = gr.Button(value='Compare!')
button.click(menu, inputs=inputs, outputs=[display])
class ItemLayout(ParameterLayout):
_prefix = Path('jerome-white', 'leaderboard-documents')
def __init__(self, name, framework, ci=0.95, **kwargs):
super().__init__(name, ci, **kwargs)
ds = load_dataset(f'{self._prefix}-{framework}', split='train')
self.df = (ds
.to_pandas()
.set_index('doc'))
def extra(self, df, *args):
(frame, ) = args
frame.select(self.df_select_callback,
inputs=[frame],
outputs=[gr.JSON()])
def df_select_callback(self, df: pd.DataFrame, evt: gr.SelectData):
index = df.columns.get_loc('Source')
doc = evt.row_value[index]
return self.df.loc[(doc, 'info')]
#
#
#
with gr.Blocks() as demo:
path = str(Path('jerome-white', 'leaderboard-item-response'))
splits = get_dataset_split_names(path)
pmap = {
'alpha': ft.partial(ItemLayout, name='discrimination'),
'beta': ft.partial(ItemLayout, name='difficulty'),
'theta': PersonLayout,
}
for s in splits:
df = load_dataset(path, split=s).to_pandas()
with gr.Tab(s):
for (i, g) in df.groupby('parameter'):
layout = pmap.get(i)(framework=s)
tab = str(layout)
with gr.Tab(tab.title()):
layout(g)
demo.launch()