Spaces:
Paused
Paused
Add wandb overview
Browse files- app.py +36 -10
- opendashboards/assets/io.py +34 -18
- opendashboards/assets/plot.py +6 -0
- opendashboards/utils/plotting.py +7 -5
- opendashboards/utils/utils.py +5 -2
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# https://huggingface.co/docs/hub/en/spaces-github-actions
|
|
|
2 |
import time
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
@@ -7,10 +8,21 @@ from opendashboards.assets import io, inspect, metric, plot
|
|
7 |
# prompt-based completion score stats
|
8 |
# instrospect specific RUN-UID-COMPLETION
|
9 |
# cache individual file loads
|
10 |
-
# Hotkey
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
WANDB_PROJECT = "opentensor-dev/alpha-validators"
|
13 |
-
|
|
|
|
|
14 |
DEFAULT_SELECTED_HOTKEYS = None
|
15 |
DEFAULT_TASK = 'qa'
|
16 |
DEFAULT_COMPLETION_NTOP = 10
|
@@ -24,7 +36,7 @@ st.set_page_config(
|
|
24 |
'About': f"""
|
25 |
This dashboard is part of the OpenTensor project. \n
|
26 |
To see runs in wandb, go to: \n
|
27 |
-
https://wandb.ai/{WANDB_PROJECT}/table?workspace=default
|
28 |
"""
|
29 |
},
|
30 |
layout = "centered"
|
@@ -36,23 +48,31 @@ st.markdown('#')
|
|
36 |
st.markdown('#')
|
37 |
|
38 |
|
|
|
39 |
with st.spinner(text=f'Checking wandb...'):
|
40 |
-
df_runs = io.load_runs(project=WANDB_PROJECT, filters=DEFAULT_FILTERS, min_steps=10)
|
41 |
|
42 |
metric.wandb(df_runs)
|
43 |
|
44 |
# add vertical space
|
45 |
st.markdown('#')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
st.markdown('#')
|
47 |
|
48 |
-
tab1, tab2, tab3, tab4 = st.tabs(["
|
49 |
|
50 |
### Wandb Runs ###
|
51 |
with tab1:
|
52 |
|
53 |
st.markdown('#')
|
54 |
st.subheader(":violet[Run] Data")
|
55 |
-
with st.expander(f'Show :violet[
|
56 |
|
57 |
edited_df = st.data_editor(
|
58 |
df_runs.assign(Select=False).set_index('Select'),
|
@@ -60,20 +80,26 @@ with tab1:
|
|
60 |
disabled=df_runs.columns,
|
61 |
use_container_width=True,
|
62 |
)
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
if n_runs:
|
67 |
df = io.load_data(df_runs_subset, load=True, save=True)
|
68 |
df = inspect.clean_data(df)
|
69 |
print(f'\nNans in columns: {df.isna().sum()}')
|
70 |
df_long = inspect.explode_data(df)
|
|
|
|
|
71 |
else:
|
72 |
st.info(f'You must select at least one run to load data')
|
73 |
st.stop()
|
74 |
|
75 |
metric.runs(df_long)
|
76 |
|
|
|
|
|
|
|
77 |
st.markdown('#')
|
78 |
st.subheader(":violet[Event] Data")
|
79 |
with st.expander(f'Show :violet[raw] event data for **{n_runs} selected runs**'):
|
@@ -97,7 +123,7 @@ with tab2:
|
|
97 |
|
98 |
uid_src = st.radio('Select task type:', step_types, horizontal=True, key='uid_src')
|
99 |
df_uid = df_long[df_long.task.str.contains(uid_src)] if uid_src != 'all' else df_long
|
100 |
-
|
101 |
metric.uids(df_uid, uid_src)
|
102 |
uids = st.multiselect('UID:', sorted(df_uid['uids'].unique()), key='uid')
|
103 |
with st.expander(f'Show UID health data for **{n_runs} selected runs** and **{len(uids)} selected UIDs**'):
|
@@ -158,7 +184,7 @@ with tab3:
|
|
158 |
# completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
|
159 |
completion_src = st.radio('Select task type:', step_types, horizontal=True, key='completion_src')
|
160 |
df_comp = df_long[df_long.task.str.contains(completion_src)] if completion_src != 'all' else df_long
|
161 |
-
|
162 |
completion_info.info(f"Showing **{completion_src}** completions for **{n_runs} selected runs**")
|
163 |
|
164 |
completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=DEFAULT_COMPLETION_NTOP, key='completion_ntop')
|
|
|
1 |
# https://huggingface.co/docs/hub/en/spaces-github-actions
|
2 |
+
import os
|
3 |
import time
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
|
|
8 |
# prompt-based completion score stats
|
9 |
# instrospect specific RUN-UID-COMPLETION
|
10 |
# cache individual file loads
|
11 |
+
# Hotkey
|
12 |
+
|
13 |
+
# TODO: limit the historical lookup to something reasonable (e.g. 30 days)
|
14 |
+
# TODO: Add sidebar for filters such as tags, hotkeys, etc.
|
15 |
+
# TODO: Show trends for runs (versions, hotkeys, etc.). An area chart would be nice, a gantt chart would be better
|
16 |
+
# TODO: Add a search bar for runs
|
17 |
+
# TODO: Find a reason to make a pie chart (task distribution, maybe)
|
18 |
+
# TODO: remove repetition plots (it's not really a thing any more)
|
19 |
+
# TODO: MINER SKILLSET STAR CHART
|
20 |
+
# TODO: Status codes for runs vs time (from analysis notebook)
|
21 |
|
22 |
WANDB_PROJECT = "opentensor-dev/alpha-validators"
|
23 |
+
PROJECT_URL = f'https://wandb.ai/{WANDB_PROJECT}/table?workspace=default'
|
24 |
+
MAX_RECENT_RUNS = 100
|
25 |
+
DEFAULT_FILTERS = {}#{"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
|
26 |
DEFAULT_SELECTED_HOTKEYS = None
|
27 |
DEFAULT_TASK = 'qa'
|
28 |
DEFAULT_COMPLETION_NTOP = 10
|
|
|
36 |
'About': f"""
|
37 |
This dashboard is part of the OpenTensor project. \n
|
38 |
To see runs in wandb, go to: \n
|
39 |
+
[Wandb Table](https://wandb.ai/{WANDB_PROJECT}/table?workspace=default) \n
|
40 |
"""
|
41 |
},
|
42 |
layout = "centered"
|
|
|
48 |
st.markdown('#')
|
49 |
|
50 |
|
51 |
+
|
52 |
with st.spinner(text=f'Checking wandb...'):
|
53 |
+
df_runs = io.load_runs(project=WANDB_PROJECT, filters=DEFAULT_FILTERS, min_steps=10, max_recent=MAX_RECENT_RUNS)
|
54 |
|
55 |
metric.wandb(df_runs)
|
56 |
|
57 |
# add vertical space
|
58 |
st.markdown('#')
|
59 |
+
|
60 |
+
runid_c1, runid_c2 = st.columns([3, 1])
|
61 |
+
# make multiselect for run_ids with label on same line
|
62 |
+
run_ids = runid_c1.multiselect('Select one or more weights and biases run by id:', df_runs['run_id'], key='run_id', default=df_runs['run_id'][:3], help=f'Select one or more runs to analyze. You can find the raw data for these runs [here]({PROJECT_URL}).')
|
63 |
+
n_runs = len(run_ids)
|
64 |
+
df_runs_subset = df_runs[df_runs['run_id'].isin(run_ids)]
|
65 |
+
|
66 |
st.markdown('#')
|
67 |
|
68 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Run Data", "UID Health", "Completions", "Prompt-based scoring"])
|
69 |
|
70 |
### Wandb Runs ###
|
71 |
with tab1:
|
72 |
|
73 |
st.markdown('#')
|
74 |
st.subheader(":violet[Run] Data")
|
75 |
+
with st.expander(f'Show :violet[all] wandb runs'):
|
76 |
|
77 |
edited_df = st.data_editor(
|
78 |
df_runs.assign(Select=False).set_index('Select'),
|
|
|
80 |
disabled=df_runs.columns,
|
81 |
use_container_width=True,
|
82 |
)
|
83 |
+
if edited_df.index.any():
|
84 |
+
df_runs_subset = df_runs[edited_df.index==True]
|
85 |
+
n_runs = len(df_runs_subset)
|
86 |
|
87 |
if n_runs:
|
88 |
df = io.load_data(df_runs_subset, load=True, save=True)
|
89 |
df = inspect.clean_data(df)
|
90 |
print(f'\nNans in columns: {df.isna().sum()}')
|
91 |
df_long = inspect.explode_data(df)
|
92 |
+
if 'rewards' in df_long:
|
93 |
+
df_long['rewards'] = df_long['rewards'].astype(float)
|
94 |
else:
|
95 |
st.info(f'You must select at least one run to load data')
|
96 |
st.stop()
|
97 |
|
98 |
metric.runs(df_long)
|
99 |
|
100 |
+
timeline_color = st.radio('Color by:', ['state', 'version', 'netuid'], key='timeline_color', horizontal=True)
|
101 |
+
plot.timeline(df_runs, color=timeline_color)
|
102 |
+
|
103 |
st.markdown('#')
|
104 |
st.subheader(":violet[Event] Data")
|
105 |
with st.expander(f'Show :violet[raw] event data for **{n_runs} selected runs**'):
|
|
|
123 |
|
124 |
uid_src = st.radio('Select task type:', step_types, horizontal=True, key='uid_src')
|
125 |
df_uid = df_long[df_long.task.str.contains(uid_src)] if uid_src != 'all' else df_long
|
126 |
+
|
127 |
metric.uids(df_uid, uid_src)
|
128 |
uids = st.multiselect('UID:', sorted(df_uid['uids'].unique()), key='uid')
|
129 |
with st.expander(f'Show UID health data for **{n_runs} selected runs** and **{len(uids)} selected UIDs**'):
|
|
|
184 |
# completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
|
185 |
completion_src = st.radio('Select task type:', step_types, horizontal=True, key='completion_src')
|
186 |
df_comp = df_long[df_long.task.str.contains(completion_src)] if completion_src != 'all' else df_long
|
187 |
+
|
188 |
completion_info.info(f"Showing **{completion_src}** completions for **{n_runs} selected runs**")
|
189 |
|
190 |
completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=DEFAULT_COMPLETION_NTOP, key='completion_ntop')
|
opendashboards/assets/io.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
|
@@ -12,8 +13,16 @@ from pandas.api.types import (
|
|
12 |
is_object_dtype,
|
13 |
)
|
14 |
|
15 |
-
@st.cache_data
|
16 |
-
def load_runs(project, filters, min_steps=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
runs = []
|
18 |
n_events = 0
|
19 |
successful = 0
|
@@ -22,20 +31,25 @@ def load_runs(project, filters, min_steps=10):
|
|
22 |
|
23 |
all_runs = utils.get_runs(project, filters)
|
24 |
for i, run in enumerate(all_runs):
|
25 |
-
|
|
|
26 |
summary = run.summary
|
27 |
step = summary.get('_step',-1) + 1
|
28 |
if step < min_steps:
|
29 |
msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
|
30 |
continue
|
31 |
-
|
32 |
prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
|
33 |
-
progress.progress(i/len(all_runs),f'{prog_msg}... **fetching** `{run.name}`')
|
34 |
-
|
35 |
duration = summary.get('_runtime')
|
36 |
end_time = summary.get('_timestamp')
|
37 |
# extract values for selected tags
|
38 |
-
rules = {
|
|
|
|
|
|
|
|
|
39 |
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
|
40 |
# include bool flag for remaining tags
|
41 |
tags.update({k: k in run.tags for k in ('mock','disable_set_weights')})
|
@@ -44,16 +58,18 @@ def load_runs(project, filters, min_steps=10):
|
|
44 |
'state': run.state,
|
45 |
'num_steps': step,
|
46 |
'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
|
47 |
-
'
|
|
|
|
|
|
|
|
|
|
|
48 |
'run_id': run.id,
|
49 |
'run_name': run.name,
|
50 |
-
'project': run.project,
|
51 |
'url': run.url,
|
|
|
|
|
52 |
'run_path': os.path.join(run.entity, run.project, run.id),
|
53 |
-
'start_time': pd.to_datetime(end_time-duration, unit="s"),
|
54 |
-
'end_time': pd.to_datetime(end_time, unit="s"),
|
55 |
-
'duration': pd.to_timedelta(duration, unit="s").round('s'),
|
56 |
-
**tags
|
57 |
})
|
58 |
n_events += step
|
59 |
successful += 1
|
@@ -61,8 +77,8 @@ def load_runs(project, filters, min_steps=10):
|
|
61 |
progress.empty()
|
62 |
msg.empty()
|
63 |
frame = pd.DataFrame(runs)
|
64 |
-
|
65 |
-
return frame.astype({k:v for k,v in
|
66 |
|
67 |
|
68 |
@st.cache_data
|
@@ -84,7 +100,7 @@ def load_data(selected_runs, load=True, save=False):
|
|
84 |
if load and os.path.exists(file_path):
|
85 |
progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`')
|
86 |
try:
|
87 |
-
df = utils.
|
88 |
except Exception as e:
|
89 |
info.warning(f'Failed to load history from `{file_path}`')
|
90 |
st.exception(e)
|
@@ -97,7 +113,7 @@ def load_data(selected_runs, load=True, save=False):
|
|
97 |
|
98 |
print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}')
|
99 |
df.info()
|
100 |
-
|
101 |
if save and run.state != 'running':
|
102 |
df.to_csv(file_path, index=False)
|
103 |
# st.info(f'Saved history to {file_path}')
|
@@ -137,7 +153,7 @@ def filter_dataframe(df: pd.DataFrame, demo_selection=None) -> pd.DataFrame:
|
|
137 |
df = df.loc[demo_selection]
|
138 |
run_msg.info(f"Selected {len(df)} runs")
|
139 |
return df
|
140 |
-
|
141 |
df = df.copy()
|
142 |
|
143 |
# Try to convert datetimes into a standarrd format (datetime, no timezone)
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
import time
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
6 |
|
|
|
13 |
is_object_dtype,
|
14 |
)
|
15 |
|
16 |
+
# @st.cache_data
|
17 |
+
def load_runs(project, filters, min_steps=10, max_recent=100, local_path='wandb_runs.csv', local_stale_time=3600):
|
18 |
+
# TODO: clean up the caching logic (e.g. take into account the args)
|
19 |
+
|
20 |
+
dtypes = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category', 'start_time': 'datetime64[s]', 'end_time': 'datetime64[s]', 'duration': 'timedelta64[s]'}
|
21 |
+
|
22 |
+
if local_path and os.path.exists(local_path) and (time.time() - float(os.path.getmtime(local_path))) < local_stale_time:
|
23 |
+
frame = pd.read_csv(local_path)
|
24 |
+
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})
|
25 |
+
|
26 |
runs = []
|
27 |
n_events = 0
|
28 |
successful = 0
|
|
|
31 |
|
32 |
all_runs = utils.get_runs(project, filters)
|
33 |
for i, run in enumerate(all_runs):
|
34 |
+
if i > max_recent:
|
35 |
+
break
|
36 |
summary = run.summary
|
37 |
step = summary.get('_step',-1) + 1
|
38 |
if step < min_steps:
|
39 |
msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
|
40 |
continue
|
41 |
+
|
42 |
prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
|
43 |
+
progress.progress(min(i/len(all_runs),1),f'{prog_msg}... **fetching** `{run.name}`')
|
44 |
+
|
45 |
duration = summary.get('_runtime')
|
46 |
end_time = summary.get('_timestamp')
|
47 |
# extract values for selected tags
|
48 |
+
rules = {
|
49 |
+
'version': re.compile('^\\d\.\\d+\.\\d+$'),
|
50 |
+
'spec_version': re.compile('\\d{4}$'),
|
51 |
+
'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE)
|
52 |
+
}
|
53 |
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
|
54 |
# include bool flag for remaining tags
|
55 |
tags.update({k: k in run.tags for k in ('mock','disable_set_weights')})
|
|
|
58 |
'state': run.state,
|
59 |
'num_steps': step,
|
60 |
'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
|
61 |
+
'duration': pd.to_timedelta(duration, unit="s").round('T'), # round to nearest minute
|
62 |
+
'start_time': pd.to_datetime(end_time-duration, unit="s").round('T'),
|
63 |
+
'end_time': pd.to_datetime(end_time, unit="s").round('T'),
|
64 |
+
'netuid': run.config.get('netuid'),
|
65 |
+
**tags,
|
66 |
+
'username': run.user.username,
|
67 |
'run_id': run.id,
|
68 |
'run_name': run.name,
|
|
|
69 |
'url': run.url,
|
70 |
+
# 'entity': run.entity,
|
71 |
+
# 'project': run.project,
|
72 |
'run_path': os.path.join(run.entity, run.project, run.id),
|
|
|
|
|
|
|
|
|
73 |
})
|
74 |
n_events += step
|
75 |
successful += 1
|
|
|
77 |
progress.empty()
|
78 |
msg.empty()
|
79 |
frame = pd.DataFrame(runs)
|
80 |
+
frame.to_csv(local_path, index=False)
|
81 |
+
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})
|
82 |
|
83 |
|
84 |
@st.cache_data
|
|
|
100 |
if load and os.path.exists(file_path):
|
101 |
progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`')
|
102 |
try:
|
103 |
+
df = utils.read_data(file_path)
|
104 |
except Exception as e:
|
105 |
info.warning(f'Failed to load history from `{file_path}`')
|
106 |
st.exception(e)
|
|
|
113 |
|
114 |
print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}')
|
115 |
df.info()
|
116 |
+
|
117 |
if save and run.state != 'running':
|
118 |
df.to_csv(file_path, index=False)
|
119 |
# st.info(f'Saved history to {file_path}')
|
|
|
153 |
df = df.loc[demo_selection]
|
154 |
run_msg.info(f"Selected {len(df)} runs")
|
155 |
return df
|
156 |
+
|
157 |
df = df.copy()
|
158 |
|
159 |
# Try to convert datetimes into a standarrd format (datetime, no timezone)
|
opendashboards/assets/plot.py
CHANGED
@@ -2,6 +2,12 @@
|
|
2 |
import streamlit as st
|
3 |
import opendashboards.utils.plotting as plotting
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
# @st.cache_data
|
6 |
def uid_diversty(df, rm_failed=True):
|
7 |
return st.plotly_chart(
|
|
|
2 |
import streamlit as st
|
3 |
import opendashboards.utils.plotting as plotting
|
4 |
|
5 |
+
def timeline(df_runs, color='state'):
|
6 |
+
return st.plotly_chart(
|
7 |
+
plotting.plot_gantt(df_runs, color=color),
|
8 |
+
use_container_width=True
|
9 |
+
)
|
10 |
+
|
11 |
# @st.cache_data
|
12 |
def uid_diversty(df, rm_failed=True):
|
13 |
return st.plotly_chart(
|
opendashboards/utils/plotting.py
CHANGED
@@ -20,6 +20,7 @@ import tqdm
|
|
20 |
import pandas as pd
|
21 |
import numpy as np
|
22 |
import networkx as nx
|
|
|
23 |
|
24 |
import plotly.express as px
|
25 |
import plotly.graph_objects as go
|
@@ -28,21 +29,22 @@ from typing import List, Union
|
|
28 |
|
29 |
plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
|
30 |
|
31 |
-
def plot_gantt(df_runs: pd.DataFrame, y='username'):
|
32 |
-
|
33 |
-
|
|
|
34 |
title="Timeline of WandB Runs",
|
35 |
category_orders={'run_name': df_runs.run_name.unique()},
|
36 |
hover_name="run_name",
|
37 |
hover_data=[col for col in ['hotkey','user','username','run_id','num_steps','num_completions'] if col in df_runs],
|
38 |
-
color_discrete_map={
|
39 |
opacity=0.3,
|
40 |
width=1200,
|
41 |
height=800,
|
42 |
template="plotly_white",
|
43 |
)
|
44 |
# remove y axis ticks
|
45 |
-
fig.update_yaxes(
|
46 |
return fig
|
47 |
|
48 |
def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
|
|
|
20 |
import pandas as pd
|
21 |
import numpy as np
|
22 |
import networkx as nx
|
23 |
+
import streamlit as st
|
24 |
|
25 |
import plotly.express as px
|
26 |
import plotly.graph_objects as go
|
|
|
29 |
|
30 |
plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
|
31 |
|
32 |
+
def plot_gantt(df_runs: pd.DataFrame, y='username', color="state"):
|
33 |
+
color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'}
|
34 |
+
fig = px.timeline(df_runs.astype({color: str}),
|
35 |
+
x_start="start_time", x_end="end_time", y=y, color=color,
|
36 |
title="Timeline of WandB Runs",
|
37 |
category_orders={'run_name': df_runs.run_name.unique()},
|
38 |
hover_name="run_name",
|
39 |
hover_data=[col for col in ['hotkey','user','username','run_id','num_steps','num_completions'] if col in df_runs],
|
40 |
+
color_discrete_map={k: v for k, v in color_discrete_map.items() if k in df_runs[color].unique()},
|
41 |
opacity=0.3,
|
42 |
width=1200,
|
43 |
height=800,
|
44 |
template="plotly_white",
|
45 |
)
|
46 |
# remove y axis ticks
|
47 |
+
fig.update_yaxes(title='')
|
48 |
return fig
|
49 |
|
50 |
def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
|
opendashboards/utils/utils.py
CHANGED
@@ -144,10 +144,12 @@ def read_data(path: str, nrows: int = None):
|
|
144 |
"""Load data from csv."""
|
145 |
df = pd.read_csv(path, nrows=nrows)
|
146 |
# filter out events with missing step length
|
147 |
-
df = df.loc[df.step_length.notna()]
|
148 |
|
149 |
# detect list columns which as stored as strings
|
150 |
-
|
|
|
|
|
151 |
# convert string representation of list to list
|
152 |
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
|
153 |
|
@@ -161,6 +163,7 @@ def load_data(selected_runs, load=True, save=False, explode=True, datadir='data/
|
|
161 |
if not os.path.exists(datadir):
|
162 |
os.makedirs(datadir)
|
163 |
|
|
|
164 |
pbar = tqdm.tqdm(selected_runs.index, desc="Loading runs", total=len(selected_runs), unit="run")
|
165 |
for i, idx in enumerate(pbar):
|
166 |
run = selected_runs.loc[idx]
|
|
|
144 |
"""Load data from csv."""
|
145 |
df = pd.read_csv(path, nrows=nrows)
|
146 |
# filter out events with missing step length
|
147 |
+
# df = df.loc[df.step_length.notna()]
|
148 |
|
149 |
# detect list columns which as stored as strings
|
150 |
+
def is_list_col(x):
|
151 |
+
return isinstance(x, str) and x[0]=='[' and x[-1]==']'
|
152 |
+
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].apply(is_list_col).all()]
|
153 |
# convert string representation of list to list
|
154 |
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
|
155 |
|
|
|
163 |
if not os.path.exists(datadir):
|
164 |
os.makedirs(datadir)
|
165 |
|
166 |
+
st.write(selected_runs)
|
167 |
pbar = tqdm.tqdm(selected_runs.index, desc="Loading runs", total=len(selected_runs), unit="run")
|
168 |
for i, idx in enumerate(pbar):
|
169 |
run = selected_runs.loc[idx]
|