jerome-white's picture
Squeezed overview for easier reading
ab6dff0
import math
import operator as op
import itertools as it
import functools as ft
import collections as cl
from pathlib import Path
import pandas as pd
import gradio as gr
from datasets import load_dataset
HDI = cl.namedtuple('HDI', 'lower, upper')
Parameter = cl.namedtuple('Parameter', 'name, ptype, gist')
#
# See https://cran.r-project.org/package=HDInterval
#
def hdi(values, ci=0.95):
values = sorted(filter(math.isfinite, values))
if not values:
raise ValueError('Empty data set')
n = len(values)
exclude = n - math.floor(n * ci)
left = it.islice(values, exclude)
right = it.islice(values, n - exclude, None)
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
(*args, _) = min(diffs, key=op.itemgetter(-1))
return HDI(*args)
#
#
#
def load(repo):
parameter = 'parameter'
dataset = load_dataset(repo)
return (dataset
.get('train')
.to_pandas()
.filter(items=[
parameter,
'element',
'value',
])
.groupby(parameter, sort=False))
def parameters(groups):
_params = it.starmap(Parameter, (
('alpha', 'prompt', 'discrimination'),
('beta', 'prompt', 'difficulty'),
('theta', 'model', 'ability'),
))
lookup = { x.name: x for x in _params }
for (i, _) in groups:
if i in lookup:
yield lookup[i]
@ft.singledispatch
def get(param, group):
raise TypeError(type(param))
@get.register
def _(param: str, group):
return group.get_group(param)
@get.register
def _(param: Parameter, group):
return get(param.name, group)
def summarize(param, df, ci=0.95):
def _aggregate(i, g):
values = g['value']
interval = hdi(values, ci)
agg = {
param.ptype: i,
param.gist: values.median(),
'uncertainty': interval.upper - interval.lower,
}
agg.update(interval._asdict())
return agg
groups = df.groupby('element', sort=False)
records = it.starmap(_aggregate, groups)
return pd.DataFrame.from_records(records)
def rank(param, df, ascending, name='rank'):
uncertainty = 'uncertainty'
df = (df
.sort_values(by=[param.gist, uncertainty],
ascending=[ascending, not ascending])
.drop(columns=uncertainty)
.reset_index(drop=True))
df.index += 1
return df.reset_index(names=name)
def md_reader(name, prefix='_'):
path = Path(f'{prefix}{name.upper()}')
return (path
.with_suffix('.md')
.read_text())
#
#
#
with gr.Blocks() as demo:
data = load('jerome-white/alpaca-irt-stan')
gr.Markdown('# Alpaca Item Response')
with gr.Row():
with gr.Column():
gr.Markdown(md_reader('readme'))
with gr.Column():
pass
for i in parameters(data):
with gr.Row():
view = rank(i, summarize(i, get(i, data)), False)
columns = { x: f'HDI {x}' for x in HDI._fields }
for i in view.columns:
columns.setdefault(i, i.title())
view = (view
.rename(columns=columns)
.style.format(precision=4))
gr.Dataframe(view, wrap=True)
demo.launch()