Spaces:
Sleeping
Sleeping
Commit
·
75b2724
1
Parent(s):
db4a8ee
Lift HDI calculation to its own module
Browse filesNew functionality from HDI allows the calculation of the smallest
interval that excludes a given value. Its added complexity adds code,
which makes putting it in its own module cleaner.
- app.py +14 -27
- hdinterval.py +87 -0
app.py
CHANGED
@@ -4,6 +4,7 @@ import itertools as it
|
|
4 |
import functools as ft
|
5 |
import collections as cl
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
import pandas as pd
|
9 |
import gradio as gr
|
@@ -12,27 +13,9 @@ import matplotlib.pyplot as plt
|
|
12 |
from datasets import load_dataset
|
13 |
from scipy.special import expit
|
14 |
|
15 |
-
|
16 |
-
TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset')
|
17 |
-
|
18 |
-
#
|
19 |
-
# See https://cran.r-project.org/package=HDInterval
|
20 |
-
#
|
21 |
-
def hdi(values, ci=0.95):
|
22 |
-
values = sorted(filter(math.isfinite, values))
|
23 |
-
if not values:
|
24 |
-
raise ValueError('Empty data set')
|
25 |
-
|
26 |
-
n = len(values)
|
27 |
-
exclude = n - math.floor(n * ci)
|
28 |
|
29 |
-
|
30 |
-
right = it.islice(values, n - exclude, None)
|
31 |
-
|
32 |
-
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
|
33 |
-
(*args, _) = min(diffs, key=op.itemgetter(-1))
|
34 |
-
|
35 |
-
return HDI(*args)
|
36 |
|
37 |
#
|
38 |
#
|
@@ -60,14 +43,15 @@ def load(repo):
|
|
60 |
def summarize(df, ci=0.95):
|
61 |
def _aggregate(i, g):
|
62 |
values = g['value']
|
63 |
-
|
|
|
64 |
|
65 |
agg = {
|
66 |
'model': i,
|
67 |
'ability': values.median(),
|
68 |
-
'uncertainty': interval.
|
69 |
}
|
70 |
-
agg.update(interval
|
71 |
|
72 |
return agg
|
73 |
|
@@ -150,17 +134,20 @@ class RankPlotter(DataPlotter):
|
|
150 |
class ComparisonPlotter(DataPlotter):
|
151 |
def __init__(self, df, model_1, model_2, ci=0.95):
|
152 |
super().__init__(compare(df, model_1, model_2))
|
153 |
-
self.
|
|
|
154 |
|
155 |
def draw(self, ax):
|
|
|
|
|
156 |
sns.ecdfplot(self.df, ax=ax)
|
157 |
|
158 |
(_, color, *_) = sns.color_palette()
|
159 |
ax.axvline(x=self.df.median(),
|
160 |
color=color,
|
161 |
linestyle='dashed')
|
162 |
-
ax.axvspan(xmin=
|
163 |
-
xmax=
|
164 |
alpha=0.15,
|
165 |
color=color)
|
166 |
ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
|
@@ -205,7 +192,7 @@ def layout(tab):
|
|
205 |
|
206 |
with gr.Row():
|
207 |
view = rank(summarize(df), False)
|
208 |
-
columns = { x: f'HDI {x}' for x in HDI
|
209 |
for i in view.columns:
|
210 |
columns.setdefault(i, i.title())
|
211 |
view = (view
|
|
|
4 |
import functools as ft
|
5 |
import collections as cl
|
6 |
from pathlib import Path
|
7 |
+
from dataclasses import fields, asdict
|
8 |
|
9 |
import pandas as pd
|
10 |
import gradio as gr
|
|
|
13 |
from datasets import load_dataset
|
14 |
from scipy.special import expit
|
15 |
|
16 |
+
from hdinterval import HDI, HDInterval
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset')
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
#
|
21 |
#
|
|
|
43 |
def summarize(df, ci=0.95):
|
44 |
def _aggregate(i, g):
|
45 |
values = g['value']
|
46 |
+
hdi = HDInterval(values)
|
47 |
+
interval = hdi(ci)
|
48 |
|
49 |
agg = {
|
50 |
'model': i,
|
51 |
'ability': values.median(),
|
52 |
+
'uncertainty': interval.width(),
|
53 |
}
|
54 |
+
agg.update(asdict(interval))
|
55 |
|
56 |
return agg
|
57 |
|
|
|
134 |
class ComparisonPlotter(DataPlotter):
|
135 |
def __init__(self, df, model_1, model_2, ci=0.95):
|
136 |
super().__init__(compare(df, model_1, model_2))
|
137 |
+
self.hdi = HDInterval(self.df)
|
138 |
+
self.ci = ci
|
139 |
|
140 |
def draw(self, ax):
|
141 |
+
interval = self.hdi(self.ci)
|
142 |
+
|
143 |
sns.ecdfplot(self.df, ax=ax)
|
144 |
|
145 |
(_, color, *_) = sns.color_palette()
|
146 |
ax.axvline(x=self.df.median(),
|
147 |
color=color,
|
148 |
linestyle='dashed')
|
149 |
+
ax.axvspan(xmin=interval.lower,
|
150 |
+
xmax=interval.upper,
|
151 |
alpha=0.15,
|
152 |
color=color)
|
153 |
ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
|
|
|
192 |
|
193 |
with gr.Row():
|
194 |
view = rank(summarize(df), False)
|
195 |
+
columns = { x.name: f'HDI {x.name}' for x in fields(HDI) }
|
196 |
for i in view.columns:
|
197 |
columns.setdefault(i, i.title())
|
198 |
view = (view
|
hdinterval.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
import operator as op
|
4 |
+
import itertools as it
|
5 |
+
import functools as ft
|
6 |
+
import statistics as st
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class HDI:
|
11 |
+
lower: float
|
12 |
+
upper: float
|
13 |
+
|
14 |
+
def __iter__(self):
|
15 |
+
yield from (self.lower, self.upper)
|
16 |
+
|
17 |
+
def __contains__(self, item):
|
18 |
+
return self.lower <= item <= self.upper
|
19 |
+
|
20 |
+
def width(self):
|
21 |
+
return self.upper - self.lower
|
22 |
+
|
23 |
+
class HDInterval:
|
24 |
+
@ft.cached_property
|
25 |
+
def values(self):
|
26 |
+
view = sorted(filter(math.isfinite, self._values))
|
27 |
+
if not view:
|
28 |
+
raise AttributeError('Empty data set')
|
29 |
+
|
30 |
+
return view
|
31 |
+
|
32 |
+
def __init__(self, values):
|
33 |
+
self._values = values
|
34 |
+
|
35 |
+
#
|
36 |
+
# See https://cran.r-project.org/package=HDInterval
|
37 |
+
#
|
38 |
+
def __call__(self, ci=0.95):
|
39 |
+
if ci == 1:
|
40 |
+
args = (self.values[x] for x in (0, -1))
|
41 |
+
else:
|
42 |
+
n = len(self.values)
|
43 |
+
exclude = n - math.floor(n * ci)
|
44 |
+
|
45 |
+
left = it.islice(self.values, exclude)
|
46 |
+
right = it.islice(self.values, n - exclude, None)
|
47 |
+
|
48 |
+
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
|
49 |
+
(*args, _) = min(diffs, key=op.itemgetter(-1))
|
50 |
+
|
51 |
+
return HDI(*args)
|
52 |
+
|
53 |
+
def _at(self, target, tolerance, ci, jump):
|
54 |
+
if ci > 1:
|
55 |
+
return 1
|
56 |
+
|
57 |
+
interval = self(ci)
|
58 |
+
if any(math.isclose(x, target, abs_tol=tolerance) for x in interval):
|
59 |
+
return ci
|
60 |
+
|
61 |
+
plus_minus = op.sub if target in interval else op.add
|
62 |
+
ci = plus_minus(ci, jump)
|
63 |
+
jump /= 2
|
64 |
+
|
65 |
+
return self._at(target, tolerance, ci, jump)
|
66 |
+
|
67 |
+
def at(self, target, tolerance=1e-3):
|
68 |
+
while tolerance < 1:
|
69 |
+
try:
|
70 |
+
return self._at(target, tolerance, 1, 1)
|
71 |
+
except RecursionError:
|
72 |
+
tolerance *= 10
|
73 |
+
warnings.warn(f'Tolerance reduced: {tolerance}')
|
74 |
+
|
75 |
+
raise OverflowError()
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import numpy as np
|
79 |
+
|
80 |
+
data = np.random.uniform(size=2000)
|
81 |
+
# data = list(filter(lambda x: x > 0.7, data))
|
82 |
+
# data = [0.5] * 10
|
83 |
+
|
84 |
+
interval = HDInterval(data)
|
85 |
+
point = interval.at(0.5)
|
86 |
+
hdi = interval(point)
|
87 |
+
print(point, hdi)
|