steffenc commited on
Commit
101093d
·
1 Parent(s): 44a0b65

Add new files for pulling data and template for metagraph dashboard

Browse files
Files changed (5) hide show
  1. meta_plotting.py +48 -0
  2. meta_utils.py +48 -0
  3. metagraph.py +169 -0
  4. multigraph.py +112 -0
  5. multistats.py +237 -0
meta_plotting.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.express as px
3
+
4
+ def plot_trace(df, col='emission', agg='mean', ntop=10, hotkeys=None, hotkey_regex=None, abbrev=8, type='Miners'):
5
+
6
+ if hotkeys is not None:
7
+ df = df.loc[df.hotkey.isin(hotkeys)]
8
+ if hotkey_regex is not None:
9
+ df = df.loc[df.hotkey.str.contains(hotkey_regex)]
10
+
11
+ top_miners = df.groupby('hotkey')[col].agg(agg).sort_values(ascending=False)
12
+
13
+ stats = df.loc[df.hotkey.isin(top_miners.index[:ntop])].sort_values(by=['timestamp'])
14
+
15
+ stats['hotkey_abbrev'] = stats.hotkey.str[:abbrev]
16
+ stats['coldkey_abbrev'] = stats.coldkey.str[:abbrev]
17
+ stats['rank'] = stats.hotkey.map({k:i for i,k in enumerate(top_miners.index, start=1)})
18
+
19
+ return px.line(stats.sort_values(by=['timestamp','rank']),
20
+ x='timestamp', y=col, color='coldkey_abbrev', line_group='hotkey_abbrev',
21
+ hover_data=['hotkey','rank'],
22
+ labels={col:col.title(),'timestamp':'','coldkey_abbrev':f'Coldkey (first {abbrev} chars)','hotkey_abbrev':f'Hotkey (first {abbrev} chars)'},
23
+ title=f'Top {ntop} {type}, by {col.title()}',
24
+ template='plotly_white', width=800, height=600,
25
+ ).update_traces(opacity=0.7)
26
+
27
+
28
+ def plot_cabals(df, sel_col='coldkey', count_col='hotkey', values=None, ntop=10, abbr=8):
29
+
30
+ if values is None:
31
+ values = df[sel_col].value_counts().sort_values(ascending=False).index[:ntop].tolist()
32
+ print(f'Automatically selected {sel_col!r} = {values!r}')
33
+
34
+ df = df.loc[df[sel_col].isin(values)]
35
+ rates = df.groupby(['timestamp',sel_col])[count_col].nunique().reset_index()
36
+ abbr_col = f'{sel_col} (first {abbr} chars)'
37
+ rates[abbr_col] = rates[sel_col].str[:abbr]
38
+ return px.line(rates.melt(id_vars=['timestamp',sel_col,abbr_col]),
39
+ x='timestamp', y='value', color=abbr_col,
40
+ #facet_col='variable', facet_col_wrap=1,
41
+ labels={'value':f'Number of Unique {count_col.title()}s per {sel_col.title()}','timestamp':''},
42
+ category_orders={abbr_col:[ v[:abbr] for v in values]},
43
+ # title=f'Unique {count_col.title()}s Associated with Selected {sel_col.title()}s in Metagraph',
44
+ title=f'Impact of Validators Update on Cabal',
45
+ width=800, height=600, template='plotly_white',
46
+ )
47
+
48
+
meta_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import pickle
5
+ import subprocess
6
+ import pandas as pd
7
+
8
+
9
+ def run_subprocess(*args):
10
+ # Trigger the multigraph.py script to run and save metagraph snapshots
11
+ return subprocess.run('python multigraph.py'.split()+list(args),
12
+ shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
13
+
14
+ def load_metagraph(path, extra_cols=None, rm_cols=None):
15
+
16
+ with open(path, 'rb') as f:
17
+ metagraph = pickle.load(f)
18
+
19
+ df = pd.DataFrame(metagraph.axons)
20
+ df['block'] = metagraph.block.item()
21
+ df['difficulty'] = metagraph.difficulty
22
+ for c in extra_cols:
23
+ vals = getattr(metagraph,c)
24
+ df[c] = vals
25
+
26
+ return df.drop(columns=rm_cols)
27
+
28
+ def load_metagraphs(block_start, block_end, block_step=1000, datadir='data/metagraph/1/', extra_cols=None):
29
+
30
+ if extra_cols is None:
31
+ extra_cols = ['total_stake','ranks','incentive','emission','consensus','trust','validator_trust','dividends']
32
+
33
+ blocks = range(block_start, block_end, block_step)
34
+ filenames = sorted(path for path in os.listdir(datadir) if int(path.split('.')[0]) in blocks)
35
+
36
+ metagraphs = []
37
+
38
+ pbar = tqdm.tqdm(filenames)
39
+ for filename in pbar:
40
+ pbar.set_description(f'Processing {filename}')
41
+
42
+ metagraph = load_metagraph(os.path.join(datadir, filename), extra_cols=extra_cols, rm_cols=['protocol','placeholder1','placeholder2'])
43
+
44
+ metagraphs.append(metagraph)
45
+
46
+ return pd.concat(metagraphs)
47
+
48
+ load_metagraphs(block_start=700_000, block_end=800_000, block_step=1000)
metagraph.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from meta_utils import run_subprocess, load_metagraphs
3
+ # from opendashboards.assets import io, inspect, metric, plot
4
+ from meta_plotting import plot_trace, plot_cabals
5
+
6
+ DEFAULT_SRC = 'miner'
7
+ DEFAULT_NTOP = 10
8
+ DEFAULT_UID_NTOP = 10
9
+
10
+ # Set app config
11
+ st.set_page_config(
12
+ page_title='Validator Dashboard',
13
+ menu_items={
14
+ 'Report a bug': "https://github.com/opentensor/dashboards/issues",
15
+ 'About': """
16
+ This dashboard is part of the OpenTensor project. \n
17
+ """
18
+ },
19
+ layout = "centered"
20
+ )
21
+
22
+ st.title('Metagraph :red[Analysis] Dashboard :eyes:')
23
+ # add vertical space
24
+ st.markdown('#')
25
+ st.markdown('#')
26
+
27
+
28
+ with st.spinner(text=f'Loading data...'):
29
+ df = load_metagraphs()
30
+
31
+ blocks = df.block.unique()
32
+
33
+ # metric.wandb(df_runs)
34
+
35
+ # add vertical space
36
+ st.markdown('#')
37
+ st.markdown('#')
38
+
39
+ tab1, tab2, tab3, tab4 = st.tabs(["Health", "Miners", "Validators", "Block"])
40
+
41
+
42
+ ### Wandb Runs ###
43
+ with tab1:
44
+
45
+ st.markdown('#')
46
+ st.header(":violet[Wandb] Runs")
47
+
48
+ run_msg = st.info("Select a single run or compare multiple runs")
49
+ selected_runs = st.multiselect(f'Runs ({len(df_runs)})', df_runs.id, default=DEFAULT_SELECTED_RUNS, key='runs')
50
+
51
+ # Load data if new runs selected
52
+ if not selected_runs:
53
+ # open a dialog to select runs
54
+ run_msg.error("Please select at least one run")
55
+ st.snow()
56
+ st.stop()
57
+
58
+ df = io.load_data(df_runs.loc[df_runs.id.isin(selected_runs)], load=True, save=True)
59
+ df_long = inspect.explode_data(df)
60
+ df_weights = inspect.weights(df)
61
+
62
+ metric.runs(df, df_long, selected_runs)
63
+
64
+ with st.expander(f'Show :violet[raw] data for {len(selected_runs)} selected runs'):
65
+ inspect.run_event_data(df_runs,df, selected_runs)
66
+
67
+
68
+ ### UID Health ###
69
+ with tab2:
70
+
71
+ st.markdown('#')
72
+ st.header("UID :violet[Health]")
73
+ st.info(f"Showing UID health metrics for **{len(selected_runs)} selected runs**")
74
+
75
+ uid_src = st.radio('Select one:', ['followup', 'answer'], horizontal=True, key='uid_src')
76
+
77
+ metric.uids(df_long, uid_src)
78
+
79
+ with st.expander(f'Show UID **{uid_src}** weights data for **{len(selected_runs)} selected runs**'):
80
+
81
+ uids = st.multiselect('UID:', sorted(df_long[f'{uid_src}_uids'].unique()), key='uid')
82
+ st.markdown('#')
83
+ st.subheader(f"UID {uid_src.title()} :violet[Weights]")
84
+
85
+ plot.weights(
86
+ df_weights,
87
+ uids=uids,
88
+ )
89
+
90
+ with st.expander(f'Show UID **{uid_src}** leaderboard data for **{len(selected_runs)} selected runs**'):
91
+
92
+ st.markdown('#')
93
+ st.subheader(f"UID {uid_src.title()} :violet[Leaderboard]")
94
+ uid_col1, uid_col2 = st.columns(2)
95
+ uid_ntop = uid_col1.slider('Number of UIDs:', min_value=1, max_value=50, value=DEFAULT_UID_NTOP, key='uid_ntop')
96
+ uid_agg = uid_col2.selectbox('Aggregation:', ('mean','min','max','size','nunique'), key='uid_agg')
97
+
98
+ plot.leaderboard(
99
+ df,
100
+ ntop=uid_ntop,
101
+ group_on=f'{uid_src}_uids',
102
+ agg_col=f'{uid_src}_rewards',
103
+ agg=uid_agg
104
+ )
105
+
106
+
107
+ with st.expander(f'Show UID **{uid_src}** diversity data for **{len(selected_runs)} selected runs**'):
108
+
109
+ st.markdown('#')
110
+ st.subheader(f"UID {uid_src.title()} :violet[Diversity]")
111
+ rm_failed = st.checkbox(f'Remove failed **{uid_src}** completions', value=True)
112
+ plot.uid_diversty(df, rm_failed)
113
+
114
+
115
+ ### Completions ###
116
+ with tab3:
117
+
118
+ st.markdown('#')
119
+ st.subheader('Completion :violet[Leaderboard]')
120
+ completion_info = st.empty()
121
+
122
+ msg_col1, msg_col2 = st.columns(2)
123
+ completion_src = msg_col1.radio('Select one:', ['followup', 'answer'], horizontal=True, key='completion_src')
124
+ completion_info.info(f"Showing **{completion_src}** completions for **{len(selected_runs)} selected runs**")
125
+
126
+ completion_ntop = msg_col2.slider('Top k:', min_value=1, max_value=50, value=DEFAULT_COMPLETION_NTOP, key='completion_ntop')
127
+
128
+ completion_col = f'{completion_src}_completions'
129
+ reward_col = f'{completion_src}_rewards'
130
+ uid_col = f'{completion_src}_uids'
131
+
132
+ completions = inspect.completions(df_long, completion_col)
133
+
134
+ # Get completions with highest average rewards
135
+ plot.leaderboard(
136
+ df,
137
+ ntop=completion_ntop,
138
+ group_on=completion_col,
139
+ agg_col=reward_col,
140
+ agg='mean',
141
+ alias=True
142
+ )
143
+
144
+ with st.expander(f'Show **{completion_src}** completion rewards data for **{len(selected_runs)} selected runs**'):
145
+
146
+ st.markdown('#')
147
+ st.subheader('Completion :violet[Rewards]')
148
+
149
+ completion_select = st.multiselect('Completions:', completions.index, default=completions.index[:3].tolist())
150
+ # completion_regex = st.text_input('Completion regex:', value='', key='completion_regex')
151
+
152
+ plot.completion_rewards(
153
+ df,
154
+ completion_col=completion_col,
155
+ reward_col=reward_col,
156
+ uid_col=uid_col,
157
+ ntop=completion_ntop,
158
+ completions=completion_select,
159
+ )
160
+
161
+
162
+ ### Prompt-based scoring ###
163
+ with tab4:
164
+ # coming soon
165
+ st.info('Prompt-based scoring coming soon')
166
+
167
+ # st.dataframe(df_long_long.filter(regex=prompt_src).head())
168
+
169
+
multigraph.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ from traceback import print_exc
5
+ import pickle
6
+ import tqdm
7
+ import pandas as pd
8
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
9
+
10
+ import torch
11
+ import bittensor
12
+
13
+ #TODO: make line charts and other cool stuff for each metagraph snapshot
14
+
15
+ def process(block, netuid=1, lite=True, difficulty=False, prune_weights=False, return_graph=False, half=True, subtensor=None):
16
+
17
+ if subtensor is None:
18
+ subtensor = bittensor.subtensor(network='finney')
19
+
20
+ try:
21
+ metagraph = subtensor.metagraph(block=block, netuid=netuid, lite=lite)
22
+ if difficulty:
23
+ metagraph.difficulty = subtensor.difficulty(block=block, netuid=netuid)
24
+
25
+ if not lite:
26
+ if half:
27
+ metagraph.weights = torch.nn.Parameter(metagraph.weights.half(), requires_grad=False)
28
+ if prune_weights:
29
+ metagraph.weights = metagraph.weights[metagraph.weights.sum(axis=1) > 0]
30
+
31
+ with open(f'data/metagraph/{netuid}/{block}.pkl', 'wb') as f:
32
+ pickle.dump(metagraph, f)
33
+
34
+ return metagraph if return_graph else True
35
+
36
+ except Exception as e:
37
+ print(f'Error processing block {block}: {e}')
38
+
39
+
40
+ def parse_arguments():
41
+ parser = argparse.ArgumentParser(description='Process metagraphs for a given network.')
42
+ parser.add_argument('--netuid', type=int, default=1, help='Network UID to use.')
43
+ parser.add_argument('--difficulty', action='store_true', help='Include difficulty in metagraph.')
44
+ parser.add_argument('--prune_weights', action='store_true', help='Prune weights in metagraph.')
45
+ parser.add_argument('--return_graph', action='store_true', help='Return metagraph instead of True.')
46
+ parser.add_argument('--max_workers', type=int, default=32, help='Max workers to use.')
47
+ parser.add_argument('--start_block', type=int, default=1_000_000, help='Start block.')
48
+ parser.add_argument('--end_block', type=int, default=600_000, help='End block.')
49
+ parser.add_argument('--step_size', type=int, default=100, help='Step size.')
50
+ return parser.parse_args()
51
+
52
+ if __name__ == '__main__':
53
+
54
+ subtensor = bittensor.subtensor(network='finney')
55
+ print(f'Current block: {subtensor.block}')
56
+
57
+ args = parse_arguments()
58
+
59
+ netuid=args.netuid
60
+ difficulty=args.difficulty
61
+ overwrite=False
62
+ return_graph=args.return_graph
63
+
64
+ step_size = args.step_size
65
+ start_block = args.start_block
66
+ start_block = (min(subtensor.block, start_block)//step_size)*step_size # round to nearest step_size
67
+ end_block = args.end_block
68
+ blocks = range(start_block, end_block, -step_size)
69
+
70
+ # only get weights for multiple of 500 blocks
71
+ lite=lambda x: x%500!=0
72
+
73
+ max_workers = min(args.max_workers, len(blocks))
74
+
75
+ os.makedirs(f'data/metagraph/{netuid}', exist_ok=True)
76
+ if not overwrite:
77
+ blocks = [block for block in blocks if not os.path.exists(f'data/metagraph/{netuid}/{block}.pkl')]
78
+
79
+ metagraphs = []
80
+
81
+ if len(blocks)==0:
82
+ print(f'No blocks to process. Current block: {subtensor.block}')
83
+ quit()
84
+
85
+ print(f'Processing {len(blocks)} blocks from {blocks[0]}-{blocks[-1]} using {max_workers} workers.')
86
+
87
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
88
+ futures = [
89
+ executor.submit(process, block, lite=lite(block), netuid=netuid, difficulty=difficulty)
90
+ for block in blocks
91
+ ]
92
+
93
+ success = 0
94
+ with tqdm.tqdm(total=len(futures)) as pbar:
95
+ for block, future in zip(blocks,futures):
96
+ try:
97
+ metagraphs.append(future.result())
98
+ success += 1
99
+ except Exception as e:
100
+ print(f'generated an exception: {print_exc(e)}')
101
+ pbar.update(1)
102
+ pbar.set_description(f'Processed {success} blocks. Current block: {block}')
103
+
104
+ if not success:
105
+ raise ValueError('No blocks were successfully processed.')
106
+
107
+ print(f'Processed {success} blocks.')
108
+ if return_graph:
109
+ for metagraph in metagraphs:
110
+ print(f'{metagraph.block}: {metagraph.n.item()} nodes, difficulty={getattr(metagraph, "difficulty", None)}, weights={metagraph.weights.shape if hasattr(metagraph, "weights") else None}')
111
+
112
+ print(metagraphs[-1])
multistats.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import re
4
+ import tqdm
5
+ import wandb
6
+ from traceback import print_exc
7
+ import plotly.express as px
8
+ import pandas as pd
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import opendashboards.utils.utils as utils
12
+
13
+ from IPython.display import display
14
+
15
+ api= wandb.Api(timeout=60)
16
+ wandb.login(anonymous="allow")
17
+
18
+ def pull_wandb_runs(project='openvalidators', filters=None, min_steps=50, max_steps=100_000, ntop=10, summary_filters=None ):
19
+ # TODO: speed this up by storing older runs
20
+
21
+ all_runs = api.runs(project, filters=filters)
22
+ print(f'Using {ntop}/{len(all_runs)} runs with more than {min_steps} events')
23
+ pbar = tqdm.tqdm(all_runs)
24
+ runs = []
25
+ n_events = 0
26
+ successful = 0
27
+ for i, run in enumerate(pbar):
28
+
29
+ summary = run.summary
30
+ if summary_filters is not None and not summary_filters(summary):
31
+ continue
32
+ step = summary.get('_step',0)
33
+ if step < min_steps or step > max_steps:
34
+ # warnings.warn(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
35
+ continue
36
+
37
+ prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
38
+ pbar.set_description(f'{prog_msg}... **fetching** `{run.name}`')
39
+
40
+ duration = summary.get('_runtime')
41
+ end_time = summary.get('_timestamp')
42
+ # extract values for selected tags
43
+ rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
44
+ tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
45
+ # include bool flag for remaining tags
46
+ tags.update({k: True for k in run.tags if k not in tags.keys() and k not in tags.values()})
47
+
48
+ runs.append({
49
+ 'state': run.state,
50
+ 'num_steps': step,
51
+ 'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
52
+ 'entity': run.entity,
53
+ 'user': run.user.name,
54
+ 'username': run.user.username,
55
+ 'run_id': run.id,
56
+ 'run_name': run.name,
57
+ 'project': run.project,
58
+ 'run_url': run.url,
59
+ 'run_path': os.path.join(run.entity, run.project, run.id),
60
+ 'start_time': pd.to_datetime(end_time-duration, unit="s"),
61
+ 'end_time': pd.to_datetime(end_time, unit="s"),
62
+ 'duration': pd.to_timedelta(duration, unit="s").round('s'),
63
+ **tags
64
+ })
65
+ n_events += step
66
+ successful += 1
67
+ if successful >= ntop:
68
+ break
69
+
70
+ return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})
71
+
72
+ def plot_gantt(df_runs):
73
+ fig = px.timeline(df_runs,
74
+ x_start="start_time", x_end="end_time", y="username", color="state",
75
+ title="Timeline of Runs",
76
+ category_orders={'run_name': df_runs.run_name.unique()},#,'username': sorted(df_runs.username.unique())},
77
+ hover_name="run_name",
78
+ hover_data=['hotkey','user','username','run_id','num_steps','num_completions'],
79
+ color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'},
80
+ opacity=0.3,
81
+ width=1200,
82
+ height=800,
83
+ template="plotly_white",
84
+ )
85
+ fig.update_yaxes(tickfont_size=8, title='')
86
+ fig.show()
87
+
88
+ def load_data(run_id, run_path=None, load=True, save=False, timeout=30):
89
+
90
+ file_path = os.path.join('data/runs/',f'history-{run_id}.csv')
91
+
92
+ if load and os.path.exists(file_path):
93
+ df = pd.read_csv(file_path, nrows=None)
94
+ # filter out events with missing step length
95
+ df = df.loc[df.step_length.notna()]
96
+
97
+ # detect list columns which as stored as strings
98
+ list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
99
+ # convert string representation of list to list
100
+ df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
101
+
102
+ else:
103
+ # Download the history from wandb and add metadata
104
+ run = api.run(run_path)
105
+ df = pd.DataFrame(list(run.scan_history()))
106
+
107
+ print(f'Downloaded {df.shape[0]} events from {run_path!r} with id {run_id!r}')
108
+
109
+ if save:
110
+ df.to_csv(file_path, index=False)
111
+
112
+ # Convert timestamp to datetime.
113
+ df._timestamp = pd.to_datetime(df._timestamp, unit="s")
114
+ return df.sort_values("_timestamp")
115
+
116
+
117
+ def calculate_stats(df_long, rm_failed=True, rm_zero_reward=True, freq='H', save_path=None ):
118
+
119
+ df_long._timestamp = pd.to_datetime(df_long._timestamp)
120
+ # if dataframe has columns such as followup_completions and answer_completions, convert to multiple rows
121
+ if 'completions' not in df_long.columns:
122
+ df_long.set_index(['_timestamp','run_id'], inplace=True)
123
+ df_schema = pd.concat([
124
+ df_long[['followup_completions','followup_rewards']].rename(columns={'followup_completions':'completions', 'followup_rewards':'rewards'}),
125
+ df_long[['answer_completions','answer_rewards']].rename(columns={'answer_completions':'completions', 'answer_rewards':'rewards'})
126
+ ])
127
+ df_long = df_schema.reset_index()
128
+
129
+ if rm_failed:
130
+ df_long = df_long.loc[ df_long.completions.str.len()>0 ]
131
+
132
+ if rm_zero_reward:
133
+ df_long = df_long.loc[ df_long.rewards>0 ]
134
+
135
+ print(f'Calculating stats for dataframe with shape {df_long.shape}')
136
+
137
+ g = df_long.groupby([pd.Grouper(key='_timestamp', axis=0, freq=freq), 'run_id'])
138
+
139
+ stats = g.agg({'completions':['nunique','count'], 'rewards':['sum','mean','std']})
140
+
141
+ stats.columns = ['_'.join(c) for c in stats.columns]
142
+ stats['completions_diversity'] = stats['completions_nunique'] / stats['completions_count']
143
+ stats = stats.reset_index()
144
+
145
+ if save_path:
146
+ stats.to_csv(save_path, index=False)
147
+
148
+ return stats
149
+
150
+
151
+ def clean_data(df):
152
+ return df.dropna(subset=df.filter(regex='completions|rewards').columns, how='any').dropna(axis=1, how='all')
153
+
154
+ def explode_data(df):
155
+ list_cols = utils.get_list_col_lengths(df)
156
+ return utils.explode_data(df, list(list_cols.keys())).apply(pd.to_numeric, errors='ignore')
157
+
158
+
159
+ def process(run, load=True, save=False, freq='H'):
160
+
161
+ try:
162
+
163
+ stats_path = f'data/aggs/stats-{run["run_id"]}.csv'
164
+ if os.path.exists(stats_path):
165
+ print(f'Loaded stats file {stats_path}')
166
+ return pd.read_csv(stats_path)
167
+
168
+ # Load data and add extra columns from wandb run
169
+ df = load_data(run_id=run['run_id'],
170
+ run_path=run['run_path'],
171
+ load=load,
172
+ save=save,
173
+ save = (run['state'] != 'running') & run['end_time']
174
+ ).assign(**run.to_dict())
175
+ # Clean and explode dataframe
176
+ df_long = explode_data(clean_data(df))
177
+ # Remove original dataframe from memory
178
+ del df
179
+ # Get and save stats
180
+ return calculate_stats(df_long, freq=freq, save_path=stats_path)
181
+
182
+ except Exception as e:
183
+ print(f'Error processing run {run["run_id"]}: {e}')
184
+
185
+ if __name__ == '__main__':
186
+
187
+ # TODO: flag to overwrite runs that were running when downloaded and saved: check if file date is older than run end time.
188
+
189
+ filters = None# {"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
190
+ # filters={'tags': {'$in': ['5F4tQyWrhfGVcNhoqeiNsR6KjD4wMZ2kfhLj4oHYuyHbZAc3']}} # Is foundation validator
191
+ df_runs = pull_wandb_runs(ntop=500, filters=filters)#summary_filters=lambda s: s.get('augment_prompt'))
192
+
193
+ os.makedirs('data/runs/', exist_ok=True)
194
+ os.makedirs('data/aggs/', exist_ok=True)
195
+ df_runs.to_csv('data/wandb.csv', index=False)
196
+
197
+ display(df_runs)
198
+ plot_gantt(df_runs)
199
+
200
+ with ProcessPoolExecutor(max_workers=min(32, df_runs.shape[0])) as executor:
201
+ futures = [executor.submit(process, run, load=True, save=True) for _, run in df_runs.iterrows()]
202
+
203
+ # Use tqdm to add a progress bar
204
+ results = []
205
+ with tqdm.tqdm(total=len(futures)) as pbar:
206
+ for future in futures:
207
+ try:
208
+ result = future.result()
209
+ results.append(result)
210
+ except Exception as e:
211
+ print(f'generated an exception: {print_exc(e)}')
212
+ pbar.update(1)
213
+
214
+ if not results:
215
+ raise ValueError('No runs were successfully processed.')
216
+
217
+ # Concatenate the results into a single dataframe
218
+ df = pd.concat(results, ignore_index=True)
219
+
220
+ df.to_csv('data/processed.csv', index=False)
221
+
222
+ display(df)
223
+
224
+ fig = px.line(df.astype({'_timestamp':str}),
225
+ x='_timestamp',
226
+ y='completions_diversity',
227
+ # y=['Unique','Total'],
228
+ line_group='run_id',
229
+ # color='hotkey',
230
+ # color_discrete_sequence=px.colors.sequential.YlGnBu,
231
+ title='Completion Diversity over Time',
232
+ labels={'_timestamp':'', 'completions_diversity':'Diversity', 'uids':'UID','value':'counts', 'variable':'Completions'},
233
+ width=800, height=600,
234
+ template='plotly_white',
235
+ ).update_traces(opacity=0.3)
236
+ fig.show()
237
+