Spaces:
Paused
Paused
HELLO
Browse files- dashboard.py +317 -0
- hello_world.py +11 -0
- plotting.py +362 -0
- requirements.txt +7 -0
- utils.py +127 -0
dashboard.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from .utils import get_runs, download_data, get_list_col_lengths, explode_data
|
7 |
+
import .plotting as plotting
|
8 |
+
|
9 |
+
|
10 |
+
# dendrite time versus completion length
|
11 |
+
# prompt-based completion score stats
|
12 |
+
|
13 |
+
|
14 |
+
DEFAULT_PROJECT = "openvalidators"
|
15 |
+
DEFAULT_FILTERS = {"tags": {"$in": ["1.0.0", "1.0.1", "1.0.2", "1.0.3", "1.0.4"]}}
|
16 |
+
|
17 |
+
@st.cache_data
|
18 |
+
def load_runs(project=DEFAULT_PROJECT, filters=DEFAULT_FILTERS, min_steps=10):
|
19 |
+
runs = []
|
20 |
+
msg = st.empty()
|
21 |
+
for run in get_runs(project, filters):
|
22 |
+
step = run.summary.get('_step',0)
|
23 |
+
if step < min_steps:
|
24 |
+
msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
|
25 |
+
continue
|
26 |
+
|
27 |
+
duration = run.summary.get('_runtime')
|
28 |
+
end_time = run.summary.get('_timestamp')
|
29 |
+
# extract values for selected tags
|
30 |
+
rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
|
31 |
+
# tags = {k: tag for k,tag in zip(('hotkey','version','spec_version'),run.tags)}
|
32 |
+
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
|
33 |
+
# include bool flag for remaining tags
|
34 |
+
tags.update({k: k in run.tags for k in ('mock','custom_gating_model','nsfw_filter','outsource_scoring','disable_set_weights')})
|
35 |
+
|
36 |
+
runs.append({
|
37 |
+
'state': run.state,
|
38 |
+
'num_steps': step,
|
39 |
+
'entity': run.entity,
|
40 |
+
'id': run.id,
|
41 |
+
'name': run.name,
|
42 |
+
'project': run.project,
|
43 |
+
'url': run.url,
|
44 |
+
'path': os.path.join(run.entity, run.project, run.id),
|
45 |
+
'start_time': pd.to_datetime(end_time-duration, unit="s"),
|
46 |
+
'end_time': pd.to_datetime(end_time, unit="s"),
|
47 |
+
'duration': pd.to_datetime(duration, unit="s"),
|
48 |
+
# 'tags': run.tags,
|
49 |
+
**tags
|
50 |
+
})
|
51 |
+
msg.empty()
|
52 |
+
return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})
|
53 |
+
|
54 |
+
|
55 |
+
@st.cache_data
|
56 |
+
def load_data(selected_runs, load=True, save=False):
|
57 |
+
|
58 |
+
frames = []
|
59 |
+
n_events = 0
|
60 |
+
progress = st.progress(0, 'Loading data')
|
61 |
+
for i, idx in enumerate(selected_runs.index):
|
62 |
+
run = selected_runs.loc[idx]
|
63 |
+
prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({i}/{len(selected_runs)} runs, {n_events} events)'
|
64 |
+
|
65 |
+
file_path = f'data/history-{run.id}.csv'
|
66 |
+
|
67 |
+
if load and os.path.exists(file_path):
|
68 |
+
progress.progress(i/len(selected_runs),f'{prog_msg}... reading {file_path}')
|
69 |
+
df = pd.read_csv(file_path)
|
70 |
+
# filter out events with missing step length
|
71 |
+
df = df.loc[df.step_length.notna()]
|
72 |
+
|
73 |
+
# detect list columns which as stored as strings
|
74 |
+
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
|
75 |
+
# convert string representation of list to list
|
76 |
+
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
|
77 |
+
|
78 |
+
else:
|
79 |
+
try:
|
80 |
+
# Download the history from wandb
|
81 |
+
progress.progress(i/len(selected_runs),f'{prog_msg}... downloading `{run.path}`')
|
82 |
+
df = download_data(run.path)
|
83 |
+
df.assign(**run.to_dict())
|
84 |
+
if not os.path.exists('data/'):
|
85 |
+
os.makedirs(file_path)
|
86 |
+
|
87 |
+
if save and run.state != 'running':
|
88 |
+
df.to_csv(file_path, index=False)
|
89 |
+
# st.info(f'Saved history to {file_path}')
|
90 |
+
except Exception as e:
|
91 |
+
st.error(f'Failed to download history for `{run.path}`')
|
92 |
+
st.exception(e)
|
93 |
+
continue
|
94 |
+
|
95 |
+
frames.append(df)
|
96 |
+
n_events += df.shape[0]
|
97 |
+
|
98 |
+
progress.empty()
|
99 |
+
# Remove rows which contain chain weights as it messes up schema
|
100 |
+
return pd.concat(frames)
|
101 |
+
|
102 |
+
@st.cache_data
|
103 |
+
def get_exploded_data(df):
|
104 |
+
list_cols = get_list_col_lengths(df)
|
105 |
+
return explode_data(df, list(list_cols))
|
106 |
+
|
107 |
+
@st.cache_data
|
108 |
+
def get_completions(df_long, col):
|
109 |
+
return df_long[col].value_counts()
|
110 |
+
|
111 |
+
@st.cache_data
|
112 |
+
def plot_uid_diversty(df, remove_unsuccessful=True):
|
113 |
+
return plotting.plot_uid_diversty(df, remove_unsuccessful=remove_unsuccessful)
|
114 |
+
|
115 |
+
@st.cache_data
|
116 |
+
def plot_leaderboard(df, ntop, group_on, agg_col, agg, alias=False):
|
117 |
+
return plotting.plot_leaderboard(df, ntop=ntop, group_on=group_on, agg_col=agg_col, agg=agg, alias=alias)
|
118 |
+
|
119 |
+
@st.cache_data
|
120 |
+
def plot_completion_rewards(df, completion_col, reward_col, uid_col, ntop, completions=None, completion_regex=None):
|
121 |
+
return plotting.plot_completion_rewards(df, msg_col=completion_col, reward_col=reward_col, uid_col=uid_col, ntop=ntop, completions=completions, completion_regex=completion_regex)
|
122 |
+
|
123 |
+
@st.cache_data
|
124 |
+
def uid_metrics(df_long, src, uid=None):
|
125 |
+
|
126 |
+
uid_col = f'{src}_uids'
|
127 |
+
completion_col = f'{src}_completions'
|
128 |
+
nsfw_col = f'{src}_nsfw_scores'
|
129 |
+
reward_col = f'{src}_rewards'
|
130 |
+
|
131 |
+
if uid is not None:
|
132 |
+
df_long = df_long.loc[df_long[uid_col] == uid]
|
133 |
+
|
134 |
+
col1, col2, col3 = st.columns(3)
|
135 |
+
col1.metric(
|
136 |
+
label="Success %",
|
137 |
+
value=f'{df_long.loc[df_long[completion_col].str.len() > 0].shape[0]/df_long.shape[0] * 100:.1f}'
|
138 |
+
)
|
139 |
+
col2.metric(
|
140 |
+
label="Diversity %",
|
141 |
+
value=f'{df_long[completion_col].nunique()/df_long.shape[0] * 100:.1f}'
|
142 |
+
)
|
143 |
+
col3.metric(
|
144 |
+
label="Toxicity %",
|
145 |
+
value=f'{df_long[nsfw_col].mean() * 100:.1f}' if nsfw_col in df_long.columns else 'N/A'
|
146 |
+
)
|
147 |
+
|
148 |
+
st.title('Validator :red[Analysis] Dashboard :eyes:')
|
149 |
+
# add vertical space
|
150 |
+
st.markdown('#')
|
151 |
+
st.markdown('#')
|
152 |
+
|
153 |
+
|
154 |
+
with st.sidebar:
|
155 |
+
st.sidebar.header('Pages')
|
156 |
+
|
157 |
+
with st.spinner(text=f'Checking wandb...'):
|
158 |
+
df_runs = load_runs()
|
159 |
+
# get rows where start time is older than 24h ago
|
160 |
+
df_runs_old = df_runs.loc[df_runs.start_time < pd.to_datetime(time.time()-24*60*60, unit='s')]
|
161 |
+
|
162 |
+
col1, col2, col3 = st.columns(3)
|
163 |
+
|
164 |
+
col1.metric('Runs', df_runs.shape[0], delta=f'{df_runs.shape[0]-df_runs_old.shape[0]} (24h)')
|
165 |
+
col2.metric('Hotkeys', df_runs.hotkey.nunique(), delta=f'{df_runs.hotkey.nunique()-df_runs_old.hotkey.nunique()} (24h)')
|
166 |
+
col3.metric('Events', df_runs.num_steps.sum(), delta=f'{df_runs.num_steps.sum()-df_runs_old.num_steps.sum()} (24h)')
|
167 |
+
|
168 |
+
# https://wandb.ai/opentensor-dev/openvalidators/runs/kt9bzxii/overview?workspace=
|
169 |
+
# all_run_paths = ['opentensor-dev/openvalidators/kt9bzxii'] # pedro long run
|
170 |
+
|
171 |
+
run_ids = df_runs.id
|
172 |
+
default_selected_runs = ['kt9bzxii']
|
173 |
+
selected_runs = default_selected_runs
|
174 |
+
|
175 |
+
# add vertical space
|
176 |
+
st.markdown('#')
|
177 |
+
st.markdown('#')
|
178 |
+
|
179 |
+
|
180 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Wandb Runs", "UID Health", "Completions", "Prompt-based scoring"])
|
181 |
+
|
182 |
+
# src = st.radio('Choose data source:', ['followup', 'answer'], horizontal=True, key='src')
|
183 |
+
# list_list_cols = get_list_col_lengths(df_long)
|
184 |
+
# df_long_long = explode_data(df_long, list(list_list_cols))
|
185 |
+
|
186 |
+
with tab1:
|
187 |
+
|
188 |
+
st.markdown('#')
|
189 |
+
st.subheader(":violet[Wandb] Runs")
|
190 |
+
|
191 |
+
# Load data
|
192 |
+
df = load_data(df_runs.loc[run_ids.isin(selected_runs)], load=True, save=True)
|
193 |
+
df_long = get_exploded_data(df)
|
194 |
+
|
195 |
+
col1, col2, col3, col4 = st.columns(4)
|
196 |
+
col1.metric(label="Selected runs", value=len(selected_runs))
|
197 |
+
col2.metric(label="Events", value=df.shape[0]) #
|
198 |
+
col3.metric(label="UIDs", value=df_long.followup_uids.nunique())
|
199 |
+
col4.metric(label="Unique completions", value=df_long.followup_uids.nunique())
|
200 |
+
|
201 |
+
selected_runs = st.multiselect(f'Runs ({len(df_runs)})', run_ids, default=selected_runs)
|
202 |
+
|
203 |
+
st.markdown('#')
|
204 |
+
st.subheader("View :violet[Data]")
|
205 |
+
|
206 |
+
show_col1, show_col2 = st.columns(2)
|
207 |
+
show_runs = show_col1.checkbox('Show runs', value=True)
|
208 |
+
show_events = show_col2.checkbox('Show events', value=False)
|
209 |
+
if show_runs:
|
210 |
+
st.markdown(f'Wandb info for **{len(selected_runs)} selected runs**:')
|
211 |
+
st.dataframe(df_runs.loc[run_ids.isin(selected_runs)],
|
212 |
+
column_config={
|
213 |
+
"url": st.column_config.LinkColumn("URL"),
|
214 |
+
}
|
215 |
+
)
|
216 |
+
|
217 |
+
if show_events:
|
218 |
+
st.markdown(f'Raw events for **{len(selected_runs)} selected runs**:')
|
219 |
+
st.dataframe(df.head(50),
|
220 |
+
column_config={
|
221 |
+
"url": st.column_config.LinkColumn("URL"),
|
222 |
+
}
|
223 |
+
)
|
224 |
+
|
225 |
+
default_src = 'followup'
|
226 |
+
with tab2:
|
227 |
+
|
228 |
+
st.markdown('#')
|
229 |
+
st.subheader("UID :violet[Health]")
|
230 |
+
uid_src = default_src
|
231 |
+
|
232 |
+
# uid = st.selectbox('UID:', sorted(df_long[uid_col].unique()), key='uid')
|
233 |
+
|
234 |
+
uid_metrics(df_long, uid_src)
|
235 |
+
uid_src = st.radio('Select one:', ['followup', 'answer'], horizontal=True, key='uid_src')
|
236 |
+
uid_col = f'{uid_src}_uids'
|
237 |
+
reward_col = f'{uid_src}_rewards'
|
238 |
+
|
239 |
+
st.markdown('#')
|
240 |
+
st.subheader("UID :violet[Leaderboard]")
|
241 |
+
uid_ntop_default = 10
|
242 |
+
|
243 |
+
uid_col1, uid_col2 = st.columns(2)
|
244 |
+
uid_ntop = uid_col1.slider('Number of UIDs:', min_value=1, max_value=50, value=uid_ntop_default, key='uid_ntop')
|
245 |
+
uid_agg = uid_col2.selectbox('Aggregation:', ('mean','min','max','size','nunique'), key='uid_agg')
|
246 |
+
|
247 |
+
st.plotly_chart(
|
248 |
+
plot_leaderboard(
|
249 |
+
df,
|
250 |
+
ntop=uid_ntop,
|
251 |
+
group_on=uid_col,
|
252 |
+
agg_col=reward_col,
|
253 |
+
agg=uid_agg
|
254 |
+
)
|
255 |
+
)
|
256 |
+
remove_unsuccessful = st.checkbox('Remove failed completions', value=True)
|
257 |
+
st.plotly_chart(
|
258 |
+
plot_uid_diversty(
|
259 |
+
df,
|
260 |
+
remove_unsuccessful=remove_unsuccessful
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
completion_ntop_default = 10
|
266 |
+
with tab3:
|
267 |
+
|
268 |
+
st.markdown('#')
|
269 |
+
st.subheader('Completion :violet[Leaderboard]')
|
270 |
+
completion_src = default_src
|
271 |
+
|
272 |
+
msg_col1, msg_col2 = st.columns(2)
|
273 |
+
completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
|
274 |
+
completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=completion_ntop_default, key='completion_ntop')
|
275 |
+
|
276 |
+
completion_col = f'{completion_src}_completions'
|
277 |
+
reward_col = f'{completion_src}_rewards'
|
278 |
+
uid_col = f'{completion_src}_uids'
|
279 |
+
|
280 |
+
completions = get_completions(df_long, completion_col)
|
281 |
+
|
282 |
+
# completion_sel = st.radio('Select input method:', ['ntop', 'select','regex'], horizontal=True, key='completion_sel')
|
283 |
+
# Get completions with highest average rewards
|
284 |
+
st.plotly_chart(
|
285 |
+
plot_leaderboard(
|
286 |
+
df,
|
287 |
+
ntop=completion_ntop,
|
288 |
+
group_on=completion_col,
|
289 |
+
agg_col=reward_col,
|
290 |
+
agg='mean',
|
291 |
+
alias=True
|
292 |
+
)
|
293 |
+
)
|
294 |
+
st.markdown('#')
|
295 |
+
st.subheader('Completion :violet[Rewards]')
|
296 |
+
|
297 |
+
completion_select = st.multiselect('Completions:', completions.index, default=completions.index[:3].tolist())
|
298 |
+
# completion_regex = st.text_input('Completion regex:', value='', key='completion_regex')
|
299 |
+
|
300 |
+
st.plotly_chart(
|
301 |
+
plot_completion_rewards(
|
302 |
+
df,
|
303 |
+
completion_col=completion_col,
|
304 |
+
reward_col=reward_col,
|
305 |
+
uid_col=uid_col,
|
306 |
+
ntop=completion_ntop,
|
307 |
+
completions=completion_select,
|
308 |
+
)
|
309 |
+
)
|
310 |
+
|
311 |
+
with tab4:
|
312 |
+
st.subheader(':pink[Prompt-based scoring]')
|
313 |
+
prompt_src = st.radio('Select one:', ['followup', 'answer'], key='prompt')
|
314 |
+
|
315 |
+
|
316 |
+
# st.dataframe(df_long_long.filter(regex=prompt_src).head())
|
317 |
+
|
hello_world.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from .utils import get_runs, download_data, get_list_col_lengths, explode_data
|
7 |
+
import .plotting as plotting
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
st.write('HELLO BOITCHES')
|
plotting.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
# Copyright © 2021 Yuma Rao
|
3 |
+
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
5 |
+
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
|
7 |
+
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
|
10 |
+
# the Software.
|
11 |
+
|
12 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
|
13 |
+
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
14 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
15 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
16 |
+
# DEALINGS IN THE SOFTWARE.
|
17 |
+
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
import pandas as pd
|
21 |
+
import numpy as np
|
22 |
+
import networkx as nx
|
23 |
+
|
24 |
+
import plotly.express as px
|
25 |
+
import plotly.graph_objects as go
|
26 |
+
|
27 |
+
from typing import List, Union
|
28 |
+
|
29 |
+
plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
|
30 |
+
|
31 |
+
|
32 |
+
def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
|
33 |
+
"""Plot throughput of event log.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
df (pd.DataFrame): Dataframe of event log.
|
37 |
+
n_minutes (int, optional): Number of minutes to aggregate. Defaults to 10.
|
38 |
+
"""
|
39 |
+
|
40 |
+
rate = df.resample(rule=f"{n_minutes}T", on="_timestamp").size()
|
41 |
+
return px.line(
|
42 |
+
x=rate.index, y=rate, title="Event Log Throughput", labels={"x": "", "y": f"Logs / {n_minutes} min"}, **plotly_config
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def plot_weights(scores: pd.DataFrame, ntop: int = 20, uids: List[Union[str, int]] = None) -> go.Figure:
|
47 |
+
"""_summary_
|
48 |
+
|
49 |
+
Args:
|
50 |
+
scores (pd.DataFrame): Dataframe of scores. Should be indexed by timestamp and have one column per uid.
|
51 |
+
ntop (int, optional): Number of uids to plot. Defaults to 20.
|
52 |
+
uids (List[Union[str, int]], optional): List of uids to plot, should match column names. Defaults to None.
|
53 |
+
"""
|
54 |
+
|
55 |
+
# Select subset of columns for plotting
|
56 |
+
if uids is None:
|
57 |
+
uids = scores.columns[:ntop]
|
58 |
+
print(f"Using first {ntop} uids for plotting: {uids}")
|
59 |
+
|
60 |
+
return px.line(
|
61 |
+
scores, y=uids, title="Moving Averaged Scores", labels={"_timestamp": "", "value": "Score"}, **plotly_config
|
62 |
+
).update_traces(opacity=0.7)
|
63 |
+
|
64 |
+
|
65 |
+
def plot_uid_diversty(df: pd.DataFrame, remove_unsuccessful: bool = False) -> go.Figure:
|
66 |
+
"""Plot uid diversity as measured by ratio of unique to total completions.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
df (pd.DataFrame): Dataframe of event log.
|
70 |
+
"""
|
71 |
+
uid_cols = ["followup_uids", "answer_uids"]
|
72 |
+
completion_cols = ["followup_completions", "answer_completions"]
|
73 |
+
reward_cols = ["followup_rewards", "answer_rewards"]
|
74 |
+
list_cols = uid_cols + completion_cols + reward_cols
|
75 |
+
|
76 |
+
df = df[list_cols].explode(column=list_cols)
|
77 |
+
if remove_unsuccessful:
|
78 |
+
# remove unsuccessful completions, as indicated by empty completions
|
79 |
+
for col in completion_cols:
|
80 |
+
df = df[df[col].str.len() > 0]
|
81 |
+
|
82 |
+
frames = []
|
83 |
+
for uid_col, completion_col, reward_col in zip(uid_cols, completion_cols, reward_cols):
|
84 |
+
frame = df.groupby(uid_col).agg({completion_col: ["nunique", "size"], reward_col: "mean"})
|
85 |
+
# flatten multiindex columns
|
86 |
+
frame.columns = ["_".join(col) for col in frame.columns]
|
87 |
+
frame["diversity"] = frame[f"{completion_col}_nunique"] / frame[f"{completion_col}_size"]
|
88 |
+
frames.append(frame)
|
89 |
+
|
90 |
+
merged = pd.merge(*frames, left_index=True, right_index=True, suffixes=("_followup", "_answer"))
|
91 |
+
merged["reward_mean"] = merged.filter(regex="rewards_mean").mean(axis=1)
|
92 |
+
|
93 |
+
merged.index.name = "UID"
|
94 |
+
merged.reset_index(inplace=True)
|
95 |
+
|
96 |
+
return px.scatter(
|
97 |
+
merged,
|
98 |
+
x="diversity_followup",
|
99 |
+
y="diversity_answer",
|
100 |
+
opacity=0.3,
|
101 |
+
size="followup_completions_size",
|
102 |
+
color="reward_mean",
|
103 |
+
hover_data=["UID"] + merged.columns.tolist(),
|
104 |
+
marginal_x="histogram",
|
105 |
+
marginal_y="histogram",
|
106 |
+
color_continuous_scale=px.colors.sequential.Bluered,
|
107 |
+
labels={"x": "Followup diversity", "y": "Answer diversity"},
|
108 |
+
title="Diversity of completions by UID",
|
109 |
+
**plotly_config,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def plot_completion_rates(
|
114 |
+
df: pd.DataFrame,
|
115 |
+
msg_col: str = "all_completions",
|
116 |
+
time_interval: str = "H",
|
117 |
+
time_col: str = "_timestamp",
|
118 |
+
ntop: int = 20,
|
119 |
+
completions: List[str] = None,
|
120 |
+
completion_regex: str = None,
|
121 |
+
) -> go.Figure:
|
122 |
+
"""Plot completion rates. Useful for identifying common completions and attacks.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
df (pd.DataFrame): Dataframe of event log.
|
126 |
+
msg_col (str, optional): List-like column containing completions. Defaults to 'all_completions'.
|
127 |
+
time_interval (str, optional): Pandas time interval. Defaults to 'H'. See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases
|
128 |
+
time_col (str, optional): Column containing timestamps as pd.Datetime. Defaults to '_timestamp'.
|
129 |
+
ntop (int, optional): Number of completions to plot. Defaults to 20.
|
130 |
+
completions (List[str], optional): List of completions to plot. Defaults to None.
|
131 |
+
completion_regex (str, optional): Regex to match completions. Defaults to None.
|
132 |
+
|
133 |
+
"""
|
134 |
+
|
135 |
+
df = df[[time_col, msg_col]].explode(column=msg_col)
|
136 |
+
|
137 |
+
if completions is None:
|
138 |
+
completion_counts = df[msg_col].value_counts()
|
139 |
+
if completion_regex is not None:
|
140 |
+
completions = completion_counts[completion_counts.index.str.contains(completion_regex)].index[:ntop]
|
141 |
+
print(f"Using {len(completions)} completions which match {completion_regex!r}: \n{completions}")
|
142 |
+
else:
|
143 |
+
completions = completion_counts.index[:ntop]
|
144 |
+
print(f"Using top {len(completions)} completions: \n{completions}")
|
145 |
+
|
146 |
+
period = df[time_col].dt.to_period(time_interval)
|
147 |
+
|
148 |
+
counts = df.groupby([msg_col, period]).size()
|
149 |
+
top_counts = counts.loc[completions].reset_index().rename(columns={0: "Size"})
|
150 |
+
top_counts["Completion ID"] = top_counts[msg_col].map({k: f"{i}" for i, k in enumerate(completions, start=1)})
|
151 |
+
|
152 |
+
return px.line(
|
153 |
+
top_counts.astype({time_col: str}),
|
154 |
+
x=time_col,
|
155 |
+
y="Size",
|
156 |
+
color="Completion ID",
|
157 |
+
hover_data=[top_counts[msg_col].str.replace("\n", "<br>")],
|
158 |
+
labels={time_col: f"Time, {time_interval}", "Size": f"Occurrences / {time_interval}"},
|
159 |
+
title=f"Completion Rates for {len(completions)} Messages",
|
160 |
+
**plotly_config,
|
161 |
+
).update_traces(opacity=0.7)
|
162 |
+
|
163 |
+
|
164 |
+
def plot_completion_rewards(
|
165 |
+
df: pd.DataFrame,
|
166 |
+
msg_col: str = "followup_completions",
|
167 |
+
reward_col: str = "followup_rewards",
|
168 |
+
time_col: str = "_timestamp",
|
169 |
+
uid_col: str = "followup_uids",
|
170 |
+
ntop: int = 3,
|
171 |
+
completions: List[str] = None,
|
172 |
+
completion_regex: str = None,
|
173 |
+
) -> go.Figure:
|
174 |
+
"""Plot completion rewards. Useful for tracking common completions and their rewards.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
df (pd.DataFrame): Dataframe of event log.
|
178 |
+
msg_col (str, optional): List-like column containing completions. Defaults to 'followup_completions'.
|
179 |
+
reward_col (str, optional): List-like column containing rewards. Defaults to 'followup_rewards'.
|
180 |
+
time_col (str, optional): Column containing timestamps as pd.Datetime. Defaults to '_timestamp'.
|
181 |
+
ntop (int, optional): Number of completions to plot. Defaults to 20.
|
182 |
+
completions (List[str], optional): List of completions to plot. Defaults to None.
|
183 |
+
completion_regex (str, optional): Regex to match completions. Defaults to None.
|
184 |
+
|
185 |
+
"""
|
186 |
+
|
187 |
+
df = (
|
188 |
+
df[[time_col, uid_col, msg_col, reward_col]]
|
189 |
+
.explode(column=[msg_col, uid_col, reward_col])
|
190 |
+
.rename(columns={uid_col: "UID"})
|
191 |
+
)
|
192 |
+
completion_counts = df[msg_col].value_counts()
|
193 |
+
|
194 |
+
if completions is None:
|
195 |
+
if completion_regex is not None:
|
196 |
+
completions = completion_counts[completion_counts.index.str.contains(completion_regex)].index[:ntop]
|
197 |
+
print(f"Using {len(completions)} completions which match {completion_regex!r}: \n{completions}")
|
198 |
+
else:
|
199 |
+
completions = completion_counts.index[:ntop]
|
200 |
+
print(f"Using top {len(completions)} completions: \n{completions}")
|
201 |
+
|
202 |
+
# Get ranks of completions in terms of number of occurrences
|
203 |
+
ranks = completion_counts.rank(method="dense", ascending=False).loc[completions].astype(int)
|
204 |
+
|
205 |
+
# Filter to only the selected completions
|
206 |
+
df = df.loc[df[msg_col].isin(completions)]
|
207 |
+
df["rank"] = df[msg_col].map(ranks).astype(str)
|
208 |
+
df["Total"] = df[msg_col].map(completion_counts)
|
209 |
+
|
210 |
+
return px.scatter(
|
211 |
+
df,
|
212 |
+
x=time_col,
|
213 |
+
y=reward_col,
|
214 |
+
color="rank",
|
215 |
+
hover_data=[msg_col, "UID", "Total"],
|
216 |
+
category_orders={"rank": sorted(df["rank"].unique())},
|
217 |
+
marginal_x="histogram",
|
218 |
+
marginal_y="violin",
|
219 |
+
labels={"rank": "Rank", reward_col: "Reward", time_col: ""},
|
220 |
+
title=f"Rewards for {len(completions)} Messages",
|
221 |
+
**plotly_config,
|
222 |
+
opacity=0.3,
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def plot_leaderboard(
|
227 |
+
df: pd.DataFrame,
|
228 |
+
group_on: str = "answer_uids",
|
229 |
+
agg_col: str = "answer_rewards",
|
230 |
+
agg: str = "mean",
|
231 |
+
ntop: int = 10,
|
232 |
+
alias: bool = False,
|
233 |
+
) -> go.Figure:
|
234 |
+
"""Plot leaderboard for a given column. By default plots the top 10 UIDs by mean reward.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
df (pd.DataFrame): Dataframe of event log.
|
238 |
+
group_on (str, optional): Entities to use for grouping. Defaults to 'answer_uids'.
|
239 |
+
agg_col (str, optional): Column to aggregate. Defaults to 'answer_rewards'.
|
240 |
+
agg (str, optional): Aggregation function. Defaults to 'mean'.
|
241 |
+
ntop (int, optional): Number of entities to plot. Defaults to 10.
|
242 |
+
alias (bool, optional): Whether to use aliases for indices. Defaults to False.
|
243 |
+
"""
|
244 |
+
df = df[[group_on, agg_col]].explode(column=[group_on, agg_col])
|
245 |
+
|
246 |
+
rankings = df.groupby(group_on)[agg_col].agg(agg).sort_values(ascending=False).head(ntop)
|
247 |
+
if alias:
|
248 |
+
index = rankings.index.map({name: str(i) for i, name in enumerate(rankings.index)})
|
249 |
+
else:
|
250 |
+
index = rankings.index.astype(str)
|
251 |
+
|
252 |
+
return px.bar(
|
253 |
+
x=rankings,
|
254 |
+
y=index,
|
255 |
+
color=rankings,
|
256 |
+
orientation="h",
|
257 |
+
labels={"x": f"{agg_col.title()}", "y": group_on, "color": ""},
|
258 |
+
title=f"Leaderboard for {agg_col}, top {ntop} {group_on}",
|
259 |
+
color_continuous_scale="BlueRed",
|
260 |
+
opacity=0.5,
|
261 |
+
hover_data=[rankings.index.astype(str)],
|
262 |
+
**plotly_config,
|
263 |
+
)
|
264 |
+
|
265 |
+
|
266 |
+
def plot_dendrite_rates(
|
267 |
+
df: pd.DataFrame, uid_col: str = "answer_uids", reward_col: str = "answer_rewards", ntop: int = 20, uids: List[int] = None
|
268 |
+
) -> go.Figure:
|
269 |
+
"""Makes a bar chart of the success rate of dendrite calls for a given set of uids.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
df (pd.DataFrame): Dataframe of event log.
|
273 |
+
uid_col (str, optional): Column containing uids. Defaults to 'answer_uids'.
|
274 |
+
reward_col (str, optional): Column containing rewards. Defaults to 'answer_rewards'.
|
275 |
+
ntop (int, optional): Number of uids to plot. Defaults to 20.
|
276 |
+
uids (List[int], optional): List of uids to plot. Defaults to None.
|
277 |
+
|
278 |
+
"""
|
279 |
+
|
280 |
+
df = df[[uid_col, reward_col]].explode(column=[uid_col, reward_col]).rename(columns={uid_col: "UID"})
|
281 |
+
df["success"] = df[reward_col] != 0
|
282 |
+
|
283 |
+
if uids is None:
|
284 |
+
uids = df["UID"].value_counts().head(ntop).index
|
285 |
+
df = df.loc[df["UID"].isin(uids)]
|
286 |
+
|
287 |
+
# get total and successful dendrite calls
|
288 |
+
rates = df.groupby("UID").success.agg(["sum", "count"]).rename(columns={"sum": "Success", "count": "Total"})
|
289 |
+
rates = rates.melt(ignore_index=False).reset_index()
|
290 |
+
return px.bar(
|
291 |
+
rates.astype({"UID": str}),
|
292 |
+
x="value",
|
293 |
+
y="UID",
|
294 |
+
color="variable",
|
295 |
+
labels={"value": "Number of Calls", "variable": ""},
|
296 |
+
barmode="group",
|
297 |
+
title="Dendrite Calls by UID",
|
298 |
+
color_continuous_scale="Blues",
|
299 |
+
opacity=0.5,
|
300 |
+
**plotly_config,
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
def plot_network_embedding(
|
305 |
+
df: pd.DataFrame,
|
306 |
+
uid_col: str = "followup_uids",
|
307 |
+
completion_col: str = "followup_completions",
|
308 |
+
ntop: int = 1,
|
309 |
+
uids: List[int] = None,
|
310 |
+
) -> go.Figure:
|
311 |
+
"""Plots a network embedding of the most common completions for a given set of uids.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
df (pd.DataFrame): Dataframe of event log.
|
315 |
+
|
316 |
+
uid_col (str, optional): Column containing uids. Defaults to 'answer_uids'.
|
317 |
+
completion_col (str, optional): Column containing completions. Defaults to 'followup_completions'.
|
318 |
+
ntop (int, optional): Number of uids to plot. Defaults to 20.
|
319 |
+
hover_data (List[str], optional): Columns to include in hover data. Defaults to None.
|
320 |
+
uids (List[int], optional): List of uids to plot. Defaults to None.
|
321 |
+
|
322 |
+
# TODO: use value counts to use weighted similarity instead of a simple set intersection
|
323 |
+
"""
|
324 |
+
top_completions = {}
|
325 |
+
df = df[[uid_col, completion_col]].explode(column=[uid_col, completion_col])
|
326 |
+
|
327 |
+
if uids is None:
|
328 |
+
uids = df[uid_col].unique()
|
329 |
+
# loop over UIDs and compute ntop most common completions
|
330 |
+
for uid in tqdm.tqdm(uids, unit="UID"):
|
331 |
+
c = df.loc[df[uid_col] == uid, completion_col].value_counts()
|
332 |
+
top_completions[uid] = set(c.index[:ntop])
|
333 |
+
|
334 |
+
a = np.zeros((len(uids), len(uids)))
|
335 |
+
# now compute similarity matrix as a set intersection
|
336 |
+
for i, uid in enumerate(uids):
|
337 |
+
for j, uid2 in enumerate(uids[i + 1 :], start=i + 1):
|
338 |
+
a[i, j] = a[j, i] = len(top_completions[uid].intersection(top_completions[uid2])) / ntop
|
339 |
+
|
340 |
+
# make a graph from the similarity matrix
|
341 |
+
g = nx.from_numpy_array(a)
|
342 |
+
z = pd.DataFrame(nx.spring_layout(g)).T.rename(columns={0: "x", 1: "y"})
|
343 |
+
z["UID"] = uids
|
344 |
+
z["top_completions"] = pd.Series(top_completions).apply(list)
|
345 |
+
|
346 |
+
# assign groups based on cliques (fully connected subgraphs)
|
347 |
+
cliques = {
|
348 |
+
uids[cc]: f"Group-{i}" if len(c) > 1 else "Other" for i, c in enumerate(nx.find_cliques(g), start=1) for cc in c
|
349 |
+
}
|
350 |
+
z["Group"] = z["UID"].map(cliques)
|
351 |
+
|
352 |
+
return px.scatter(
|
353 |
+
z.reset_index(),
|
354 |
+
x="x",
|
355 |
+
y="y",
|
356 |
+
color="Group",
|
357 |
+
title=f"Graph for Top {ntop} Completion Similarities",
|
358 |
+
color_continuous_scale="BlueRed",
|
359 |
+
hover_data=["UID", "top_completions"],
|
360 |
+
opacity=0.5,
|
361 |
+
**plotly_config,
|
362 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb==0.15.3
|
2 |
+
datasets==2.12.0
|
3 |
+
plotly==5.14.1
|
4 |
+
networkx==3.1
|
5 |
+
scipy==1.10.1
|
6 |
+
pre-commit==3.3.2
|
7 |
+
click==8.1.3
|
utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
# Copyright © 2021 Yuma Rao
|
3 |
+
|
4 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
5 |
+
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
|
7 |
+
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
8 |
+
|
9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
|
10 |
+
# the Software.
|
11 |
+
|
12 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
|
13 |
+
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
14 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
15 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
16 |
+
# DEALINGS IN THE SOFTWARE.
|
17 |
+
|
18 |
+
import os
|
19 |
+
import tqdm
|
20 |
+
import wandb
|
21 |
+
import pandas as pd
|
22 |
+
from pandas.api.types import is_list_like
|
23 |
+
|
24 |
+
from typing import List, Dict, Any, Union
|
25 |
+
|
26 |
+
|
27 |
+
def get_runs(project: str = "openvalidators", filters: Dict[str, Any] = None, return_paths: bool = False) -> List:
|
28 |
+
"""Download runs from wandb.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
project (str): Name of the project. Defaults to 'openvalidators' (community project)
|
32 |
+
filters (Dict[str, Any], optional): Optional run filters for wandb api. Defaults to None.
|
33 |
+
return_paths (bool, optional): Return only run paths. Defaults to False.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
List[wandb.apis.public.Run]: List of runs or run paths (List[str]).
|
37 |
+
"""
|
38 |
+
api = wandb.Api()
|
39 |
+
wandb.login()
|
40 |
+
|
41 |
+
runs = api.runs(project, filters=filters)
|
42 |
+
if return_paths:
|
43 |
+
return [os.path.join(run.entity, run.project, run.id) for run in runs]
|
44 |
+
else:
|
45 |
+
return runs
|
46 |
+
|
47 |
+
|
48 |
+
def download_data(run_path: Union[str, List] = None, timeout: float = 600) -> pd.DataFrame:
|
49 |
+
"""Download data from wandb.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
run_path (Union[str, List], optional): Path to run or list of paths. Defaults to None.
|
53 |
+
timeout (float, optional): Timeout for wandb api. Defaults to 600.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
pd.DataFrame: Dataframe of event log.
|
57 |
+
"""
|
58 |
+
api = wandb.Api(timeout=timeout)
|
59 |
+
wandb.login()
|
60 |
+
|
61 |
+
if isinstance(run_path, str):
|
62 |
+
run_path = [run_path]
|
63 |
+
|
64 |
+
frames = []
|
65 |
+
total_events = 0
|
66 |
+
pbar = tqdm.tqdm(sorted(run_path), desc="Loading history from wandb", total=len(run_path), unit="run")
|
67 |
+
for path in pbar:
|
68 |
+
run = api.run(path)
|
69 |
+
|
70 |
+
frame = pd.DataFrame(list(run.scan_history()))
|
71 |
+
frames.append(frame)
|
72 |
+
total_events += len(frame)
|
73 |
+
|
74 |
+
pbar.set_postfix({"total_events": total_events})
|
75 |
+
|
76 |
+
df = pd.concat(frames)
|
77 |
+
# Convert timestamp to datetime.
|
78 |
+
df._timestamp = pd.to_datetime(df._timestamp, unit="s")
|
79 |
+
df.sort_values("_timestamp", inplace=True)
|
80 |
+
|
81 |
+
return df
|
82 |
+
|
83 |
+
|
84 |
+
def load_data(path: str, nrows: int = None):
|
85 |
+
"""Load data from csv."""
|
86 |
+
df = pd.read_csv(path, nrows=nrows)
|
87 |
+
# filter out events with missing step length
|
88 |
+
df = df.loc[df.step_length.notna()]
|
89 |
+
|
90 |
+
# detect list columns which as stored as strings
|
91 |
+
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
|
92 |
+
# convert string representation of list to list
|
93 |
+
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
|
94 |
+
|
95 |
+
return df
|
96 |
+
|
97 |
+
|
98 |
+
def explode_data(df: pd.DataFrame, list_cols: List[str] = None, list_len: int = None) -> pd.DataFrame:
|
99 |
+
"""Explode list columns in dataframe so that each element in the list is a separate row.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
df (pd.DataFrame): Dataframe of event log.
|
103 |
+
list_cols (List[str], optional): List of columns to explode. Defaults to None.
|
104 |
+
list_len (int, optional): Length of list. Defaults to None.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
pd.DataFrame: Dataframe with exploded list columns.
|
108 |
+
"""
|
109 |
+
if list_cols is None:
|
110 |
+
list_cols = [c for c in df.columns if df[c].apply(is_list_like).all()]
|
111 |
+
print(f"Exploding {len(list_cols)}) list columns with {list_len} elements: {list_cols}")
|
112 |
+
if list_len:
|
113 |
+
list_cols = [c for c in list_cols if df[c].apply(len).unique()[0] == list_len]
|
114 |
+
print(f"Exploding {len(list_cols)}) list columns with {list_len} elements: {list_cols}")
|
115 |
+
|
116 |
+
return df.explode(column=list_cols)
|
117 |
+
|
118 |
+
|
119 |
+
def get_list_col_lengths(df: pd.DataFrame) -> Dict[str, int]:
|
120 |
+
"""Helper function to get the length of list columns."""
|
121 |
+
list_col_lengths = {c: sorted(df[c].apply(len).unique()) for c in df.columns if df[c].apply(is_list_like).all()}
|
122 |
+
varying_lengths = {c: v for c, v in list_col_lengths.items() if len(v) > 1}
|
123 |
+
|
124 |
+
if len(varying_lengths) > 0:
|
125 |
+
print(f"The following columns have varying lengths: {varying_lengths}")
|
126 |
+
|
127 |
+
return {c: v[0] for c, v in list_col_lengths.items()}
|