jerome-white commited on
Commit
7e85849
·
1 Parent(s): dc3994d

Cleaner way of updating column names

Browse files

Rename the column to what the previous code expected.

Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -38,11 +38,12 @@ def hdi(values, ci=0.95):
38
  #
39
  def load(repo):
40
  parameter = 'parameter'
 
41
  items = [
42
  'chain',
43
  'sample',
44
  parameter,
45
- 'parameter_value',
46
  'value',
47
  ]
48
  dataset = load_dataset(repo)
@@ -50,6 +51,7 @@ def load(repo):
50
  return (dataset
51
  .get('train')
52
  .to_pandas()
 
53
  .filter(items=items)
54
  .query(f'{parameter} == "alpha"')
55
  .drop(columns=parameter))
@@ -60,7 +62,7 @@ def summarize(df, ci=0.95):
60
  interval = hdi(values, ci)
61
 
62
  agg = {
63
- 'parameter_value': i,
64
  'ability': values.median(),
65
  'uncertainty': interval.upper - interval.lower,
66
  }
@@ -68,7 +70,7 @@ def summarize(df, ci=0.95):
68
 
69
  return agg
70
 
71
- groups = df.groupby('parameter_value', sort=False)
72
  records = it.starmap(_aggregate, groups)
73
 
74
  return pd.DataFrame.from_records(records)
@@ -84,7 +86,7 @@ def rank(df, ascending, name='rank'):
84
  return df.reset_index(names=name)
85
 
86
  def compare(df, model_1, model_2):
87
- mcol = 'parameter_value'
88
  models = [
89
  model_1,
90
  model_2,
@@ -141,7 +143,7 @@ class RankPlotter(DataPlotter):
141
  xmax=self.df['upper'],
142
  alpha=0.5)
143
  ax.set_ylabel('')
144
- ax.set_yticks(self.y, self.df['parameter_value'])
145
 
146
  class ComparisonPlotter(DataPlotter):
147
  def __init__(self, df, model_1, model_2, ci=0.95):
@@ -219,9 +221,8 @@ with gr.Blocks() as demo:
219
 
220
  ''')
221
  with gr.Column():
222
- models = df['parameter_value'].unique()
223
- choices = sorted(models, key=lambda x: x.lower())
224
- drops = ft.partial(gr.Dropdown, choices=choices)
225
  inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
226
 
227
  button = gr.Button(value='Compare!')
 
38
  #
39
  def load(repo):
40
  parameter = 'parameter'
41
+ model = 'model'
42
  items = [
43
  'chain',
44
  'sample',
45
  parameter,
46
+ model,
47
  'value',
48
  ]
49
  dataset = load_dataset(repo)
 
51
  return (dataset
52
  .get('train')
53
  .to_pandas()
54
+ .rename(columns={'parameter_value': model})
55
  .filter(items=items)
56
  .query(f'{parameter} == "alpha"')
57
  .drop(columns=parameter))
 
62
  interval = hdi(values, ci)
63
 
64
  agg = {
65
+ 'model': i,
66
  'ability': values.median(),
67
  'uncertainty': interval.upper - interval.lower,
68
  }
 
70
 
71
  return agg
72
 
73
+ groups = df.groupby('model', sort=False)
74
  records = it.starmap(_aggregate, groups)
75
 
76
  return pd.DataFrame.from_records(records)
 
86
  return df.reset_index(names=name)
87
 
88
  def compare(df, model_1, model_2):
89
+ mcol = 'model'
90
  models = [
91
  model_1,
92
  model_2,
 
143
  xmax=self.df['upper'],
144
  alpha=0.5)
145
  ax.set_ylabel('')
146
+ ax.set_yticks(self.y, self.df['model'])
147
 
148
  class ComparisonPlotter(DataPlotter):
149
  def __init__(self, df, model_1, model_2, ci=0.95):
 
221
 
222
  ''')
223
  with gr.Column():
224
+ models = sorted(df['model'].unique(), key=lambda x: x.lower())
225
+ drops = ft.partial(gr.Dropdown, choices=models)
 
226
  inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
227
 
228
  button = gr.Button(value='Compare!')