jerome-white commited on
Commit
fdccb4e
·
1 Parent(s): 31e584a

Allow user to specify HDI

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -132,7 +132,7 @@ class RankPlotter(DataPlotter):
132
  ax.set_yticks(self.y, self.df['model'])
133
 
134
  class ComparisonPlotter(DataPlotter):
135
- def __init__(self, df, model_1, model_2, ci=0.95):
136
  super().__init__(compare(df, model_1, model_2))
137
  self.hdi = HDInterval(self.df)
138
  self.ci = ci
@@ -152,8 +152,19 @@ class ComparisonPlotter(DataPlotter):
152
  color=color)
153
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
154
 
155
- def cplot(df, ci=0.95):
156
- def _plot(model_1, model_2):
 
 
 
 
 
 
 
 
 
 
 
157
  cp = ComparisonPlotter(df, model_1, model_2, ci)
158
  return cp.plot()
159
 
@@ -213,17 +224,24 @@ def layout(tab):
213
  solid blue curve is a CDF of that distribution;
214
  formally the inverse logit of the difference in model
215
  abilities. The dashed orange vertical line is the
216
- median, while the band surrounding it is its 95%
217
- [highest density
218
- interval](https://cran.r-project.org/package=HDInterval).
 
219
 
220
  ''')
221
  with gr.Column():
 
 
222
  models = df['model'].unique()
223
  choices = sorted(models, key=lambda x: x.lower())
224
- drops = ft.partial(gr.Dropdown, choices=choices)
225
- inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
226
 
 
 
 
 
227
  button = gr.Button(value='Compare!')
228
  button.click(cplot(df), inputs=inputs, outputs=[display])
229
 
 
132
  ax.set_yticks(self.y, self.df['model'])
133
 
134
  class ComparisonPlotter(DataPlotter):
135
+ def __init__(self, df, model_1, model_2, ci):
136
  super().__init__(compare(df, model_1, model_2))
137
  self.hdi = HDInterval(self.df)
138
  self.ci = ci
 
152
  color=color)
153
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
154
 
155
+ try:
156
+ ci_mid = self.hdi.at(0.5)
157
+ ax.text(x=0.01,
158
+ y=0.99,
159
+ s=f'0.5-min HDI: {ci_mid:.0%}',
160
+ horizontalalignment='left',
161
+ verticalalignment='top',
162
+ transform=ax.transAxes)
163
+ except ArithmeticError:
164
+ pass
165
+
166
+ def cplot(df):
167
+ def _plot(model_1, model_2, ci):
168
  cp = ComparisonPlotter(df, model_1, model_2, ci)
169
  return cp.plot()
170
 
 
224
  solid blue curve is a CDF of that distribution;
225
  formally the inverse logit of the difference in model
226
  abilities. The dashed orange vertical line is the
227
+ median, while the band surrounding it is the [highest
228
+ density
229
+ interval](https://cran.r-project.org/package=HDInterval)
230
+ of your choice (default 95%).
231
 
232
  ''')
233
  with gr.Column():
234
+ ci = gr.Number(value=0.95, minimum=0, maximum=1, step=1e-2)
235
+
236
  models = df['model'].unique()
237
  choices = sorted(models, key=lambda x: x.lower())
238
+ partial = ft.partial(gr.Dropdown, choices=choices)
239
+ drops = (partial(label=f'Model {x}') for x in range(1, 3))
240
 
241
+ inputs = [
242
+ *drops,
243
+ ci,
244
+ ]
245
  button = gr.Button(value='Compare!')
246
  button.click(cplot(df), inputs=inputs, outputs=[display])
247