|
import itertools as it |
|
import functools as ft |
|
from pathlib import Path |
|
from dataclasses import fields, asdict |
|
|
|
import pandas as pd |
|
import gradio as gr |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
from datasets import load_dataset, get_dataset_split_names |
|
from scipy.special import expit |
|
from matplotlib.ticker import FixedLocator, StrMethodFormatter |
|
|
|
from hdinterval import HDI, HDInterval |
|
|
|
|
|
|
|
|
|
def compare(df, model_1, model_2): |
|
mcol = 'source' |
|
models = [ |
|
model_1, |
|
model_2, |
|
] |
|
view = (df |
|
.query(f'{mcol} in @models') |
|
.pivot(index=['chain', 'sample'], |
|
columns=mcol, |
|
values='value')) |
|
|
|
return expit(view[model_1] - view[model_2]) |
|
|
|
|
|
|
|
|
|
class DataSummarizer: |
|
def __init__(self, name, ci): |
|
self.name = name |
|
self.ci = ci |
|
|
|
def __call__(self, df): |
|
df = (self |
|
.summarize(df) |
|
.sort_values(by=['value', 'uncertainty'], |
|
ascending=[False, True]) |
|
.drop(columns='uncertainty') |
|
.reset_index(drop=True)) |
|
df.index += 1 |
|
|
|
return df.reset_index(names='rank') |
|
|
|
def aggregate(self, source, df): |
|
values = df['value'] |
|
hdi = HDInterval(values) |
|
interval = hdi(self.ci) |
|
|
|
agg = { |
|
'source': source, |
|
'value': values.median(), |
|
'uncertainty': interval.width(), |
|
} |
|
agg.update(asdict(interval)) |
|
|
|
return agg |
|
|
|
def summarize(self, df): |
|
groups = df.groupby('source', sort=False) |
|
records = it.starmap(self.aggregate, groups) |
|
|
|
return pd.DataFrame.from_records(records) |
|
|
|
|
|
|
|
|
|
class DataPlotter: |
|
def __init__(self, df): |
|
self.df = df |
|
|
|
def plot(self): |
|
fig = plt.figure(dpi=200) |
|
|
|
ax = fig.gca() |
|
self.draw(ax) |
|
ax.grid(visible=True, |
|
axis='both', |
|
alpha=0.25, |
|
linestyle='dotted') |
|
fig.tight_layout() |
|
|
|
return fig |
|
|
|
def draw(self, ax): |
|
raise NotImplementedError() |
|
|
|
class ComparisonPlotter(DataPlotter): |
|
_uncertain = 0.5 |
|
|
|
def __init__(self, df, model_1, model_2, ci): |
|
super().__init__(compare(df, model_1, model_2)) |
|
self.interval = HDInterval(self.df) |
|
self.ci = ci |
|
|
|
def draw(self, ax): |
|
hdi = self.interval(self.ci) |
|
(c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2) |
|
|
|
ax = sns.histplot(data=self.df, |
|
stat='density', |
|
color=c_hist) |
|
ax.set_xlabel('\u03B1$_{1}$ - \u03B1$_{2}$') |
|
|
|
self.pr(ax, hdi, c_hdi) |
|
self.min_inclusive(ax) |
|
|
|
def min_inclusive(self, ax): |
|
try: |
|
ci = self.interval.at(self._uncertain) |
|
inclusive = '\u2208' |
|
except OverflowError: |
|
ci = 1 |
|
inclusive = '\u2209' |
|
except FloatingPointError: |
|
return |
|
|
|
ax.text(x=0.02, |
|
y=0.975, |
|
s=f'{self._uncertain} {inclusive} {ci:.0%} HDI', |
|
fontsize='small', |
|
fontstyle='italic', |
|
horizontalalignment='left', |
|
verticalalignment='top', |
|
transform=ax.transAxes) |
|
|
|
def pr(self, ax, hdi, color): |
|
x = self.df.median() |
|
zorder = ax.zorder - 1 |
|
|
|
(label, *_) = ax.get_xticklabels() |
|
parts = label.get_text().split('.') |
|
decimals = len(parts[-1]) + 1 if parts else 2 |
|
fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}' |
|
|
|
ax.axvline(x=x, |
|
color=color, |
|
linestyle='dashed') |
|
ax.axvspan(xmin=hdi.lower, |
|
xmax=hdi.upper, |
|
alpha=0.15, |
|
color=color, |
|
zorder=zorder) |
|
|
|
ax_ = ax.secondary_xaxis('top') |
|
ax_.xaxis.set_major_locator(FixedLocator([x])) |
|
ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt)) |
|
|
|
|
|
|
|
|
|
class ComparisonMenu: |
|
def __init__(self, df, ci=95): |
|
self.df = df |
|
self.ci = ci |
|
|
|
def __call__(self, model_1, model_2, ci): |
|
if model_1 and model_2: |
|
ci /= 100 |
|
cp = ComparisonPlotter(self.df, model_1, model_2, ci) |
|
|
|
return cp.plot() |
|
|
|
def build_and_get(self): |
|
models = self.df['source'].unique() |
|
choices = sorted(models, key=lambda x: x.lower()) |
|
|
|
for i in range(1, 3): |
|
label = f'Model {i}' |
|
yield gr.Dropdown(label=label, choices=choices) |
|
|
|
yield gr.Number(value=self.ci, |
|
label='HDI (%)', |
|
minimum=0, |
|
maximum=100) |
|
|
|
|
|
|
|
|
|
class ParameterLayout: |
|
def __init__(self, name, ci=0.95, **kwargs): |
|
self.name = name |
|
self.ci = ci |
|
self.summarize = DataSummarizer(self.name, self.ci) |
|
|
|
def __str__(self): |
|
return self.name |
|
|
|
def __call__(self, df): |
|
self.extra(df, self.table(df)) |
|
|
|
def table(self, df): |
|
view = self.summarize(df) |
|
columns = { |
|
x.name: f'{self.ci:.0%} HDI {x.name}' for x in fields(HDI) |
|
} |
|
for i in view.columns: |
|
columns.setdefault(i, i.title()) |
|
view = (view |
|
.rename(columns=columns) |
|
.style |
|
.format(precision=4)) |
|
|
|
with gr.Row(): |
|
return gr.Dataframe(view) |
|
|
|
def extra(self, df, *args): |
|
raise NotImplementedError() |
|
|
|
class PersonLayout(ParameterLayout): |
|
def __init__(self, ci=0.95, **kwargs): |
|
super().__init__('ability', ci, **kwargs) |
|
|
|
def extra(self, df, *args): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(f''' |
|
|
|
Probability that Model 1 is preferred to Model 2. The |
|
histogram is represents the distribution of the |
|
difference in estimated model abilities. The dashed |
|
vertical line is its median. The shaded region |
|
demarcates the chosen [highest density |
|
interval](https://cran.r-project.org/package=HDInterval) |
|
(HDI). The note in the upper left denotes the smallest |
|
HDI that is inclusive of |
|
{ComparisonPlotter._uncertain}. |
|
|
|
''') |
|
|
|
with gr.Column(scale=3): |
|
display = gr.Plot() |
|
|
|
with gr.Column(): |
|
menu = ComparisonMenu(df) |
|
inputs = list(menu.build_and_get()) |
|
button = gr.Button(value='Compare!') |
|
button.click(menu, inputs=inputs, outputs=[display]) |
|
|
|
class ItemLayout(ParameterLayout): |
|
_prefix = Path('jerome-white', 'leaderboard-documents') |
|
|
|
def __init__(self, name, framework, ci=0.95, **kwargs): |
|
super().__init__(name, ci, **kwargs) |
|
ds = load_dataset(f'{self._prefix}-{framework}', split='train') |
|
self.df = (ds |
|
.to_pandas() |
|
.set_index('doc')) |
|
|
|
def extra(self, df, *args): |
|
(frame, ) = args |
|
frame.select(self.df_select_callback, |
|
inputs=[frame], |
|
outputs=[gr.JSON()]) |
|
|
|
def df_select_callback(self, df: pd.DataFrame, evt: gr.SelectData): |
|
index = df.columns.get_loc('Source') |
|
doc = evt.row_value[index] |
|
return self.df.loc[(doc, 'info')] |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
path = str(Path('jerome-white', 'leaderboard-item-response')) |
|
splits = get_dataset_split_names(path) |
|
pmap = { |
|
'alpha': ft.partial(ItemLayout, name='discrimination'), |
|
'beta': ft.partial(ItemLayout, name='difficulty'), |
|
'theta': PersonLayout, |
|
} |
|
|
|
for s in splits: |
|
df = load_dataset(path, split=s).to_pandas() |
|
with gr.Tab(s): |
|
for (i, g) in df.groupby('parameter'): |
|
layout = pmap.get(i)(framework=s) |
|
tab = str(layout) |
|
with gr.Tab(tab.title()): |
|
layout(g) |
|
|
|
demo.launch() |
|
|