Commit
·
9fe0c6d
1
Parent(s):
87caf61
Basic implementation and layout
Browse files* Load data
* Build tabs
* Generate tables
* Plot ability
- app.py +304 -0
- hdinterval.py +78 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools as it
|
2 |
+
import functools as ft
|
3 |
+
from pathlib import Path
|
4 |
+
from dataclasses import fields, asdict
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import gradio as gr
|
8 |
+
import seaborn as sns
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from datasets import load_dataset, get_dataset_split_names
|
11 |
+
from scipy.special import expit
|
12 |
+
from matplotlib.ticker import FixedLocator, StrMethodFormatter
|
13 |
+
|
14 |
+
from hdinterval import HDI, HDInterval
|
15 |
+
|
16 |
+
#
|
17 |
+
#
|
18 |
+
#
|
19 |
+
def summarize(df, ci):
|
20 |
+
def _aggregate(i, g):
|
21 |
+
values = g['value']
|
22 |
+
hdi = HDInterval(values)
|
23 |
+
interval = hdi(ci)
|
24 |
+
|
25 |
+
agg = {
|
26 |
+
'source': i,
|
27 |
+
'value': values.median(),
|
28 |
+
'uncertainty': interval.width(),
|
29 |
+
}
|
30 |
+
agg.update(asdict(interval))
|
31 |
+
|
32 |
+
return agg
|
33 |
+
|
34 |
+
groups = df.groupby('source', sort=False)
|
35 |
+
records = it.starmap(_aggregate, groups)
|
36 |
+
|
37 |
+
return pd.DataFrame.from_records(records)
|
38 |
+
|
39 |
+
def rank(df, ascending, name='rank'):
|
40 |
+
df = (df
|
41 |
+
.sort_values(by=['value', 'uncertainty'],
|
42 |
+
ascending=[ascending, not ascending])
|
43 |
+
.drop(columns='uncertainty')
|
44 |
+
.reset_index(drop=True))
|
45 |
+
df.index += 1
|
46 |
+
|
47 |
+
return df.reset_index(names=name)
|
48 |
+
|
49 |
+
def compare(df, model_1, model_2):
|
50 |
+
mcol = 'source'
|
51 |
+
models = [
|
52 |
+
model_1,
|
53 |
+
model_2,
|
54 |
+
]
|
55 |
+
view = (df
|
56 |
+
.query(f'{mcol} in @models')
|
57 |
+
.pivot(index=['chain', 'sample'],
|
58 |
+
columns=mcol,
|
59 |
+
values='value'))
|
60 |
+
|
61 |
+
return expit(view[model_1] - view[model_2])
|
62 |
+
|
63 |
+
#
|
64 |
+
#
|
65 |
+
#
|
66 |
+
class DataPlotter:
|
67 |
+
def __init__(self, df):
|
68 |
+
self.df = df
|
69 |
+
|
70 |
+
def plot(self):
|
71 |
+
fig = plt.figure(dpi=200)
|
72 |
+
|
73 |
+
ax = fig.gca()
|
74 |
+
self.draw(ax)
|
75 |
+
ax.grid(visible=True,
|
76 |
+
axis='both',
|
77 |
+
alpha=0.25,
|
78 |
+
linestyle='dotted')
|
79 |
+
fig.tight_layout()
|
80 |
+
|
81 |
+
return fig
|
82 |
+
|
83 |
+
def draw(self, ax):
|
84 |
+
raise NotImplementedError()
|
85 |
+
|
86 |
+
class RankPlotter(DataPlotter):
|
87 |
+
_y = 'y'
|
88 |
+
|
89 |
+
@ft.cached_property
|
90 |
+
def y(self):
|
91 |
+
return self.df[self._y]
|
92 |
+
|
93 |
+
def __init__(self, df, ci=0.95, top=10):
|
94 |
+
self.ci = ci
|
95 |
+
view = rank(summarize(df, self.ci), True, self._y)
|
96 |
+
view = (view
|
97 |
+
.tail(top)
|
98 |
+
.sort_values(by=self._y, ascending=False))
|
99 |
+
|
100 |
+
super().__init__(view)
|
101 |
+
|
102 |
+
def draw(self, ax):
|
103 |
+
self.df.plot.scatter('value', self._y, ax=ax)
|
104 |
+
ax.hlines(self.y,
|
105 |
+
xmin=self.df['lower'],
|
106 |
+
xmax=self.df['upper'],
|
107 |
+
alpha=0.5)
|
108 |
+
ax.set_xlabel('{} (with {:.0%} HDI)'.format(
|
109 |
+
ax.get_xlabel().title(),
|
110 |
+
self.ci,
|
111 |
+
))
|
112 |
+
ax.set_ylabel('')
|
113 |
+
ax.set_yticks(self.y, self.df['source'])
|
114 |
+
|
115 |
+
class ComparisonPlotter(DataPlotter):
|
116 |
+
_uncertain = 0.5
|
117 |
+
|
118 |
+
def __init__(self, df, model_1, model_2, ci):
|
119 |
+
super().__init__(compare(df, model_1, model_2))
|
120 |
+
self.interval = HDInterval(self.df)
|
121 |
+
self.ci = ci
|
122 |
+
|
123 |
+
def draw(self, ax):
|
124 |
+
hdi = self.interval(self.ci)
|
125 |
+
(c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2)
|
126 |
+
|
127 |
+
ax = sns.histplot(data=self.df,
|
128 |
+
stat='density',
|
129 |
+
color=c_hist)
|
130 |
+
ax.set_xlabel('\u03B1$_{1}$ - \u03B1$_{2}$')
|
131 |
+
|
132 |
+
self.pr(ax, hdi, c_hdi)
|
133 |
+
self.min_inclusive(ax)
|
134 |
+
|
135 |
+
def min_inclusive(self, ax):
|
136 |
+
try:
|
137 |
+
ci = self.interval.at(self._uncertain)
|
138 |
+
inclusive = '\u2208'
|
139 |
+
except OverflowError:
|
140 |
+
ci = 1
|
141 |
+
inclusive = '\u2209'
|
142 |
+
except FloatingPointError:
|
143 |
+
return
|
144 |
+
|
145 |
+
ax.text(x=0.02,
|
146 |
+
y=0.975,
|
147 |
+
s=f'{self._uncertain} {inclusive} {ci:.0%} HDI',
|
148 |
+
fontsize='small',
|
149 |
+
fontstyle='italic',
|
150 |
+
horizontalalignment='left',
|
151 |
+
verticalalignment='top',
|
152 |
+
transform=ax.transAxes)
|
153 |
+
|
154 |
+
def pr(self, ax, hdi, color):
|
155 |
+
x = self.df.median()
|
156 |
+
zorder = ax.zorder - 1
|
157 |
+
|
158 |
+
(label, *_) = ax.get_xticklabels()
|
159 |
+
parts = label.get_text().split('.')
|
160 |
+
decimals = len(parts[-1]) + 1 if parts else 2
|
161 |
+
fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}'
|
162 |
+
|
163 |
+
ax.axvline(x=x,
|
164 |
+
color=color,
|
165 |
+
linestyle='dashed')
|
166 |
+
ax.axvspan(xmin=hdi.lower,
|
167 |
+
xmax=hdi.upper,
|
168 |
+
alpha=0.15,
|
169 |
+
color=color,
|
170 |
+
zorder=zorder)
|
171 |
+
|
172 |
+
ax_ = ax.secondary_xaxis('top')
|
173 |
+
ax_.xaxis.set_major_locator(FixedLocator([x]))
|
174 |
+
ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt))
|
175 |
+
|
176 |
+
#
|
177 |
+
#
|
178 |
+
#
|
179 |
+
class ComparisonMenu:
|
180 |
+
def __init__(self, df, ci=95):
|
181 |
+
self.df = df
|
182 |
+
self.ci = ci
|
183 |
+
|
184 |
+
def __call__(self, model_1, model_2, ci):
|
185 |
+
if model_1 and model_2:
|
186 |
+
ci /= 100
|
187 |
+
cp = ComparisonPlotter(self.df, model_1, model_2, ci)
|
188 |
+
|
189 |
+
return cp.plot()
|
190 |
+
|
191 |
+
def build_and_get(self):
|
192 |
+
models = self.df['source'].unique()
|
193 |
+
choices = sorted(models, key=lambda x: x.lower())
|
194 |
+
|
195 |
+
for i in range(1, 3):
|
196 |
+
label = f'Model {i}'
|
197 |
+
yield gr.Dropdown(label=label, choices=choices)
|
198 |
+
|
199 |
+
yield gr.Number(value=self.ci,
|
200 |
+
label='HDI (%)',
|
201 |
+
minimum=0,
|
202 |
+
maximum=100)
|
203 |
+
|
204 |
+
#
|
205 |
+
#
|
206 |
+
#
|
207 |
+
class DocumentationReader:
|
208 |
+
_suffix = '.md'
|
209 |
+
|
210 |
+
def __init__(self, root):
|
211 |
+
self.root = root
|
212 |
+
|
213 |
+
def __getitem__(self, item):
|
214 |
+
return (self
|
215 |
+
.root
|
216 |
+
.joinpath(item)
|
217 |
+
.with_suffix(self._suffix)
|
218 |
+
.read_text())
|
219 |
+
|
220 |
+
#
|
221 |
+
#
|
222 |
+
#
|
223 |
+
def layout_a(df, ci=0.95):
|
224 |
+
with gr.Row():
|
225 |
+
view = rank(summarize(df, ci), False)
|
226 |
+
columns = { x.name: f'{ci:.0%} HDI {x.name}' for x in fields(HDI) }
|
227 |
+
for i in view.columns:
|
228 |
+
columns.setdefault(i, i.title())
|
229 |
+
view = (view
|
230 |
+
.rename(columns=columns)
|
231 |
+
.style.format(precision=4))
|
232 |
+
|
233 |
+
gr.Dataframe(view)
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column():
|
237 |
+
gr.Markdown(f'''
|
238 |
+
|
239 |
+
Probability that Model 1 has higher ability than Model
|
240 |
+
2. The histogram is represents the distribution of the
|
241 |
+
difference in estimated model abilities. The dashed
|
242 |
+
vertical line is its median. The shaded region demarcates
|
243 |
+
the chosen [highest density
|
244 |
+
interval](https://cran.r-project.org/package=HDInterval)
|
245 |
+
(HDI). The note in the upper left denotes the smallest HDI
|
246 |
+
that is inclusive of {ComparisonPlotter._uncertain}.
|
247 |
+
|
248 |
+
''')
|
249 |
+
|
250 |
+
with gr.Column(scale=3):
|
251 |
+
display = gr.Plot()
|
252 |
+
|
253 |
+
with gr.Column():
|
254 |
+
menu = ComparisonMenu(df)
|
255 |
+
inputs = list(menu.build_and_get())
|
256 |
+
button = gr.Button(value='Compare!')
|
257 |
+
button.click(menu, inputs=inputs, outputs=[display])
|
258 |
+
|
259 |
+
def layout_b(df, ci=0.95):
|
260 |
+
for (i, g) in df.groupby('parameter'):
|
261 |
+
with gr.Tab(i):
|
262 |
+
with gr.Row():
|
263 |
+
view = rank(summarize(g, ci), False)
|
264 |
+
columns = {
|
265 |
+
x.name: f'{ci:.0%} HDI {x.name}' for x in fields(HDI)
|
266 |
+
}
|
267 |
+
for i in view.columns:
|
268 |
+
columns.setdefault(i, i.title())
|
269 |
+
view = (view
|
270 |
+
.rename(columns=columns)
|
271 |
+
.style.format(precision=4))
|
272 |
+
|
273 |
+
gr.Dataframe(view)
|
274 |
+
|
275 |
+
#
|
276 |
+
#
|
277 |
+
#
|
278 |
+
def grouper(df):
|
279 |
+
def classifier(idx):
|
280 |
+
theta = df.loc[(idx, 'parameter')] == 'theta'
|
281 |
+
return 'person' if theta else 'item'
|
282 |
+
|
283 |
+
return classifier
|
284 |
+
|
285 |
+
#
|
286 |
+
#
|
287 |
+
#
|
288 |
+
with gr.Blocks() as demo:
|
289 |
+
path = str(Path('jerome-white', 'leaderboard-item-response'))
|
290 |
+
splits = get_dataset_split_names(path)
|
291 |
+
|
292 |
+
for s in splits:
|
293 |
+
data = load_dataset(path, split=s)
|
294 |
+
df = data.to_pandas()
|
295 |
+
groups = df.groupby(grouper(df), sort=False)
|
296 |
+
|
297 |
+
with gr.Tab(s):
|
298 |
+
with gr.Row():
|
299 |
+
with gr.Column():
|
300 |
+
layout_a(groups.get_group('person'))
|
301 |
+
with gr.Column():
|
302 |
+
layout_b(groups.get_group('item'))
|
303 |
+
|
304 |
+
demo.launch(server_name='0.0.0.0', debug=True)
|
hdinterval.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
import operator as op
|
4 |
+
import itertools as it
|
5 |
+
import functools as ft
|
6 |
+
import statistics as st
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class HDI:
|
11 |
+
lower: float
|
12 |
+
upper: float
|
13 |
+
|
14 |
+
def __iter__(self):
|
15 |
+
yield from (self.lower, self.upper)
|
16 |
+
|
17 |
+
def __contains__(self, item):
|
18 |
+
return self.lower <= item <= self.upper
|
19 |
+
|
20 |
+
def width(self):
|
21 |
+
return self.upper - self.lower
|
22 |
+
|
23 |
+
class HDInterval:
|
24 |
+
@ft.cached_property
|
25 |
+
def values(self):
|
26 |
+
view = sorted(filter(math.isfinite, self._values))
|
27 |
+
if not view:
|
28 |
+
raise AttributeError('Empty data set')
|
29 |
+
|
30 |
+
return view
|
31 |
+
|
32 |
+
def __init__(self, values):
|
33 |
+
self._values = values
|
34 |
+
|
35 |
+
#
|
36 |
+
# See https://cran.r-project.org/package=HDInterval
|
37 |
+
#
|
38 |
+
def __call__(self, ci=0.95):
|
39 |
+
if ci == 1:
|
40 |
+
args = (self.values[x] for x in (0, -1))
|
41 |
+
else:
|
42 |
+
n = len(self.values)
|
43 |
+
exclude = n - math.floor(n * ci)
|
44 |
+
|
45 |
+
left = it.islice(self.values, exclude)
|
46 |
+
right = it.islice(self.values, n - exclude, None)
|
47 |
+
|
48 |
+
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
|
49 |
+
(*args, _) = min(diffs, key=op.itemgetter(-1))
|
50 |
+
|
51 |
+
return HDI(*args)
|
52 |
+
|
53 |
+
def _at(self, target, tolerance, ci=1, jump=1):
|
54 |
+
if ci > 1:
|
55 |
+
raise OverflowError()
|
56 |
+
|
57 |
+
hdi = self(ci)
|
58 |
+
if any(x in tolerance for x in hdi):
|
59 |
+
return ci
|
60 |
+
|
61 |
+
adjust = op.sub if target in hdi else op.add
|
62 |
+
ci = adjust(ci, jump)
|
63 |
+
jump /= 2
|
64 |
+
|
65 |
+
return self._at(target, tolerance, ci, jump)
|
66 |
+
|
67 |
+
def at(self, target, tolerance=1e-4):
|
68 |
+
assert tolerance > 0
|
69 |
+
|
70 |
+
while tolerance < 1:
|
71 |
+
hdi = HDI(target, target + tolerance)
|
72 |
+
try:
|
73 |
+
return self._at(target, hdi)
|
74 |
+
except RecursionError:
|
75 |
+
tolerance *= 10
|
76 |
+
warnings.warn(f'Tolerance reduced: {tolerance}')
|
77 |
+
|
78 |
+
raise FloatingPointError('Unable to converge')
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
gradio
|
3 |
+
matplotlib
|
4 |
+
pandas
|
5 |
+
seaborn
|
6 |
+
scipy
|