steffenc commited on
Commit
f98fb68
·
1 Parent(s): 6a167cd
Files changed (5) hide show
  1. dashboard.py +317 -0
  2. hello_world.py +11 -0
  3. plotting.py +362 -0
  4. requirements.txt +7 -0
  5. utils.py +127 -0
dashboard.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import re
5
+ import time
6
+ from .utils import get_runs, download_data, get_list_col_lengths, explode_data
7
+ import .plotting as plotting
8
+
9
+
10
+ # dendrite time versus completion length
11
+ # prompt-based completion score stats
12
+
13
+
14
+ DEFAULT_PROJECT = "openvalidators"
15
+ DEFAULT_FILTERS = {"tags": {"$in": ["1.0.0", "1.0.1", "1.0.2", "1.0.3", "1.0.4"]}}
16
+
17
+ @st.cache_data
18
+ def load_runs(project=DEFAULT_PROJECT, filters=DEFAULT_FILTERS, min_steps=10):
19
+ runs = []
20
+ msg = st.empty()
21
+ for run in get_runs(project, filters):
22
+ step = run.summary.get('_step',0)
23
+ if step < min_steps:
24
+ msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
25
+ continue
26
+
27
+ duration = run.summary.get('_runtime')
28
+ end_time = run.summary.get('_timestamp')
29
+ # extract values for selected tags
30
+ rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
31
+ # tags = {k: tag for k,tag in zip(('hotkey','version','spec_version'),run.tags)}
32
+ tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
33
+ # include bool flag for remaining tags
34
+ tags.update({k: k in run.tags for k in ('mock','custom_gating_model','nsfw_filter','outsource_scoring','disable_set_weights')})
35
+
36
+ runs.append({
37
+ 'state': run.state,
38
+ 'num_steps': step,
39
+ 'entity': run.entity,
40
+ 'id': run.id,
41
+ 'name': run.name,
42
+ 'project': run.project,
43
+ 'url': run.url,
44
+ 'path': os.path.join(run.entity, run.project, run.id),
45
+ 'start_time': pd.to_datetime(end_time-duration, unit="s"),
46
+ 'end_time': pd.to_datetime(end_time, unit="s"),
47
+ 'duration': pd.to_datetime(duration, unit="s"),
48
+ # 'tags': run.tags,
49
+ **tags
50
+ })
51
+ msg.empty()
52
+ return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})
53
+
54
+
55
+ @st.cache_data
56
+ def load_data(selected_runs, load=True, save=False):
57
+
58
+ frames = []
59
+ n_events = 0
60
+ progress = st.progress(0, 'Loading data')
61
+ for i, idx in enumerate(selected_runs.index):
62
+ run = selected_runs.loc[idx]
63
+ prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({i}/{len(selected_runs)} runs, {n_events} events)'
64
+
65
+ file_path = f'data/history-{run.id}.csv'
66
+
67
+ if load and os.path.exists(file_path):
68
+ progress.progress(i/len(selected_runs),f'{prog_msg}... reading {file_path}')
69
+ df = pd.read_csv(file_path)
70
+ # filter out events with missing step length
71
+ df = df.loc[df.step_length.notna()]
72
+
73
+ # detect list columns which as stored as strings
74
+ list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
75
+ # convert string representation of list to list
76
+ df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
77
+
78
+ else:
79
+ try:
80
+ # Download the history from wandb
81
+ progress.progress(i/len(selected_runs),f'{prog_msg}... downloading `{run.path}`')
82
+ df = download_data(run.path)
83
+ df.assign(**run.to_dict())
84
+ if not os.path.exists('data/'):
85
+ os.makedirs(file_path)
86
+
87
+ if save and run.state != 'running':
88
+ df.to_csv(file_path, index=False)
89
+ # st.info(f'Saved history to {file_path}')
90
+ except Exception as e:
91
+ st.error(f'Failed to download history for `{run.path}`')
92
+ st.exception(e)
93
+ continue
94
+
95
+ frames.append(df)
96
+ n_events += df.shape[0]
97
+
98
+ progress.empty()
99
+ # Remove rows which contain chain weights as it messes up schema
100
+ return pd.concat(frames)
101
+
102
+ @st.cache_data
103
+ def get_exploded_data(df):
104
+ list_cols = get_list_col_lengths(df)
105
+ return explode_data(df, list(list_cols))
106
+
107
+ @st.cache_data
108
+ def get_completions(df_long, col):
109
+ return df_long[col].value_counts()
110
+
111
+ @st.cache_data
112
+ def plot_uid_diversty(df, remove_unsuccessful=True):
113
+ return plotting.plot_uid_diversty(df, remove_unsuccessful=remove_unsuccessful)
114
+
115
+ @st.cache_data
116
+ def plot_leaderboard(df, ntop, group_on, agg_col, agg, alias=False):
117
+ return plotting.plot_leaderboard(df, ntop=ntop, group_on=group_on, agg_col=agg_col, agg=agg, alias=alias)
118
+
119
+ @st.cache_data
120
+ def plot_completion_rewards(df, completion_col, reward_col, uid_col, ntop, completions=None, completion_regex=None):
121
+ return plotting.plot_completion_rewards(df, msg_col=completion_col, reward_col=reward_col, uid_col=uid_col, ntop=ntop, completions=completions, completion_regex=completion_regex)
122
+
123
+ @st.cache_data
124
+ def uid_metrics(df_long, src, uid=None):
125
+
126
+ uid_col = f'{src}_uids'
127
+ completion_col = f'{src}_completions'
128
+ nsfw_col = f'{src}_nsfw_scores'
129
+ reward_col = f'{src}_rewards'
130
+
131
+ if uid is not None:
132
+ df_long = df_long.loc[df_long[uid_col] == uid]
133
+
134
+ col1, col2, col3 = st.columns(3)
135
+ col1.metric(
136
+ label="Success %",
137
+ value=f'{df_long.loc[df_long[completion_col].str.len() > 0].shape[0]/df_long.shape[0] * 100:.1f}'
138
+ )
139
+ col2.metric(
140
+ label="Diversity %",
141
+ value=f'{df_long[completion_col].nunique()/df_long.shape[0] * 100:.1f}'
142
+ )
143
+ col3.metric(
144
+ label="Toxicity %",
145
+ value=f'{df_long[nsfw_col].mean() * 100:.1f}' if nsfw_col in df_long.columns else 'N/A'
146
+ )
147
+
148
+ st.title('Validator :red[Analysis] Dashboard :eyes:')
149
+ # add vertical space
150
+ st.markdown('#')
151
+ st.markdown('#')
152
+
153
+
154
+ with st.sidebar:
155
+ st.sidebar.header('Pages')
156
+
157
+ with st.spinner(text=f'Checking wandb...'):
158
+ df_runs = load_runs()
159
+ # get rows where start time is older than 24h ago
160
+ df_runs_old = df_runs.loc[df_runs.start_time < pd.to_datetime(time.time()-24*60*60, unit='s')]
161
+
162
+ col1, col2, col3 = st.columns(3)
163
+
164
+ col1.metric('Runs', df_runs.shape[0], delta=f'{df_runs.shape[0]-df_runs_old.shape[0]} (24h)')
165
+ col2.metric('Hotkeys', df_runs.hotkey.nunique(), delta=f'{df_runs.hotkey.nunique()-df_runs_old.hotkey.nunique()} (24h)')
166
+ col3.metric('Events', df_runs.num_steps.sum(), delta=f'{df_runs.num_steps.sum()-df_runs_old.num_steps.sum()} (24h)')
167
+
168
+ # https://wandb.ai/opentensor-dev/openvalidators/runs/kt9bzxii/overview?workspace=
169
+ # all_run_paths = ['opentensor-dev/openvalidators/kt9bzxii'] # pedro long run
170
+
171
+ run_ids = df_runs.id
172
+ default_selected_runs = ['kt9bzxii']
173
+ selected_runs = default_selected_runs
174
+
175
+ # add vertical space
176
+ st.markdown('#')
177
+ st.markdown('#')
178
+
179
+
180
+ tab1, tab2, tab3, tab4 = st.tabs(["Wandb Runs", "UID Health", "Completions", "Prompt-based scoring"])
181
+
182
+ # src = st.radio('Choose data source:', ['followup', 'answer'], horizontal=True, key='src')
183
+ # list_list_cols = get_list_col_lengths(df_long)
184
+ # df_long_long = explode_data(df_long, list(list_list_cols))
185
+
186
+ with tab1:
187
+
188
+ st.markdown('#')
189
+ st.subheader(":violet[Wandb] Runs")
190
+
191
+ # Load data
192
+ df = load_data(df_runs.loc[run_ids.isin(selected_runs)], load=True, save=True)
193
+ df_long = get_exploded_data(df)
194
+
195
+ col1, col2, col3, col4 = st.columns(4)
196
+ col1.metric(label="Selected runs", value=len(selected_runs))
197
+ col2.metric(label="Events", value=df.shape[0]) #
198
+ col3.metric(label="UIDs", value=df_long.followup_uids.nunique())
199
+ col4.metric(label="Unique completions", value=df_long.followup_uids.nunique())
200
+
201
+ selected_runs = st.multiselect(f'Runs ({len(df_runs)})', run_ids, default=selected_runs)
202
+
203
+ st.markdown('#')
204
+ st.subheader("View :violet[Data]")
205
+
206
+ show_col1, show_col2 = st.columns(2)
207
+ show_runs = show_col1.checkbox('Show runs', value=True)
208
+ show_events = show_col2.checkbox('Show events', value=False)
209
+ if show_runs:
210
+ st.markdown(f'Wandb info for **{len(selected_runs)} selected runs**:')
211
+ st.dataframe(df_runs.loc[run_ids.isin(selected_runs)],
212
+ column_config={
213
+ "url": st.column_config.LinkColumn("URL"),
214
+ }
215
+ )
216
+
217
+ if show_events:
218
+ st.markdown(f'Raw events for **{len(selected_runs)} selected runs**:')
219
+ st.dataframe(df.head(50),
220
+ column_config={
221
+ "url": st.column_config.LinkColumn("URL"),
222
+ }
223
+ )
224
+
225
+ default_src = 'followup'
226
+ with tab2:
227
+
228
+ st.markdown('#')
229
+ st.subheader("UID :violet[Health]")
230
+ uid_src = default_src
231
+
232
+ # uid = st.selectbox('UID:', sorted(df_long[uid_col].unique()), key='uid')
233
+
234
+ uid_metrics(df_long, uid_src)
235
+ uid_src = st.radio('Select one:', ['followup', 'answer'], horizontal=True, key='uid_src')
236
+ uid_col = f'{uid_src}_uids'
237
+ reward_col = f'{uid_src}_rewards'
238
+
239
+ st.markdown('#')
240
+ st.subheader("UID :violet[Leaderboard]")
241
+ uid_ntop_default = 10
242
+
243
+ uid_col1, uid_col2 = st.columns(2)
244
+ uid_ntop = uid_col1.slider('Number of UIDs:', min_value=1, max_value=50, value=uid_ntop_default, key='uid_ntop')
245
+ uid_agg = uid_col2.selectbox('Aggregation:', ('mean','min','max','size','nunique'), key='uid_agg')
246
+
247
+ st.plotly_chart(
248
+ plot_leaderboard(
249
+ df,
250
+ ntop=uid_ntop,
251
+ group_on=uid_col,
252
+ agg_col=reward_col,
253
+ agg=uid_agg
254
+ )
255
+ )
256
+ remove_unsuccessful = st.checkbox('Remove failed completions', value=True)
257
+ st.plotly_chart(
258
+ plot_uid_diversty(
259
+ df,
260
+ remove_unsuccessful=remove_unsuccessful
261
+ )
262
+ )
263
+
264
+
265
+ completion_ntop_default = 10
266
+ with tab3:
267
+
268
+ st.markdown('#')
269
+ st.subheader('Completion :violet[Leaderboard]')
270
+ completion_src = default_src
271
+
272
+ msg_col1, msg_col2 = st.columns(2)
273
+ completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
274
+ completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=completion_ntop_default, key='completion_ntop')
275
+
276
+ completion_col = f'{completion_src}_completions'
277
+ reward_col = f'{completion_src}_rewards'
278
+ uid_col = f'{completion_src}_uids'
279
+
280
+ completions = get_completions(df_long, completion_col)
281
+
282
+ # completion_sel = st.radio('Select input method:', ['ntop', 'select','regex'], horizontal=True, key='completion_sel')
283
+ # Get completions with highest average rewards
284
+ st.plotly_chart(
285
+ plot_leaderboard(
286
+ df,
287
+ ntop=completion_ntop,
288
+ group_on=completion_col,
289
+ agg_col=reward_col,
290
+ agg='mean',
291
+ alias=True
292
+ )
293
+ )
294
+ st.markdown('#')
295
+ st.subheader('Completion :violet[Rewards]')
296
+
297
+ completion_select = st.multiselect('Completions:', completions.index, default=completions.index[:3].tolist())
298
+ # completion_regex = st.text_input('Completion regex:', value='', key='completion_regex')
299
+
300
+ st.plotly_chart(
301
+ plot_completion_rewards(
302
+ df,
303
+ completion_col=completion_col,
304
+ reward_col=reward_col,
305
+ uid_col=uid_col,
306
+ ntop=completion_ntop,
307
+ completions=completion_select,
308
+ )
309
+ )
310
+
311
+ with tab4:
312
+ st.subheader(':pink[Prompt-based scoring]')
313
+ prompt_src = st.radio('Select one:', ['followup', 'answer'], key='prompt')
314
+
315
+
316
+ # st.dataframe(df_long_long.filter(regex=prompt_src).head())
317
+
hello_world.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import re
5
+ import time
6
+ from .utils import get_runs, download_data, get_list_col_lengths, explode_data
7
+ import .plotting as plotting
8
+
9
+
10
+
11
+ st.write('HELLO BOITCHES')
plotting.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ # Copyright © 2021 Yuma Rao
3
+
4
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
5
+ # documentation files (the “Software”), to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
7
+ # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
10
+ # the Software.
11
+
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
13
+ # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
14
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
15
+ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
16
+ # DEALINGS IN THE SOFTWARE.
17
+
18
+ import tqdm
19
+
20
+ import pandas as pd
21
+ import numpy as np
22
+ import networkx as nx
23
+
24
+ import plotly.express as px
25
+ import plotly.graph_objects as go
26
+
27
+ from typing import List, Union
28
+
29
+ plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
30
+
31
+
32
+ def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
33
+ """Plot throughput of event log.
34
+
35
+ Args:
36
+ df (pd.DataFrame): Dataframe of event log.
37
+ n_minutes (int, optional): Number of minutes to aggregate. Defaults to 10.
38
+ """
39
+
40
+ rate = df.resample(rule=f"{n_minutes}T", on="_timestamp").size()
41
+ return px.line(
42
+ x=rate.index, y=rate, title="Event Log Throughput", labels={"x": "", "y": f"Logs / {n_minutes} min"}, **plotly_config
43
+ )
44
+
45
+
46
+ def plot_weights(scores: pd.DataFrame, ntop: int = 20, uids: List[Union[str, int]] = None) -> go.Figure:
47
+ """_summary_
48
+
49
+ Args:
50
+ scores (pd.DataFrame): Dataframe of scores. Should be indexed by timestamp and have one column per uid.
51
+ ntop (int, optional): Number of uids to plot. Defaults to 20.
52
+ uids (List[Union[str, int]], optional): List of uids to plot, should match column names. Defaults to None.
53
+ """
54
+
55
+ # Select subset of columns for plotting
56
+ if uids is None:
57
+ uids = scores.columns[:ntop]
58
+ print(f"Using first {ntop} uids for plotting: {uids}")
59
+
60
+ return px.line(
61
+ scores, y=uids, title="Moving Averaged Scores", labels={"_timestamp": "", "value": "Score"}, **plotly_config
62
+ ).update_traces(opacity=0.7)
63
+
64
+
65
+ def plot_uid_diversty(df: pd.DataFrame, remove_unsuccessful: bool = False) -> go.Figure:
66
+ """Plot uid diversity as measured by ratio of unique to total completions.
67
+
68
+ Args:
69
+ df (pd.DataFrame): Dataframe of event log.
70
+ """
71
+ uid_cols = ["followup_uids", "answer_uids"]
72
+ completion_cols = ["followup_completions", "answer_completions"]
73
+ reward_cols = ["followup_rewards", "answer_rewards"]
74
+ list_cols = uid_cols + completion_cols + reward_cols
75
+
76
+ df = df[list_cols].explode(column=list_cols)
77
+ if remove_unsuccessful:
78
+ # remove unsuccessful completions, as indicated by empty completions
79
+ for col in completion_cols:
80
+ df = df[df[col].str.len() > 0]
81
+
82
+ frames = []
83
+ for uid_col, completion_col, reward_col in zip(uid_cols, completion_cols, reward_cols):
84
+ frame = df.groupby(uid_col).agg({completion_col: ["nunique", "size"], reward_col: "mean"})
85
+ # flatten multiindex columns
86
+ frame.columns = ["_".join(col) for col in frame.columns]
87
+ frame["diversity"] = frame[f"{completion_col}_nunique"] / frame[f"{completion_col}_size"]
88
+ frames.append(frame)
89
+
90
+ merged = pd.merge(*frames, left_index=True, right_index=True, suffixes=("_followup", "_answer"))
91
+ merged["reward_mean"] = merged.filter(regex="rewards_mean").mean(axis=1)
92
+
93
+ merged.index.name = "UID"
94
+ merged.reset_index(inplace=True)
95
+
96
+ return px.scatter(
97
+ merged,
98
+ x="diversity_followup",
99
+ y="diversity_answer",
100
+ opacity=0.3,
101
+ size="followup_completions_size",
102
+ color="reward_mean",
103
+ hover_data=["UID"] + merged.columns.tolist(),
104
+ marginal_x="histogram",
105
+ marginal_y="histogram",
106
+ color_continuous_scale=px.colors.sequential.Bluered,
107
+ labels={"x": "Followup diversity", "y": "Answer diversity"},
108
+ title="Diversity of completions by UID",
109
+ **plotly_config,
110
+ )
111
+
112
+
113
+ def plot_completion_rates(
114
+ df: pd.DataFrame,
115
+ msg_col: str = "all_completions",
116
+ time_interval: str = "H",
117
+ time_col: str = "_timestamp",
118
+ ntop: int = 20,
119
+ completions: List[str] = None,
120
+ completion_regex: str = None,
121
+ ) -> go.Figure:
122
+ """Plot completion rates. Useful for identifying common completions and attacks.
123
+
124
+ Args:
125
+ df (pd.DataFrame): Dataframe of event log.
126
+ msg_col (str, optional): List-like column containing completions. Defaults to 'all_completions'.
127
+ time_interval (str, optional): Pandas time interval. Defaults to 'H'. See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases
128
+ time_col (str, optional): Column containing timestamps as pd.Datetime. Defaults to '_timestamp'.
129
+ ntop (int, optional): Number of completions to plot. Defaults to 20.
130
+ completions (List[str], optional): List of completions to plot. Defaults to None.
131
+ completion_regex (str, optional): Regex to match completions. Defaults to None.
132
+
133
+ """
134
+
135
+ df = df[[time_col, msg_col]].explode(column=msg_col)
136
+
137
+ if completions is None:
138
+ completion_counts = df[msg_col].value_counts()
139
+ if completion_regex is not None:
140
+ completions = completion_counts[completion_counts.index.str.contains(completion_regex)].index[:ntop]
141
+ print(f"Using {len(completions)} completions which match {completion_regex!r}: \n{completions}")
142
+ else:
143
+ completions = completion_counts.index[:ntop]
144
+ print(f"Using top {len(completions)} completions: \n{completions}")
145
+
146
+ period = df[time_col].dt.to_period(time_interval)
147
+
148
+ counts = df.groupby([msg_col, period]).size()
149
+ top_counts = counts.loc[completions].reset_index().rename(columns={0: "Size"})
150
+ top_counts["Completion ID"] = top_counts[msg_col].map({k: f"{i}" for i, k in enumerate(completions, start=1)})
151
+
152
+ return px.line(
153
+ top_counts.astype({time_col: str}),
154
+ x=time_col,
155
+ y="Size",
156
+ color="Completion ID",
157
+ hover_data=[top_counts[msg_col].str.replace("\n", "<br>")],
158
+ labels={time_col: f"Time, {time_interval}", "Size": f"Occurrences / {time_interval}"},
159
+ title=f"Completion Rates for {len(completions)} Messages",
160
+ **plotly_config,
161
+ ).update_traces(opacity=0.7)
162
+
163
+
164
+ def plot_completion_rewards(
165
+ df: pd.DataFrame,
166
+ msg_col: str = "followup_completions",
167
+ reward_col: str = "followup_rewards",
168
+ time_col: str = "_timestamp",
169
+ uid_col: str = "followup_uids",
170
+ ntop: int = 3,
171
+ completions: List[str] = None,
172
+ completion_regex: str = None,
173
+ ) -> go.Figure:
174
+ """Plot completion rewards. Useful for tracking common completions and their rewards.
175
+
176
+ Args:
177
+ df (pd.DataFrame): Dataframe of event log.
178
+ msg_col (str, optional): List-like column containing completions. Defaults to 'followup_completions'.
179
+ reward_col (str, optional): List-like column containing rewards. Defaults to 'followup_rewards'.
180
+ time_col (str, optional): Column containing timestamps as pd.Datetime. Defaults to '_timestamp'.
181
+ ntop (int, optional): Number of completions to plot. Defaults to 20.
182
+ completions (List[str], optional): List of completions to plot. Defaults to None.
183
+ completion_regex (str, optional): Regex to match completions. Defaults to None.
184
+
185
+ """
186
+
187
+ df = (
188
+ df[[time_col, uid_col, msg_col, reward_col]]
189
+ .explode(column=[msg_col, uid_col, reward_col])
190
+ .rename(columns={uid_col: "UID"})
191
+ )
192
+ completion_counts = df[msg_col].value_counts()
193
+
194
+ if completions is None:
195
+ if completion_regex is not None:
196
+ completions = completion_counts[completion_counts.index.str.contains(completion_regex)].index[:ntop]
197
+ print(f"Using {len(completions)} completions which match {completion_regex!r}: \n{completions}")
198
+ else:
199
+ completions = completion_counts.index[:ntop]
200
+ print(f"Using top {len(completions)} completions: \n{completions}")
201
+
202
+ # Get ranks of completions in terms of number of occurrences
203
+ ranks = completion_counts.rank(method="dense", ascending=False).loc[completions].astype(int)
204
+
205
+ # Filter to only the selected completions
206
+ df = df.loc[df[msg_col].isin(completions)]
207
+ df["rank"] = df[msg_col].map(ranks).astype(str)
208
+ df["Total"] = df[msg_col].map(completion_counts)
209
+
210
+ return px.scatter(
211
+ df,
212
+ x=time_col,
213
+ y=reward_col,
214
+ color="rank",
215
+ hover_data=[msg_col, "UID", "Total"],
216
+ category_orders={"rank": sorted(df["rank"].unique())},
217
+ marginal_x="histogram",
218
+ marginal_y="violin",
219
+ labels={"rank": "Rank", reward_col: "Reward", time_col: ""},
220
+ title=f"Rewards for {len(completions)} Messages",
221
+ **plotly_config,
222
+ opacity=0.3,
223
+ )
224
+
225
+
226
+ def plot_leaderboard(
227
+ df: pd.DataFrame,
228
+ group_on: str = "answer_uids",
229
+ agg_col: str = "answer_rewards",
230
+ agg: str = "mean",
231
+ ntop: int = 10,
232
+ alias: bool = False,
233
+ ) -> go.Figure:
234
+ """Plot leaderboard for a given column. By default plots the top 10 UIDs by mean reward.
235
+
236
+ Args:
237
+ df (pd.DataFrame): Dataframe of event log.
238
+ group_on (str, optional): Entities to use for grouping. Defaults to 'answer_uids'.
239
+ agg_col (str, optional): Column to aggregate. Defaults to 'answer_rewards'.
240
+ agg (str, optional): Aggregation function. Defaults to 'mean'.
241
+ ntop (int, optional): Number of entities to plot. Defaults to 10.
242
+ alias (bool, optional): Whether to use aliases for indices. Defaults to False.
243
+ """
244
+ df = df[[group_on, agg_col]].explode(column=[group_on, agg_col])
245
+
246
+ rankings = df.groupby(group_on)[agg_col].agg(agg).sort_values(ascending=False).head(ntop)
247
+ if alias:
248
+ index = rankings.index.map({name: str(i) for i, name in enumerate(rankings.index)})
249
+ else:
250
+ index = rankings.index.astype(str)
251
+
252
+ return px.bar(
253
+ x=rankings,
254
+ y=index,
255
+ color=rankings,
256
+ orientation="h",
257
+ labels={"x": f"{agg_col.title()}", "y": group_on, "color": ""},
258
+ title=f"Leaderboard for {agg_col}, top {ntop} {group_on}",
259
+ color_continuous_scale="BlueRed",
260
+ opacity=0.5,
261
+ hover_data=[rankings.index.astype(str)],
262
+ **plotly_config,
263
+ )
264
+
265
+
266
+ def plot_dendrite_rates(
267
+ df: pd.DataFrame, uid_col: str = "answer_uids", reward_col: str = "answer_rewards", ntop: int = 20, uids: List[int] = None
268
+ ) -> go.Figure:
269
+ """Makes a bar chart of the success rate of dendrite calls for a given set of uids.
270
+
271
+ Args:
272
+ df (pd.DataFrame): Dataframe of event log.
273
+ uid_col (str, optional): Column containing uids. Defaults to 'answer_uids'.
274
+ reward_col (str, optional): Column containing rewards. Defaults to 'answer_rewards'.
275
+ ntop (int, optional): Number of uids to plot. Defaults to 20.
276
+ uids (List[int], optional): List of uids to plot. Defaults to None.
277
+
278
+ """
279
+
280
+ df = df[[uid_col, reward_col]].explode(column=[uid_col, reward_col]).rename(columns={uid_col: "UID"})
281
+ df["success"] = df[reward_col] != 0
282
+
283
+ if uids is None:
284
+ uids = df["UID"].value_counts().head(ntop).index
285
+ df = df.loc[df["UID"].isin(uids)]
286
+
287
+ # get total and successful dendrite calls
288
+ rates = df.groupby("UID").success.agg(["sum", "count"]).rename(columns={"sum": "Success", "count": "Total"})
289
+ rates = rates.melt(ignore_index=False).reset_index()
290
+ return px.bar(
291
+ rates.astype({"UID": str}),
292
+ x="value",
293
+ y="UID",
294
+ color="variable",
295
+ labels={"value": "Number of Calls", "variable": ""},
296
+ barmode="group",
297
+ title="Dendrite Calls by UID",
298
+ color_continuous_scale="Blues",
299
+ opacity=0.5,
300
+ **plotly_config,
301
+ )
302
+
303
+
304
+ def plot_network_embedding(
305
+ df: pd.DataFrame,
306
+ uid_col: str = "followup_uids",
307
+ completion_col: str = "followup_completions",
308
+ ntop: int = 1,
309
+ uids: List[int] = None,
310
+ ) -> go.Figure:
311
+ """Plots a network embedding of the most common completions for a given set of uids.
312
+
313
+ Args:
314
+ df (pd.DataFrame): Dataframe of event log.
315
+
316
+ uid_col (str, optional): Column containing uids. Defaults to 'answer_uids'.
317
+ completion_col (str, optional): Column containing completions. Defaults to 'followup_completions'.
318
+ ntop (int, optional): Number of uids to plot. Defaults to 20.
319
+ hover_data (List[str], optional): Columns to include in hover data. Defaults to None.
320
+ uids (List[int], optional): List of uids to plot. Defaults to None.
321
+
322
+ # TODO: use value counts to use weighted similarity instead of a simple set intersection
323
+ """
324
+ top_completions = {}
325
+ df = df[[uid_col, completion_col]].explode(column=[uid_col, completion_col])
326
+
327
+ if uids is None:
328
+ uids = df[uid_col].unique()
329
+ # loop over UIDs and compute ntop most common completions
330
+ for uid in tqdm.tqdm(uids, unit="UID"):
331
+ c = df.loc[df[uid_col] == uid, completion_col].value_counts()
332
+ top_completions[uid] = set(c.index[:ntop])
333
+
334
+ a = np.zeros((len(uids), len(uids)))
335
+ # now compute similarity matrix as a set intersection
336
+ for i, uid in enumerate(uids):
337
+ for j, uid2 in enumerate(uids[i + 1 :], start=i + 1):
338
+ a[i, j] = a[j, i] = len(top_completions[uid].intersection(top_completions[uid2])) / ntop
339
+
340
+ # make a graph from the similarity matrix
341
+ g = nx.from_numpy_array(a)
342
+ z = pd.DataFrame(nx.spring_layout(g)).T.rename(columns={0: "x", 1: "y"})
343
+ z["UID"] = uids
344
+ z["top_completions"] = pd.Series(top_completions).apply(list)
345
+
346
+ # assign groups based on cliques (fully connected subgraphs)
347
+ cliques = {
348
+ uids[cc]: f"Group-{i}" if len(c) > 1 else "Other" for i, c in enumerate(nx.find_cliques(g), start=1) for cc in c
349
+ }
350
+ z["Group"] = z["UID"].map(cliques)
351
+
352
+ return px.scatter(
353
+ z.reset_index(),
354
+ x="x",
355
+ y="y",
356
+ color="Group",
357
+ title=f"Graph for Top {ntop} Completion Similarities",
358
+ color_continuous_scale="BlueRed",
359
+ hover_data=["UID", "top_completions"],
360
+ opacity=0.5,
361
+ **plotly_config,
362
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ wandb==0.15.3
2
+ datasets==2.12.0
3
+ plotly==5.14.1
4
+ networkx==3.1
5
+ scipy==1.10.1
6
+ pre-commit==3.3.2
7
+ click==8.1.3
utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ # Copyright © 2021 Yuma Rao
3
+
4
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
5
+ # documentation files (the “Software”), to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
7
+ # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8
+
9
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
10
+ # the Software.
11
+
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
13
+ # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
14
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
15
+ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
16
+ # DEALINGS IN THE SOFTWARE.
17
+
18
+ import os
19
+ import tqdm
20
+ import wandb
21
+ import pandas as pd
22
+ from pandas.api.types import is_list_like
23
+
24
+ from typing import List, Dict, Any, Union
25
+
26
+
27
+ def get_runs(project: str = "openvalidators", filters: Dict[str, Any] = None, return_paths: bool = False) -> List:
28
+ """Download runs from wandb.
29
+
30
+ Args:
31
+ project (str): Name of the project. Defaults to 'openvalidators' (community project)
32
+ filters (Dict[str, Any], optional): Optional run filters for wandb api. Defaults to None.
33
+ return_paths (bool, optional): Return only run paths. Defaults to False.
34
+
35
+ Returns:
36
+ List[wandb.apis.public.Run]: List of runs or run paths (List[str]).
37
+ """
38
+ api = wandb.Api()
39
+ wandb.login()
40
+
41
+ runs = api.runs(project, filters=filters)
42
+ if return_paths:
43
+ return [os.path.join(run.entity, run.project, run.id) for run in runs]
44
+ else:
45
+ return runs
46
+
47
+
48
+ def download_data(run_path: Union[str, List] = None, timeout: float = 600) -> pd.DataFrame:
49
+ """Download data from wandb.
50
+
51
+ Args:
52
+ run_path (Union[str, List], optional): Path to run or list of paths. Defaults to None.
53
+ timeout (float, optional): Timeout for wandb api. Defaults to 600.
54
+
55
+ Returns:
56
+ pd.DataFrame: Dataframe of event log.
57
+ """
58
+ api = wandb.Api(timeout=timeout)
59
+ wandb.login()
60
+
61
+ if isinstance(run_path, str):
62
+ run_path = [run_path]
63
+
64
+ frames = []
65
+ total_events = 0
66
+ pbar = tqdm.tqdm(sorted(run_path), desc="Loading history from wandb", total=len(run_path), unit="run")
67
+ for path in pbar:
68
+ run = api.run(path)
69
+
70
+ frame = pd.DataFrame(list(run.scan_history()))
71
+ frames.append(frame)
72
+ total_events += len(frame)
73
+
74
+ pbar.set_postfix({"total_events": total_events})
75
+
76
+ df = pd.concat(frames)
77
+ # Convert timestamp to datetime.
78
+ df._timestamp = pd.to_datetime(df._timestamp, unit="s")
79
+ df.sort_values("_timestamp", inplace=True)
80
+
81
+ return df
82
+
83
+
84
+ def load_data(path: str, nrows: int = None):
85
+ """Load data from csv."""
86
+ df = pd.read_csv(path, nrows=nrows)
87
+ # filter out events with missing step length
88
+ df = df.loc[df.step_length.notna()]
89
+
90
+ # detect list columns which as stored as strings
91
+ list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
92
+ # convert string representation of list to list
93
+ df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
94
+
95
+ return df
96
+
97
+
98
+ def explode_data(df: pd.DataFrame, list_cols: List[str] = None, list_len: int = None) -> pd.DataFrame:
99
+ """Explode list columns in dataframe so that each element in the list is a separate row.
100
+
101
+ Args:
102
+ df (pd.DataFrame): Dataframe of event log.
103
+ list_cols (List[str], optional): List of columns to explode. Defaults to None.
104
+ list_len (int, optional): Length of list. Defaults to None.
105
+
106
+ Returns:
107
+ pd.DataFrame: Dataframe with exploded list columns.
108
+ """
109
+ if list_cols is None:
110
+ list_cols = [c for c in df.columns if df[c].apply(is_list_like).all()]
111
+ print(f"Exploding {len(list_cols)}) list columns with {list_len} elements: {list_cols}")
112
+ if list_len:
113
+ list_cols = [c for c in list_cols if df[c].apply(len).unique()[0] == list_len]
114
+ print(f"Exploding {len(list_cols)}) list columns with {list_len} elements: {list_cols}")
115
+
116
+ return df.explode(column=list_cols)
117
+
118
+
119
+ def get_list_col_lengths(df: pd.DataFrame) -> Dict[str, int]:
120
+ """Helper function to get the length of list columns."""
121
+ list_col_lengths = {c: sorted(df[c].apply(len).unique()) for c in df.columns if df[c].apply(is_list_like).all()}
122
+ varying_lengths = {c: v for c, v in list_col_lengths.items() if len(v) > 1}
123
+
124
+ if len(varying_lengths) > 0:
125
+ print(f"The following columns have varying lengths: {varying_lengths}")
126
+
127
+ return {c: v[0] for c, v in list_col_lengths.items()}