jerome-white commited on
Commit
b553146
·
1 Parent(s): bd4de40

Use histogram for comparison plots

Browse files
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -12,7 +12,6 @@ import seaborn as sns
12
  import matplotlib.pyplot as plt
13
  from datasets import load_dataset
14
  from scipy.special import expit
15
- # from matplotlib.ticker import MultipleLocator
16
 
17
  from hdinterval import HDI, HDInterval
18
 
@@ -133,35 +132,50 @@ class RankPlotter(DataPlotter):
133
  ax.set_yticks(self.y, self.df['model'])
134
 
135
  class ComparisonPlotter(DataPlotter):
 
 
 
 
 
 
 
 
136
  def __init__(self, df, model_1, model_2, ci):
137
  super().__init__(compare(df, model_1, model_2))
138
- self.hdi = HDInterval(self.df)
139
  self.ci = ci
140
 
141
  def draw(self, ax):
142
- interval = self.hdi(self.ci)
143
 
144
- sns.ecdfplot(self.df, ax=ax)
 
 
 
145
 
146
  (_, color, *_) = sns.color_palette()
147
- ax.axvline(x=self.df.median(),
148
- color=color,
149
- linestyle='dashed')
150
- ax.axvspan(xmin=interval.lower,
151
- xmax=interval.upper,
152
- alpha=0.15,
153
  color=color)
154
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
155
- # for i in ('x', 'y'):
156
- # lim = getattr(ax, f'set_{i}lim')
157
- # lim(-0.01, 1.01)
158
- # ax.xaxis.set_major_locator(MultipleLocator(base=0.1, offset=0))
 
 
 
 
159
 
160
  try:
161
- ci_mid = self.hdi.at(0.5)
162
  ax.text(x=0.01,
163
  y=0.99,
164
- s=f'0.5-min HDI: {ci_mid:.0%}',
165
  horizontalalignment='left',
166
  verticalalignment='top',
167
  transform=ax.transAxes)
@@ -177,10 +191,11 @@ class ComparisonMenu:
177
  self.ci = ci
178
 
179
  def __call__(self, model_1, model_2, ci):
180
- ci /= 100
181
- cp = ComparisonPlotter(self.df, model_1, model_2, ci)
 
182
 
183
- return cp.plot()
184
 
185
  def build_and_get(self):
186
  models = self.df['model'].unique()
@@ -246,13 +261,14 @@ def layout(tab):
246
  gr.Markdown('''
247
 
248
  Probability that Model 1 is preferred to Model 2. The
249
- solid blue curve is a CDF of that distribution;
250
- formally the inverse logit of the difference in model
251
- abilities. The dashed orange vertical line is the
252
- median, while the band surrounding it is the [highest
253
- density
254
- interval](https://cran.r-project.org/package=HDInterval)
255
- of your choice (default 95%).
 
256
 
257
  ''')
258
  with gr.Column():
 
12
  import matplotlib.pyplot as plt
13
  from datasets import load_dataset
14
  from scipy.special import expit
 
15
 
16
  from hdinterval import HDI, HDInterval
17
 
 
132
  ax.set_yticks(self.y, self.df['model'])
133
 
134
  class ComparisonPlotter(DataPlotter):
135
+ _uncertain = 0.5
136
+
137
+ @staticmethod
138
+ def to_relative(hdi, ax):
139
+ (lhs, rhs) = ax.get_xlim()
140
+ length = rhs - lhs
141
+ yield from (abs(lhs - x) / length for x in hdi)
142
+
143
  def __init__(self, df, model_1, model_2, ci):
144
  super().__init__(compare(df, model_1, model_2))
145
+ self.interval = HDInterval(self.df)
146
  self.ci = ci
147
 
148
  def draw(self, ax):
149
+ hdi = self.interval(self.ci)
150
 
151
+ ax = sns.histplot(self.df, stat='density')
152
+
153
+ top = max(x.get_height() for x in ax.patches)
154
+ y = top * 1.05
155
 
156
  (_, color, *_) = sns.color_palette()
157
+ (xmin, xmax) = self.to_relative(hdi, ax)
158
+ linestyle = 'dashed' if self._uncertain in hdi else 'solid'
159
+ ax.axhline(y=y,
160
+ xmin=xmin,
161
+ xmax=xmax,
162
+ linestyle=linestyle,
163
  color=color)
164
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
165
+
166
+ x = (hdi.lower + hdi.upper) / 2
167
+ ax.text(x=x,
168
+ y=y,
169
+ s=f'{self.ci:.0%} HDI',
170
+ backgroundcolor='white',
171
+ horizontalalignment='center',
172
+ verticalalignment='center')
173
 
174
  try:
175
+ ci_min = self.interval.at(self._uncertain)
176
  ax.text(x=0.01,
177
  y=0.99,
178
+ s=f'0.5 \u2248\u2208 {ci_min:.0%} HDI',
179
  horizontalalignment='left',
180
  verticalalignment='top',
181
  transform=ax.transAxes)
 
191
  self.ci = ci
192
 
193
  def __call__(self, model_1, model_2, ci):
194
+ if model_1 and model_2:
195
+ ci /= 100
196
+ cp = ComparisonPlotter(self.df, model_1, model_2, ci)
197
 
198
+ return cp.plot()
199
 
200
  def build_and_get(self):
201
  models = self.df['model'].unique()
 
261
  gr.Markdown('''
262
 
263
  Probability that Model 1 is preferred to Model 2. The
264
+ histogram is represents the distribution of inverse
265
+ logit of the difference in model abilities. The
266
+ horizontal line above the histogram marks the chosen
267
+ [highest density
268
+ interval](https://cran.r-project.org/package=HDInterval). The
269
+ line is dashed if the interval overlaps 0.5, solid
270
+ otherwise. The HDI in the upper left denotes the
271
+ smallest approximate HDI that is inclusive of 0.5.
272
 
273
  ''')
274
  with gr.Column():