Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -166,11 +166,9 @@ def run_update_dataset():
|
|
166 |
)
|
167 |
|
168 |
|
169 |
-
def get_data(rl_env, task_id, path) -> pd.DataFrame:
|
170 |
"""
|
171 |
-
Get data from rl_env, filter by the given task_id, and drop the Task-ID column.
|
172 |
-
Also drops any columns that have no data (all values are NaN) or all values are 0.0.
|
173 |
-
:return: filtered data as a pandas DataFrame without the Task-ID column
|
174 |
"""
|
175 |
csv_path = path + "/" + rl_env + ".csv"
|
176 |
data = pd.read_csv(csv_path)
|
@@ -178,16 +176,15 @@ def get_data(rl_env, task_id, path) -> pd.DataFrame:
|
|
178 |
# Filter the data to only include rows where the "Task-ID" column matches the given task_id
|
179 |
filtered_data = data[data["Task-ID"] == task_id]
|
180 |
|
181 |
-
#
|
182 |
-
|
|
|
183 |
|
184 |
-
# Drop the "Task"
|
185 |
-
filtered_data = filtered_data.drop(columns=["Task"])
|
186 |
|
187 |
-
# Drop columns that have no data (all values are NaN)
|
188 |
filtered_data = filtered_data.dropna(axis=1, how='all')
|
189 |
-
|
190 |
-
# Drop columns where all values are 0.0
|
191 |
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
|
192 |
|
193 |
# Convert User and Model columns to clickable links
|
@@ -252,18 +249,24 @@ with block:
|
|
252 |
gr.HTML(f"<p style='text-align: center;'>Get started π on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>")
|
253 |
|
254 |
path_ = download_leaderboard_dataset()
|
255 |
-
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
256 |
# ENVIRONMENT TABS
|
257 |
-
with gr.Tabs() as tabs:
|
258 |
for env_index in range(0, len(hivex_envs)):
|
259 |
hivex_env = hivex_envs[env_index]
|
260 |
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
|
261 |
-
# TASK TABS
|
262 |
for task_id in range(0, hivex_env["task_count"]):
|
263 |
task_title = convert_to_title_case(get_task(hivex_env["hivex_env"], task_id, path_))
|
264 |
with gr.TabItem(f"Task {task_id}: {task_title}"):
|
265 |
with gr.Row():
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
row_count = len(data) # Number of rows in the data
|
268 |
|
269 |
gr_dataframe = gr.components.Dataframe(
|
@@ -272,6 +275,13 @@ with block:
|
|
272 |
datatype=["markdown", "markdown"],
|
273 |
row_count=(row_count, 'fixed') # Set to the exact number of rows in the data
|
274 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
|
277 |
scheduler = BackgroundScheduler()
|
|
|
166 |
)
|
167 |
|
168 |
|
169 |
+
def get_data(rl_env, task_id, path, selected_filters: List[str] = None) -> pd.DataFrame:
|
170 |
"""
|
171 |
+
Get data from rl_env, filter by the given task_id and selected filters, and drop the Task-ID column.
|
|
|
|
|
172 |
"""
|
173 |
csv_path = path + "/" + rl_env + ".csv"
|
174 |
data = pd.read_csv(csv_path)
|
|
|
176 |
# Filter the data to only include rows where the "Task-ID" column matches the given task_id
|
177 |
filtered_data = data[data["Task-ID"] == task_id]
|
178 |
|
179 |
+
# Apply selected filters for difficulty or pattern if provided
|
180 |
+
if selected_filters:
|
181 |
+
filtered_data = filtered_data[filtered_data['Pattern'].isin(selected_filters) | filtered_data['Difficulty'].isin(selected_filters)]
|
182 |
|
183 |
+
# Drop the "Task-ID" and "Task" columns
|
184 |
+
filtered_data = filtered_data.drop(columns=["Task-ID", "Task"])
|
185 |
|
186 |
+
# Drop columns that have no data (all values are NaN) or where all values are 0.0
|
187 |
filtered_data = filtered_data.dropna(axis=1, how='all')
|
|
|
|
|
188 |
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
|
189 |
|
190 |
# Convert User and Model columns to clickable links
|
|
|
249 |
gr.HTML(f"<p style='text-align: center;'>Get started π on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>")
|
250 |
|
251 |
path_ = download_leaderboard_dataset()
|
|
|
252 |
# ENVIRONMENT TABS
|
253 |
+
with gr.Tabs() as tabs: # elem_classes="tab-buttons"
|
254 |
for env_index in range(0, len(hivex_envs)):
|
255 |
hivex_env = hivex_envs[env_index]
|
256 |
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
|
|
|
257 |
for task_id in range(0, hivex_env["task_count"]):
|
258 |
task_title = convert_to_title_case(get_task(hivex_env["hivex_env"], task_id, path_))
|
259 |
with gr.TabItem(f"Task {task_id}: {task_title}"):
|
260 |
with gr.Row():
|
261 |
+
# CheckboxGroup for Difficulty/Pattern
|
262 |
+
selected_filters = gr.CheckboxGroup(
|
263 |
+
choices=list(pattern_map.values()),
|
264 |
+
label="Select Difficulty/Pattern",
|
265 |
+
value=list(pattern_map.values()) # Default to all selected
|
266 |
+
)
|
267 |
+
|
268 |
+
with gr.Row():
|
269 |
+
data = get_data(hivex_env["hivex_env"], task_id, path_, selected_filters=selected_filters.value)
|
270 |
row_count = len(data) # Number of rows in the data
|
271 |
|
272 |
gr_dataframe = gr.components.Dataframe(
|
|
|
275 |
datatype=["markdown", "markdown"],
|
276 |
row_count=(row_count, 'fixed') # Set to the exact number of rows in the data
|
277 |
)
|
278 |
+
|
279 |
+
# Update data based on checkbox selection
|
280 |
+
selected_filters.change(
|
281 |
+
lambda filters: get_data(hivex_env["hivex_env"], task_id, path_, filters),
|
282 |
+
inputs=[selected_filters],
|
283 |
+
outputs=[gr_dataframe]
|
284 |
+
)
|
285 |
|
286 |
|
287 |
scheduler = BackgroundScheduler()
|