sn1 / multistats.py
steffenc's picture
Major changes for efficiency, detail and presentation
b45152a
raw
history blame
14.3 kB
import os
import re
import argparse
import tqdm
import wandb
from traceback import format_exc
import plotly.express as px
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
import opendashboards.utils.utils as utils
import opendashboards.utils.aggregate as aggregate
from IPython.display import display
api= wandb.Api(timeout=60)
wandb.login(anonymous="allow")
def pull_wandb_runs(project='openvalidators', filters=None, min_steps=50, max_steps=100_000, ntop=10, netuid=None, summary_filters=None ):
# TODO: speed this up by storing older runs
all_runs = api.runs(project, filters=filters)
print(f'Using {ntop}/{len(all_runs)} runs with more than {min_steps} events')
pbar = tqdm.tqdm(all_runs)
runs = []
n_events = 0
successful = 0
for i, run in enumerate(pbar):
summary = run.summary
if summary_filters is not None and not summary_filters(summary):
continue
if netuid is not None and summary.get('netuid') != netuid:
continue
step = summary.get('_step',0)
if step < min_steps or step > max_steps:
# warnings.warn(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
continue
prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
pbar.set_description(f'{prog_msg}... **fetching** `{run.name}`')
duration = summary.get('_runtime')
end_time = summary.get('_timestamp')
# extract values for selected tags
rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
# include bool flag for remaining tags
tags.update({k: True for k in run.tags if k not in tags.keys() and k not in tags.values()})
runs.append({
'state': run.state,
'num_steps': step,
'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
'entity': run.entity,
'user': run.user.name,
'username': run.user.username,
'run_id': run.id,
'run_name': run.name,
'project': run.project,
'run_url': run.url,
'run_path': os.path.join(run.entity, run.project, run.id),
'start_time': pd.to_datetime(end_time-duration, unit="s"),
'end_time': pd.to_datetime(end_time, unit="s"),
'duration': pd.to_timedelta(duration, unit="s").round('s'),
'netuid': run.config.get('netuid'),
**tags
})
n_events += step
successful += 1
if successful >= ntop:
break
return pd.DataFrame(runs).astype({'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category'})
def plot_gantt(df_runs):
fig = px.timeline(df_runs,
x_start="start_time", x_end="end_time", y="username", color="state",
title="Timeline of Runs",
category_orders={'run_name': df_runs.run_name.unique()},#,'username': sorted(df_runs.username.unique())},
hover_name="run_name",
hover_data=['hotkey','user','username','run_id','num_steps','num_completions'],
color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'},
opacity=0.3,
width=1200,
height=800,
template="plotly_white",
)
fig.update_yaxes(tickfont_size=8, title='')
fig.show()
def clean_data(df):
return df.dropna(subset=df.filter(regex='completions|rewards').columns, how='any').dropna(axis=1, how='all')
def explode_data(df):
list_cols = utils.get_list_col_lengths(df)
return utils.explode_data(df, list(list_cols.keys())).apply(pd.to_numeric, errors='ignore')
def load_data(run_id, run_path=None, load=True, save=False, explode=True):
file_path = os.path.join('data/runs/',f'history-{run_id}.parquet')
if load and os.path.exists(file_path):
df = pd.read_parquet(file_path)
# filter out events with missing step length
df = df.loc[df.step_length.notna()]
# detect list columns which as stored as strings
list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
# convert string representation of list to list
# df[list_cols] = df[list_cols].apply(lambda x: eval(x, {'__builtins__': None}) if pd.notna(x) else x)
try:
df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
except ValueError as e:
print(f'Error loading {file_path!r} when converting columns {list_cols} to list: {e}')
else:
# Download the history from wandb and add metadata
run = api.run(run_path)
df = pd.DataFrame(list(run.scan_history()))
# Remove rows with missing completions or rewards, which will be stuff related to weights
df.dropna(subset=df.filter(regex='completions|rewards').columns, how='any', inplace=True)
print(f'Downloaded {df.shape[0]} events from {run_path!r} with id {run_id!r}')
# Clean and explode dataframe
# overwrite object to free memory
float_cols = df.filter(regex='reward').columns
df = explode_data(clean_data(df)).astype({c: float for c in float_cols}).fillna({c: 0 for c in float_cols})
if save:
df.to_parquet(file_path, index=False)
# Convert timestamp to datetime.
df._timestamp = pd.to_datetime(df._timestamp, unit="s")
return df.sort_values("_timestamp")
def calculate_stats(df_long, freq='H', save_path=None, ntop=3 ):
df_long._timestamp = pd.to_datetime(df_long._timestamp)
# if dataframe has columns such as followup_completions and answer_completions, convert to multiple rows
if 'completions' not in df_long.columns:
df_long.set_index(['_timestamp','run_id'], inplace=True)
df_schema = pd.concat([
df_long[['followup_completions','followup_rewards']].rename(columns={'followup_completions':'completions', 'followup_rewards':'rewards'}),
df_long[['answer_completions','answer_rewards']].rename(columns={'answer_completions':'completions', 'answer_rewards':'rewards'})
])
df_long = df_schema.reset_index()
print(f'Calculating stats for dataframe with shape {df_long.shape}')
# Approximate number of tokens in each completion
df_long['completion_num_tokens'] = (df_long['completions'].str.split().str.len() / 0.75).round()
g = df_long.groupby([pd.Grouper(key='_timestamp', axis=0, freq=freq), 'run_id'])
# TODO: use named aggregations
reward_aggs = ['sum','mean','std','median','max',aggregate.nonzero_rate, aggregate.nonzero_mean, aggregate.nonzero_std, aggregate.nonzero_median]
aggs = {
'completions': ['nunique','count', aggregate.diversity, aggregate.successful_diversity, aggregate.success_rate],
'completion_num_tokens': ['mean', 'std', 'median', 'max'],
**{k: reward_aggs for k in df_long.filter(regex='reward')}
}
# Calculate tokens per second
if 'completion_times' in df_long.columns:
df_long['tokens_per_sec'] = df_long['completion_num_tokens']/df_long['completion_times']
aggs.update({
'completion_times': ['mean','std','median','min','max'],
'tokens_per_sec': ['mean','std','median','max'],
})
stats = g.agg(aggs)
stats = stats.merge(g.apply(aggregate.top_stats, exclude='', ntop=ntop).reset_index(level=1,drop=True), left_index=True, right_index=True)
# flatten multiindex columns
stats.columns = ['_'.join(c) for c in stats.columns]
stats = stats.reset_index()
if save_path:
stats.to_csv(save_path, index=False)
return stats
def process(run, load=True, save=False, load_stats=True, freq='H', ntop=3):
try:
stats_path = f'data/aggs/stats-{run["run_id"]}.csv'
if load_stats and os.path.exists(stats_path):
print(f'Loaded stats file {stats_path!r}')
return pd.read_csv(stats_path)
# Load data and add extra columns from wandb run
df_long = load_data(run_id=run['run_id'],
run_path=run['run_path'],
load=load,
save=save,
# save = (run['state'] != 'running') & run['end_time']
).assign(**run.to_dict())
assert isinstance(df_long, pd.DataFrame), f'Expected dataframe, but got {type(df_long)}'
# Get and save stats
return calculate_stats(df_long, freq=freq, save_path=stats_path, ntop=ntop)
except Exception as e:
print(f'Error processing run {run["run_id"]}: { format_exc(e) }')
def line_chart(df, col, title=None):
title = title or col.replace('_',' ').title()
fig = px.line(df.astype({'_timestamp':str}),
x='_timestamp', y=col,
line_group='run_id',
title=f'{title} over Time',
labels={'_timestamp':'', col: title, 'uids':'UID','value':'counts', 'variable':'Completions'},
width=800, height=600,
template='plotly_white',
).update_traces(opacity=0.2)
fig.write_image(f'data/figures/{col}.png')
fig.write_html(f'data/figures/{col}.html')
return col
def parse_arguments():
parser = argparse.ArgumentParser(description='Process wandb validator runs for a given netuid.')
parser.add_argument('--load_runs',action='store_true', help='Load runs from file.')
parser.add_argument('--repull_unfinished',action='store_true', help='Re-pull runs that were running when downloaded and saved.')
parser.add_argument('--netuid', type=int, default=None, help='Network UID to use.')
parser.add_argument('--ntop', type=int, default=1000, help='Number of runs to process.')
parser.add_argument('--min_steps', type=int, default=100, help='Minimum number of steps to include.')
parser.add_argument('--max_workers', type=int, default=32, help='Max workers to use.')
parser.add_argument('--no_plot',action='store_true', help='Prevent plotting.')
parser.add_argument('--no_save',action='store_true', help='Prevent saving data to file.')
parser.add_argument('--no_load',action='store_true', help='Prevent loading downloaded data from file.')
parser.add_argument('--no_load_stats',action='store_true', help='Prevent loading stats data from file.')
parser.add_argument('--freq', type=str, default='H', help='Frequency to aggregate data.')
parser.add_argument('--completions_ntop', type=int, default=3, help='Number of top completions to include in stats.')
return parser.parse_args()
if __name__ == '__main__':
# TODO: flag to overwrite runs that were running when downloaded and saved: check if file date is older than run end time.
args = parse_arguments()
print(args)
filters = None# {"tags": {"$in": [f'1.1.{i}' for i in range(10)]}}
# filters={'tags': {'$in': ['5F4tQyWrhfGVcNhoqeiNsR6KjD4wMZ2kfhLj4oHYuyHbZAc3']}} # Is foundation validator
if args.load_runs and os.path.exists('data/wandb.csv'):
df_runs = pd.read_csv('data/wandb.csv')
assert len(df_runs) >= args.ntop, f'Loaded {len(df_runs)} runs, but expected at least {args.ntop}'
df_runs = df_runs.iloc[:args.ntop]
else:
df_runs = pull_wandb_runs(ntop=args.ntop,
min_steps=args.min_steps,
netuid=args.netuid,
filters=filters
)#summary_filters=lambda s: s.get('augment_prompt'))
df_runs.to_csv('data/wandb.csv', index=False)
os.makedirs('data/runs/', exist_ok=True)
os.makedirs('data/aggs/', exist_ok=True)
os.makedirs('data/figures/', exist_ok=True)
display(df_runs)
if not args.no_plot:
plot_gantt(df_runs)
with ProcessPoolExecutor(max_workers=min(args.max_workers, df_runs.shape[0])) as executor:
futures = [executor.submit(
process,
run,
load=not args.no_load,
save=not args.no_save,
load_stats=not args.no_load_stats,
freq=args.freq,
ntop=args.completions_ntop
)
for _, run in df_runs.iterrows()
]
# Use tqdm to add a progress bar
results = []
with tqdm.tqdm(total=len(futures)) as pbar:
for future in futures:
try:
result = future.result()
results.append(result)
except Exception as e:
print(f'generated an exception: {format_exc(e)}')
pbar.update(1)
if not results:
raise ValueError('No runs were successfully processed.')
# Concatenate the results into a single dataframe
df = pd.concat(results, ignore_index=True).sort_values(['_timestamp','run_id'], ignore_index=True)
df.to_csv('data/processed.csv', index=False)
print(f'Saved {df.shape[0]} rows to data/processed.csv')
display(df)
if not args.no_plot:
plots = []
cols = df.set_index(['run_id','_timestamp']).columns
with ProcessPoolExecutor(max_workers=min(args.max_workers, len(cols))) as executor:
futures = [executor.submit(line_chart, df, c) for c in cols]
# Use tqdm to add a progress bar
results = []
with tqdm.tqdm(total=len(futures)) as pbar:
for future in futures:
try:
result = future.result()
plots.append(result)
except Exception as e:
print(f'generated an exception: {format_exc(e)}')
pbar.update(1)
print(f'Saved {len(plots)} plots to data/figures/')