import os import re import time import pandas as pd import streamlit as st import opendashboards.utils.utils as utils from pandas.api.types import ( is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype, is_object_dtype, ) # @st.cache_data def load_runs(project, filters, min_steps=10, max_recent=100, local_path='wandb_runs.csv', local_stale_time=3600): # TODO: clean up the caching logic (e.g. take into account the args) dtypes = {'state': 'category', 'hotkey': 'category', 'version': 'category', 'spec_version': 'category', 'start_time': 'datetime64[s]', 'end_time': 'datetime64[s]', 'duration': 'timedelta64[s]'} if local_path and os.path.exists(local_path) and (time.time() - float(os.path.getmtime(local_path))) < local_stale_time: frame = pd.read_csv(local_path) return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns}) runs = [] n_events = 0 successful = 0 progress = st.progress(0, 'Fetching runs from wandb') msg = st.empty() all_runs = utils.get_runs(project, filters) for i, run in enumerate(all_runs): if i > max_recent: break summary = run.summary step = summary.get('_step',-1) + 1 if step < min_steps: msg.warning(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})') continue prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)' progress.progress(min(i/len(all_runs),1),f'{prog_msg}... **fetching** `{run.name}`') duration = summary.get('_runtime') end_time = summary.get('_timestamp') # extract values for selected tags rules = { 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$'), 'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE) } tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)} # include bool flag for remaining tags tags.update({k: k in run.tags for k in ('mock','disable_set_weights')}) runs.append({ 'state': run.state, 'num_steps': step, 'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)), 'duration': pd.to_timedelta(duration, unit="s").round('T'), # round to nearest minute 'start_time': pd.to_datetime(end_time-duration, unit="s").round('T'), 'end_time': pd.to_datetime(end_time, unit="s").round('T'), 'netuid': run.config.get('netuid'), **tags, 'username': run.user.username, 'run_id': run.id, 'run_name': run.name, 'url': run.url, # 'entity': run.entity, # 'project': run.project, 'run_path': os.path.join(run.entity, run.project, run.id), }) n_events += step successful += 1 progress.empty() msg.empty() frame = pd.DataFrame(runs) frame.to_csv(local_path, index=False) return frame.astype({k:v for k,v in dtypes.items() if k in frame.columns}) @st.cache_data def load_data(selected_runs, load=True, save=False): frames = [] n_events = 0 successful = 0 progress = st.progress(0, 'Loading data') info = st.empty() if not os.path.exists('data/'): os.makedirs('data/') for i, idx in enumerate(selected_runs.index): run = selected_runs.loc[idx] prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({successful}/{len(selected_runs)} runs, {n_events} events)' file_path = os.path.join('data',f'history-{run.run_id}.csv') if load and os.path.exists(file_path): progress.progress(i/len(selected_runs),f'{prog_msg}... **reading** `{file_path}`') try: df = utils.read_data(file_path) except Exception as e: info.warning(f'Failed to load history from `{file_path}`') st.exception(e) continue else: progress.progress(i/len(selected_runs),f'{prog_msg}... **downloading** `{run.run_path}`') try: # Download the history from wandb and add metadata df = utils.download_data(run.run_path).assign(**run.to_dict()) print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}') df.info() if save and run.state != 'running': df.to_csv(file_path, index=False) # st.info(f'Saved history to {file_path}') except Exception as e: info.warning(f'Failed to download history for `{run.run_path}`') st.exception(e) continue frames.append(df) n_events += df.shape[0] successful += 1 progress.empty() if not frames: info.error('No data loaded') st.stop() # Remove rows which contain chain weights as it messes up schema return pd.concat(frames) def filter_dataframe(df: pd.DataFrame, demo_selection=None) -> pd.DataFrame: """ Adds a UI on top of a dataframe to let viewers filter columns Args: df (pd.DataFrame): Original dataframe demo_selection (pd.Index): Index of runs to select (if demo) Returns: pd.DataFrame: Filtered dataframe """ filter_mode = st.sidebar.radio("Filter mode", ("Use demo", "Add filters"), index=0) run_msg = st.info("Select a single wandb run or compare multiple runs") if filter_mode == "Use demo": df = df.loc[demo_selection] run_msg.info(f"Selected {len(df)} runs") return df df = df.copy() # Try to convert datetimes into a standarrd format (datetime, no timezone) for col in df.columns: if is_object_dtype(df[col]): try: df[col] = pd.to_datetime(df[col]) except Exception: pass if is_datetime64_any_dtype(df[col]): df[col] = df[col].dt.tz_localize(None) modification_container = st.container() with modification_container: to_filter_columns = st.multiselect("Filter dataframe on", df.columns) for column in to_filter_columns: left, right = st.columns((1, 20)) # Treat columns with < 10 unique values as categorical if is_categorical_dtype(df[column]) or df[column].nunique() < 10: user_cat_input = right.multiselect( f"Values for {column}", df[column].unique(), default=list(df[column].unique()), ) df = df[df[column].isin(user_cat_input)] elif is_numeric_dtype(df[column]): _min = float(df[column].min()) _max = float(df[column].max()) step = (_max - _min) / 100 user_num_input = right.slider( f"Values for {column}", min_value=_min, max_value=_max, value=(_min, _max), step=step, ) df = df[df[column].between(*user_num_input)] elif is_datetime64_any_dtype(df[column]): user_date_input = right.date_input( f"Values for {column}", value=( df[column].min(), df[column].max(), ), ) if len(user_date_input) == 2: user_date_input = tuple(map(pd.to_datetime, user_date_input)) start_date, end_date = user_date_input df = df.loc[df[column].between(start_date, end_date)] else: user_text_input = right.text_input( f"Substring or regex in {column}", ) if user_text_input: df = df[df[column].astype(str).str.contains(user_text_input)] # Load data if new runs selected if len(df): run_msg.info(f"Selected {len(df)} runs") else: # open a dialog to select runs run_msg.error("Please select at least one run") # st.snow() # st.stop() return df