import itertools as it import functools as ft from string import Template from pathlib import Path from dataclasses import fields, asdict import pandas as pd import gradio as gr import seaborn as sns import matplotlib.pyplot as plt from scipy.special import expit from datasets import load_dataset, get_dataset_split_names from matplotlib.ticker import FixedLocator, StrMethodFormatter from hdinterval import HDI, HDInterval # # # def read_md(md): path = Path('docs', md) return (path .with_suffix('.md') .read_text()) def compare(df, model_1, model_2): columns = 'source' mask = df[columns].isin([model_1, model_2]) view = df[mask].pivot( index=['chain', 'sample'], columns=columns, values='value', ) return expit(view[model_1] - view[model_2]) # # # class DataSummarizer: def __init__(self, name, ci): self.name = name self.ci = ci def __call__(self, df): df = (self .summarize(df) .sort_values(by=['value', 'uncertainty'], ascending=[False, True]) .drop(columns='uncertainty') .reset_index(drop=True)) df.index += 1 return df.reset_index(names='rank') def aggregate(self, source, df): values = df['value'] hdi = HDInterval(values) interval = hdi(self.ci) agg = { 'source': source, 'value': values.median(), 'uncertainty': interval.width(), } agg.update(asdict(interval)) return agg def summarize(self, df): groups = df.groupby('source', sort=False) records = it.starmap(self.aggregate, groups) return pd.DataFrame.from_records(records) # # # class DataPlotter: def __init__(self, df): self.df = df def plot(self): fig = plt.figure(dpi=200) ax = fig.gca() self.draw(ax) ax.grid(visible=True, axis='both', alpha=0.25, linestyle='dotted') fig.tight_layout() return fig def draw(self, ax): raise NotImplementedError() class ComparisonPlotter(DataPlotter): _uncertain = 0.5 _theta = '\u03b8' def __init__(self, df, model_1, model_2, ci): super().__init__(compare(df, model_1, model_2)) self.interval = HDInterval(self.df) self.ci = ci def draw(self, ax): hdi = self.interval(self.ci) (c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2) ax = sns.histplot(data=self.df, stat='density', color=c_hist) ax.set_xlabel(f'{self._theta}$_{{1}}$ - {self._theta}$_{{2}}$') self.pr(ax, hdi, c_hdi) self.min_inclusive(ax) def min_inclusive(self, ax): try: ci = self.interval.at(self._uncertain) inclusive = '\u2208' except OverflowError: ci = 1 inclusive = '\u2209' except FloatingPointError: return ax.text(x=0.02, y=0.975, s=f'{self._uncertain} {inclusive} {ci:.0%} HDI', fontsize='small', fontstyle='italic', horizontalalignment='left', verticalalignment='top', transform=ax.transAxes) def pr(self, ax, hdi, color): x = self.df.median() zorder = ax.zorder - 1 (label, *_) = ax.get_xticklabels() parts = label.get_text().split('.') decimals = len(parts[-1]) + 1 if parts else 2 fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}' ax.axvline(x=x, color=color, linestyle='dashed') ax.axvspan(xmin=hdi.lower, xmax=hdi.upper, alpha=0.15, color=color, zorder=zorder) ax_ = ax.secondary_xaxis('top') ax_.xaxis.set_major_locator(FixedLocator([x])) ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt)) # # # class ComparisonMenu: def __init__(self, df, ci=95): self.df = df self.ci = ci def __call__(self, model_1, model_2, ci): if model_1 and model_2: ci /= 100 cp = ComparisonPlotter(self.df, model_1, model_2, ci) return cp.plot() def build_and_get(self): models = self.df['source'].unique() choices = sorted(models, key=lambda x: x.lower()) for i in range(1, 3): label = f'Model {i}' yield gr.Dropdown(label=label, choices=choices) yield gr.Number(value=self.ci, label='HDI (%)', minimum=0, maximum=100) # # # class ParameterLayout: def __init__(self, name, ci=0.95, **kwargs): self.name = name self.ci = ci self.summarize = DataSummarizer(self.name, self.ci) def __str__(self): return self.name def __call__(self, df): self.extra(df, self.table(df)) def table(self, df): view = self.summarize(df) columns = { x.name: f'{self.ci:.0%} HDI {x.name}' for x in fields(HDI) } for i in view.columns: columns.setdefault(i, i.title()) view = (view .rename(columns=columns) .style .format(precision=4)) with gr.Row(): return gr.Dataframe(view, interactive=False) def extra(self, df, *args): raise NotImplementedError() class PersonLayout(ParameterLayout): def __init__(self, ci=0.95, **kwargs): super().__init__('ability', ci, **kwargs) text = Template(read_md(self.name)) self.text = text.substitute(ci=ComparisonPlotter._uncertain) def extra(self, df, *args): with gr.Row(): with gr.Column(): gr.Markdown(self.text) with gr.Column(scale=3): display = gr.Plot() with gr.Column(): menu = ComparisonMenu(df) inputs = list(menu.build_and_get()) button = gr.Button(value='Compare!') button.click(menu, inputs=inputs, outputs=[display]) class ItemLayout(ParameterLayout): _prefix = Path('jerome-white', 'leaderboard-documents') def __init__(self, name, framework, ci=0.95, **kwargs): super().__init__(name, ci, **kwargs) ds = load_dataset(f'{self._prefix}-{framework}', split='train') self.df = (ds .to_pandas() .set_index('doc')) def extra(self, df, *args): (frame, ) = args frame.select(self.df_select_callback, inputs=[frame], outputs=[gr.JSON()]) def df_select_callback(self, df: pd.DataFrame, evt: gr.SelectData): index = df.columns.get_loc('Source') doc = evt.row_value[index] return self.df.loc[(doc, 'info')] # # # with gr.Blocks() as demo: path = str(Path('jerome-white', 'leaderboard-item-response')) splits = get_dataset_split_names(path) pmap = { 'alpha': ft.partial(ItemLayout, name='discrimination'), 'beta': ft.partial(ItemLayout, name='difficulty'), 'theta': PersonLayout, } with gr.Row(): gr.Markdown('# OpenLLM Leaderboard IRT') with gr.Accordion('About this Space', open=False): gr.Markdown(read_md('info')) for s in splits: df = load_dataset(path, split=s).to_pandas() with gr.Tab(s.upper()): for (i, g) in df.groupby('parameter'): layout = pmap.get(i)(framework=s) tab = str(layout) with gr.Tab(tab.title()): layout(g) demo.launch(ssr_mode=False)