Spaces:
Sleeping
Sleeping
Commit
·
7e85849
1
Parent(s):
dc3994d
Cleaner way of updating column names
Browse filesRename the column to what the previous code expected.
app.py
CHANGED
@@ -38,11 +38,12 @@ def hdi(values, ci=0.95):
|
|
38 |
#
|
39 |
def load(repo):
|
40 |
parameter = 'parameter'
|
|
|
41 |
items = [
|
42 |
'chain',
|
43 |
'sample',
|
44 |
parameter,
|
45 |
-
|
46 |
'value',
|
47 |
]
|
48 |
dataset = load_dataset(repo)
|
@@ -50,6 +51,7 @@ def load(repo):
|
|
50 |
return (dataset
|
51 |
.get('train')
|
52 |
.to_pandas()
|
|
|
53 |
.filter(items=items)
|
54 |
.query(f'{parameter} == "alpha"')
|
55 |
.drop(columns=parameter))
|
@@ -60,7 +62,7 @@ def summarize(df, ci=0.95):
|
|
60 |
interval = hdi(values, ci)
|
61 |
|
62 |
agg = {
|
63 |
-
'
|
64 |
'ability': values.median(),
|
65 |
'uncertainty': interval.upper - interval.lower,
|
66 |
}
|
@@ -68,7 +70,7 @@ def summarize(df, ci=0.95):
|
|
68 |
|
69 |
return agg
|
70 |
|
71 |
-
groups = df.groupby('
|
72 |
records = it.starmap(_aggregate, groups)
|
73 |
|
74 |
return pd.DataFrame.from_records(records)
|
@@ -84,7 +86,7 @@ def rank(df, ascending, name='rank'):
|
|
84 |
return df.reset_index(names=name)
|
85 |
|
86 |
def compare(df, model_1, model_2):
|
87 |
-
mcol = '
|
88 |
models = [
|
89 |
model_1,
|
90 |
model_2,
|
@@ -141,7 +143,7 @@ class RankPlotter(DataPlotter):
|
|
141 |
xmax=self.df['upper'],
|
142 |
alpha=0.5)
|
143 |
ax.set_ylabel('')
|
144 |
-
ax.set_yticks(self.y, self.df['
|
145 |
|
146 |
class ComparisonPlotter(DataPlotter):
|
147 |
def __init__(self, df, model_1, model_2, ci=0.95):
|
@@ -219,9 +221,8 @@ with gr.Blocks() as demo:
|
|
219 |
|
220 |
''')
|
221 |
with gr.Column():
|
222 |
-
models = df['
|
223 |
-
|
224 |
-
drops = ft.partial(gr.Dropdown, choices=choices)
|
225 |
inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
|
226 |
|
227 |
button = gr.Button(value='Compare!')
|
|
|
38 |
#
|
39 |
def load(repo):
|
40 |
parameter = 'parameter'
|
41 |
+
model = 'model'
|
42 |
items = [
|
43 |
'chain',
|
44 |
'sample',
|
45 |
parameter,
|
46 |
+
model,
|
47 |
'value',
|
48 |
]
|
49 |
dataset = load_dataset(repo)
|
|
|
51 |
return (dataset
|
52 |
.get('train')
|
53 |
.to_pandas()
|
54 |
+
.rename(columns={'parameter_value': model})
|
55 |
.filter(items=items)
|
56 |
.query(f'{parameter} == "alpha"')
|
57 |
.drop(columns=parameter))
|
|
|
62 |
interval = hdi(values, ci)
|
63 |
|
64 |
agg = {
|
65 |
+
'model': i,
|
66 |
'ability': values.median(),
|
67 |
'uncertainty': interval.upper - interval.lower,
|
68 |
}
|
|
|
70 |
|
71 |
return agg
|
72 |
|
73 |
+
groups = df.groupby('model', sort=False)
|
74 |
records = it.starmap(_aggregate, groups)
|
75 |
|
76 |
return pd.DataFrame.from_records(records)
|
|
|
86 |
return df.reset_index(names=name)
|
87 |
|
88 |
def compare(df, model_1, model_2):
|
89 |
+
mcol = 'model'
|
90 |
models = [
|
91 |
model_1,
|
92 |
model_2,
|
|
|
143 |
xmax=self.df['upper'],
|
144 |
alpha=0.5)
|
145 |
ax.set_ylabel('')
|
146 |
+
ax.set_yticks(self.y, self.df['model'])
|
147 |
|
148 |
class ComparisonPlotter(DataPlotter):
|
149 |
def __init__(self, df, model_1, model_2, ci=0.95):
|
|
|
221 |
|
222 |
''')
|
223 |
with gr.Column():
|
224 |
+
models = sorted(df['model'].unique(), key=lambda x: x.lower())
|
225 |
+
drops = ft.partial(gr.Dropdown, choices=models)
|
|
|
226 |
inputs = [ drops(label=f'Model {x}') for x in range(1, 3) ]
|
227 |
|
228 |
button = gr.Button(value='Compare!')
|