jerome-white commited on
Commit
9fe0c6d
·
1 Parent(s): 87caf61

Basic implementation and layout

Browse files

* Load data
* Build tabs
* Generate tables
* Plot ability

Files changed (3) hide show
  1. app.py +304 -0
  2. hdinterval.py +78 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools as it
2
+ import functools as ft
3
+ from pathlib import Path
4
+ from dataclasses import fields, asdict
5
+
6
+ import pandas as pd
7
+ import gradio as gr
8
+ import seaborn as sns
9
+ import matplotlib.pyplot as plt
10
+ from datasets import load_dataset, get_dataset_split_names
11
+ from scipy.special import expit
12
+ from matplotlib.ticker import FixedLocator, StrMethodFormatter
13
+
14
+ from hdinterval import HDI, HDInterval
15
+
16
+ #
17
+ #
18
+ #
19
+ def summarize(df, ci):
20
+ def _aggregate(i, g):
21
+ values = g['value']
22
+ hdi = HDInterval(values)
23
+ interval = hdi(ci)
24
+
25
+ agg = {
26
+ 'source': i,
27
+ 'value': values.median(),
28
+ 'uncertainty': interval.width(),
29
+ }
30
+ agg.update(asdict(interval))
31
+
32
+ return agg
33
+
34
+ groups = df.groupby('source', sort=False)
35
+ records = it.starmap(_aggregate, groups)
36
+
37
+ return pd.DataFrame.from_records(records)
38
+
39
+ def rank(df, ascending, name='rank'):
40
+ df = (df
41
+ .sort_values(by=['value', 'uncertainty'],
42
+ ascending=[ascending, not ascending])
43
+ .drop(columns='uncertainty')
44
+ .reset_index(drop=True))
45
+ df.index += 1
46
+
47
+ return df.reset_index(names=name)
48
+
49
+ def compare(df, model_1, model_2):
50
+ mcol = 'source'
51
+ models = [
52
+ model_1,
53
+ model_2,
54
+ ]
55
+ view = (df
56
+ .query(f'{mcol} in @models')
57
+ .pivot(index=['chain', 'sample'],
58
+ columns=mcol,
59
+ values='value'))
60
+
61
+ return expit(view[model_1] - view[model_2])
62
+
63
+ #
64
+ #
65
+ #
66
+ class DataPlotter:
67
+ def __init__(self, df):
68
+ self.df = df
69
+
70
+ def plot(self):
71
+ fig = plt.figure(dpi=200)
72
+
73
+ ax = fig.gca()
74
+ self.draw(ax)
75
+ ax.grid(visible=True,
76
+ axis='both',
77
+ alpha=0.25,
78
+ linestyle='dotted')
79
+ fig.tight_layout()
80
+
81
+ return fig
82
+
83
+ def draw(self, ax):
84
+ raise NotImplementedError()
85
+
86
+ class RankPlotter(DataPlotter):
87
+ _y = 'y'
88
+
89
+ @ft.cached_property
90
+ def y(self):
91
+ return self.df[self._y]
92
+
93
+ def __init__(self, df, ci=0.95, top=10):
94
+ self.ci = ci
95
+ view = rank(summarize(df, self.ci), True, self._y)
96
+ view = (view
97
+ .tail(top)
98
+ .sort_values(by=self._y, ascending=False))
99
+
100
+ super().__init__(view)
101
+
102
+ def draw(self, ax):
103
+ self.df.plot.scatter('value', self._y, ax=ax)
104
+ ax.hlines(self.y,
105
+ xmin=self.df['lower'],
106
+ xmax=self.df['upper'],
107
+ alpha=0.5)
108
+ ax.set_xlabel('{} (with {:.0%} HDI)'.format(
109
+ ax.get_xlabel().title(),
110
+ self.ci,
111
+ ))
112
+ ax.set_ylabel('')
113
+ ax.set_yticks(self.y, self.df['source'])
114
+
115
+ class ComparisonPlotter(DataPlotter):
116
+ _uncertain = 0.5
117
+
118
+ def __init__(self, df, model_1, model_2, ci):
119
+ super().__init__(compare(df, model_1, model_2))
120
+ self.interval = HDInterval(self.df)
121
+ self.ci = ci
122
+
123
+ def draw(self, ax):
124
+ hdi = self.interval(self.ci)
125
+ (c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2)
126
+
127
+ ax = sns.histplot(data=self.df,
128
+ stat='density',
129
+ color=c_hist)
130
+ ax.set_xlabel('\u03B1$_{1}$ - \u03B1$_{2}$')
131
+
132
+ self.pr(ax, hdi, c_hdi)
133
+ self.min_inclusive(ax)
134
+
135
+ def min_inclusive(self, ax):
136
+ try:
137
+ ci = self.interval.at(self._uncertain)
138
+ inclusive = '\u2208'
139
+ except OverflowError:
140
+ ci = 1
141
+ inclusive = '\u2209'
142
+ except FloatingPointError:
143
+ return
144
+
145
+ ax.text(x=0.02,
146
+ y=0.975,
147
+ s=f'{self._uncertain} {inclusive} {ci:.0%} HDI',
148
+ fontsize='small',
149
+ fontstyle='italic',
150
+ horizontalalignment='left',
151
+ verticalalignment='top',
152
+ transform=ax.transAxes)
153
+
154
+ def pr(self, ax, hdi, color):
155
+ x = self.df.median()
156
+ zorder = ax.zorder - 1
157
+
158
+ (label, *_) = ax.get_xticklabels()
159
+ parts = label.get_text().split('.')
160
+ decimals = len(parts[-1]) + 1 if parts else 2
161
+ fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}'
162
+
163
+ ax.axvline(x=x,
164
+ color=color,
165
+ linestyle='dashed')
166
+ ax.axvspan(xmin=hdi.lower,
167
+ xmax=hdi.upper,
168
+ alpha=0.15,
169
+ color=color,
170
+ zorder=zorder)
171
+
172
+ ax_ = ax.secondary_xaxis('top')
173
+ ax_.xaxis.set_major_locator(FixedLocator([x]))
174
+ ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt))
175
+
176
+ #
177
+ #
178
+ #
179
+ class ComparisonMenu:
180
+ def __init__(self, df, ci=95):
181
+ self.df = df
182
+ self.ci = ci
183
+
184
+ def __call__(self, model_1, model_2, ci):
185
+ if model_1 and model_2:
186
+ ci /= 100
187
+ cp = ComparisonPlotter(self.df, model_1, model_2, ci)
188
+
189
+ return cp.plot()
190
+
191
+ def build_and_get(self):
192
+ models = self.df['source'].unique()
193
+ choices = sorted(models, key=lambda x: x.lower())
194
+
195
+ for i in range(1, 3):
196
+ label = f'Model {i}'
197
+ yield gr.Dropdown(label=label, choices=choices)
198
+
199
+ yield gr.Number(value=self.ci,
200
+ label='HDI (%)',
201
+ minimum=0,
202
+ maximum=100)
203
+
204
+ #
205
+ #
206
+ #
207
+ class DocumentationReader:
208
+ _suffix = '.md'
209
+
210
+ def __init__(self, root):
211
+ self.root = root
212
+
213
+ def __getitem__(self, item):
214
+ return (self
215
+ .root
216
+ .joinpath(item)
217
+ .with_suffix(self._suffix)
218
+ .read_text())
219
+
220
+ #
221
+ #
222
+ #
223
+ def layout_a(df, ci=0.95):
224
+ with gr.Row():
225
+ view = rank(summarize(df, ci), False)
226
+ columns = { x.name: f'{ci:.0%} HDI {x.name}' for x in fields(HDI) }
227
+ for i in view.columns:
228
+ columns.setdefault(i, i.title())
229
+ view = (view
230
+ .rename(columns=columns)
231
+ .style.format(precision=4))
232
+
233
+ gr.Dataframe(view)
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ gr.Markdown(f'''
238
+
239
+ Probability that Model 1 has higher ability than Model
240
+ 2. The histogram is represents the distribution of the
241
+ difference in estimated model abilities. The dashed
242
+ vertical line is its median. The shaded region demarcates
243
+ the chosen [highest density
244
+ interval](https://cran.r-project.org/package=HDInterval)
245
+ (HDI). The note in the upper left denotes the smallest HDI
246
+ that is inclusive of {ComparisonPlotter._uncertain}.
247
+
248
+ ''')
249
+
250
+ with gr.Column(scale=3):
251
+ display = gr.Plot()
252
+
253
+ with gr.Column():
254
+ menu = ComparisonMenu(df)
255
+ inputs = list(menu.build_and_get())
256
+ button = gr.Button(value='Compare!')
257
+ button.click(menu, inputs=inputs, outputs=[display])
258
+
259
+ def layout_b(df, ci=0.95):
260
+ for (i, g) in df.groupby('parameter'):
261
+ with gr.Tab(i):
262
+ with gr.Row():
263
+ view = rank(summarize(g, ci), False)
264
+ columns = {
265
+ x.name: f'{ci:.0%} HDI {x.name}' for x in fields(HDI)
266
+ }
267
+ for i in view.columns:
268
+ columns.setdefault(i, i.title())
269
+ view = (view
270
+ .rename(columns=columns)
271
+ .style.format(precision=4))
272
+
273
+ gr.Dataframe(view)
274
+
275
+ #
276
+ #
277
+ #
278
+ def grouper(df):
279
+ def classifier(idx):
280
+ theta = df.loc[(idx, 'parameter')] == 'theta'
281
+ return 'person' if theta else 'item'
282
+
283
+ return classifier
284
+
285
+ #
286
+ #
287
+ #
288
+ with gr.Blocks() as demo:
289
+ path = str(Path('jerome-white', 'leaderboard-item-response'))
290
+ splits = get_dataset_split_names(path)
291
+
292
+ for s in splits:
293
+ data = load_dataset(path, split=s)
294
+ df = data.to_pandas()
295
+ groups = df.groupby(grouper(df), sort=False)
296
+
297
+ with gr.Tab(s):
298
+ with gr.Row():
299
+ with gr.Column():
300
+ layout_a(groups.get_group('person'))
301
+ with gr.Column():
302
+ layout_b(groups.get_group('item'))
303
+
304
+ demo.launch(server_name='0.0.0.0', debug=True)
hdinterval.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ import operator as op
4
+ import itertools as it
5
+ import functools as ft
6
+ import statistics as st
7
+ from dataclasses import dataclass
8
+
9
+ @dataclass
10
+ class HDI:
11
+ lower: float
12
+ upper: float
13
+
14
+ def __iter__(self):
15
+ yield from (self.lower, self.upper)
16
+
17
+ def __contains__(self, item):
18
+ return self.lower <= item <= self.upper
19
+
20
+ def width(self):
21
+ return self.upper - self.lower
22
+
23
+ class HDInterval:
24
+ @ft.cached_property
25
+ def values(self):
26
+ view = sorted(filter(math.isfinite, self._values))
27
+ if not view:
28
+ raise AttributeError('Empty data set')
29
+
30
+ return view
31
+
32
+ def __init__(self, values):
33
+ self._values = values
34
+
35
+ #
36
+ # See https://cran.r-project.org/package=HDInterval
37
+ #
38
+ def __call__(self, ci=0.95):
39
+ if ci == 1:
40
+ args = (self.values[x] for x in (0, -1))
41
+ else:
42
+ n = len(self.values)
43
+ exclude = n - math.floor(n * ci)
44
+
45
+ left = it.islice(self.values, exclude)
46
+ right = it.islice(self.values, n - exclude, None)
47
+
48
+ diffs = ((x, y, y - x) for (x, y) in zip(left, right))
49
+ (*args, _) = min(diffs, key=op.itemgetter(-1))
50
+
51
+ return HDI(*args)
52
+
53
+ def _at(self, target, tolerance, ci=1, jump=1):
54
+ if ci > 1:
55
+ raise OverflowError()
56
+
57
+ hdi = self(ci)
58
+ if any(x in tolerance for x in hdi):
59
+ return ci
60
+
61
+ adjust = op.sub if target in hdi else op.add
62
+ ci = adjust(ci, jump)
63
+ jump /= 2
64
+
65
+ return self._at(target, tolerance, ci, jump)
66
+
67
+ def at(self, target, tolerance=1e-4):
68
+ assert tolerance > 0
69
+
70
+ while tolerance < 1:
71
+ hdi = HDI(target, target + tolerance)
72
+ try:
73
+ return self._at(target, hdi)
74
+ except RecursionError:
75
+ tolerance *= 10
76
+ warnings.warn(f'Tolerance reduced: {tolerance}')
77
+
78
+ raise FloatingPointError('Unable to converge')
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets
2
+ gradio
3
+ matplotlib
4
+ pandas
5
+ seaborn
6
+ scipy