jerome-white commited on
Commit
dc3994d
·
1 Parent(s): 99722b5

Update to model name heading

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -42,7 +42,7 @@ def load(repo):
42
  'chain',
43
  'sample',
44
  parameter,
45
- 'model',
46
  'value',
47
  ]
48
  dataset = load_dataset(repo)
@@ -60,7 +60,7 @@ def summarize(df, ci=0.95):
60
  interval = hdi(values, ci)
61
 
62
  agg = {
63
- 'model': i,
64
  'ability': values.median(),
65
  'uncertainty': interval.upper - interval.lower,
66
  }
@@ -68,7 +68,7 @@ def summarize(df, ci=0.95):
68
 
69
  return agg
70
 
71
- groups = df.groupby('model', sort=False)
72
  records = it.starmap(_aggregate, groups)
73
 
74
  return pd.DataFrame.from_records(records)
@@ -84,7 +84,7 @@ def rank(df, ascending, name='rank'):
84
  return df.reset_index(names=name)
85
 
86
  def compare(df, model_1, model_2):
87
- mcol = 'model'
88
  models = [
89
  model_1,
90
  model_2,
@@ -141,7 +141,7 @@ class RankPlotter(DataPlotter):
141
  xmax=self.df['upper'],
142
  alpha=0.5)
143
  ax.set_ylabel('')
144
- ax.set_yticks(self.y, self.df['model'])
145
 
146
  class ComparisonPlotter(DataPlotter):
147
  def __init__(self, df, model_1, model_2, ci=0.95):
@@ -219,8 +219,9 @@ with gr.Blocks() as demo:
219
 
220
  ''')
221
  with gr.Column():
222
- models = sorted(df['model'].unique(), key=lambda x: x.lower())
223
- drops = ft.partial(gr.Dropdown, choices=models)
 
224
  inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
225
 
226
  button = gr.Button(value='Compare!')
 
42
  'chain',
43
  'sample',
44
  parameter,
45
+ 'parameter_value',
46
  'value',
47
  ]
48
  dataset = load_dataset(repo)
 
60
  interval = hdi(values, ci)
61
 
62
  agg = {
63
+ 'parameter_value': i,
64
  'ability': values.median(),
65
  'uncertainty': interval.upper - interval.lower,
66
  }
 
68
 
69
  return agg
70
 
71
+ groups = df.groupby('parameter_value', sort=False)
72
  records = it.starmap(_aggregate, groups)
73
 
74
  return pd.DataFrame.from_records(records)
 
84
  return df.reset_index(names=name)
85
 
86
  def compare(df, model_1, model_2):
87
+ mcol = 'parameter_value'
88
  models = [
89
  model_1,
90
  model_2,
 
141
  xmax=self.df['upper'],
142
  alpha=0.5)
143
  ax.set_ylabel('')
144
+ ax.set_yticks(self.y, self.df['parameter_value'])
145
 
146
  class ComparisonPlotter(DataPlotter):
147
  def __init__(self, df, model_1, model_2, ci=0.95):
 
219
 
220
  ''')
221
  with gr.Column():
222
+ models = df['parameter_value'].unique()
223
+ choices = sorted(models, key=lambda x: x.lower())
224
+ drops = ft.partial(gr.Dropdown, choices=choices)
225
  inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
226
 
227
  button = gr.Button(value='Compare!')