steffenc's picture
Add wandb overview
7980ef4
import os
import re
import time
import pandas as pd
import streamlit as st
import opendashboards.utils.utils as utils
from pandas.api.types import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_numeric_dtype,
is_object_dtype,
)
# @st.cache_data
def load_runs(project, filters, min_steps=10, max_recent=100, local_path='wandb_runs.csv', local_stale_time=3600):
# TODO: clean up the caching logic (e.g. take into account the args)
dtypes = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category', 'start_time': 'datetime64[s]', 'end_time': 'datetime64[s]', 'duration': 'timedelta64[s]'}
if local_path and os.path.exists(local_path) and (time.time() - float(os.path.getmtime(local_path))) < local_stale_time:
frame = pd.read_csv(local_path)
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})
runs = []
n_events = 0
successful = 0
progress = st.progress(0, 'Fetching runs from wandb')
msg = st.empty()
all_runs = utils.get_runs(project, filters)
for i, run in enumerate(all_runs):
if i > max_recent:
break
summary = run.summary
step = summary.get('_step',-1) + 1
if step < min_steps:
msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
continue
prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
progress.progress(min(i/len(all_runs),1),f'{prog_msg}... **fetching** `{run.name}`')
duration = summary.get('_runtime')
end_time = summary.get('_timestamp')
# extract values for selected tags
rules = {
'version': re.compile('^\\d\.\\d+\.\\d+$'),
'spec_version': re.compile('\\d{4}$'),
'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE)
}
tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
# include bool flag for remaining tags
tags.update({k: k in run.tags for k in ('mock','disable_set_weights')})
runs.append({
'state': run.state,
'num_steps': step,
'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
'duration': pd.to_timedelta(duration, unit="s").round('T'), # round to nearest minute
'start_time': pd.to_datetime(end_time-duration, unit="s").round('T'),
'end_time': pd.to_datetime(end_time, unit="s").round('T'),
'netuid': run.config.get('netuid'),
**tags,
'username': run.user.username,
'run_id': run.id,
'run_name': run.name,
'url': run.url,
# 'entity': run.entity,
# 'project': run.project,
'run_path': os.path.join(run.entity, run.project, run.id),
})
n_events += step
successful += 1
progress.empty()
msg.empty()
frame = pd.DataFrame(runs)
frame.to_csv(local_path, index=False)
return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns})
@st.cache_data
def load_data(selected_runs, load=True, save=False):
frames = []
n_events = 0
successful = 0
progress = st.progress(0, 'Loading data')
info = st.empty()
if not os.path.exists('data/'):
os.makedirs('data/')
for i, idx in enumerate(selected_runs.index):
run = selected_runs.loc[idx]
prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({successful}/{len(selected_runs)} runs, {n_events} events)'
file_path = os.path.join('data',f'history-{run.run_id}.csv')
if load and os.path.exists(file_path):
progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`')
try:
df = utils.read_data(file_path)
except Exception as e:
info.warning(f'Failed to load history from `{file_path}`')
st.exception(e)
continue
else:
progress.progress(i/len(selected_runs),f'{prog_msg}... **downloading** `{run.run_path}`')
try:
# Download the history from wandb and add metadata
df = utils.download_data(run.run_path).assign(**run.to_dict())
print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}')
df.info()
if save and run.state != 'running':
df.to_csv(file_path, index=False)
# st.info(f'Saved history to {file_path}')
except Exception as e:
info.warning(f'Failed to download history for `{run.run_path}`')
st.exception(e)
continue
frames.append(df)
n_events += df.shape[0]
successful += 1
progress.empty()
if not frames:
info.error('No data loaded')
st.stop()
# Remove rows which contain chain weights as it messes up schema
return pd.concat(frames)
def filter_dataframe(df: pd.DataFrame, demo_selection=None) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Args:
df (pd.DataFrame): Original dataframe
demo_selection (pd.Index): Index of runs to select (if demo)
Returns:
pd.DataFrame: Filtered dataframe
"""
filter_mode = st.sidebar.radio("Filter mode", ("Use demo", "Add filters"), index=0)
run_msg = st.info("Select a single wandb run or compare multiple runs")
if filter_mode == "Use demo":
df = df.loc[demo_selection]
run_msg.info(f"Selected {len(df)} runs")
return df
df = df.copy()
# Try to convert datetimes into a standarrd format (datetime, no timezone)
for col in df.columns:
if is_object_dtype(df[col]):
try:
df[col] = pd.to_datetime(df[col])
except Exception:
pass
if is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.tz_localize(None)
modification_container = st.container()
with modification_container:
to_filter_columns = st.multiselect("Filter dataframe on", df.columns)
for column in to_filter_columns:
left, right = st.columns((1, 20))
# Treat columns with < 10 unique values as categorical
if is_categorical_dtype(df[column]) or df[column].nunique() < 10:
user_cat_input = right.multiselect(
f"Values for {column}",
df[column].unique(),
default=list(df[column].unique()),
)
df = df[df[column].isin(user_cat_input)]
elif is_numeric_dtype(df[column]):
_min = float(df[column].min())
_max = float(df[column].max())
step = (_max - _min) / 100
user_num_input = right.slider(
f"Values for {column}",
min_value=_min,
max_value=_max,
value=(_min, _max),
step=step,
)
df = df[df[column].between(*user_num_input)]
elif is_datetime64_any_dtype(df[column]):
user_date_input = right.date_input(
f"Values for {column}",
value=(
df[column].min(),
df[column].max(),
),
)
if len(user_date_input) == 2:
user_date_input = tuple(map(pd.to_datetime, user_date_input))
start_date, end_date = user_date_input
df = df.loc[df[column].between(start_date, end_date)]
else:
user_text_input = right.text_input(
f"Substring or regex in {column}",
)
if user_text_input:
df = df[df[column].astype(str).str.contains(user_text_input)]
# Load data if new runs selected
if len(df):
run_msg.info(f"Selected {len(df)} runs")
else:
# open a dialog to select runs
run_msg.error("Please select at least one run")
# st.snow()
# st.stop()
return df