steffenc commited on
Commit
39972c9
·
1 Parent(s): 4dd30e1

Add gantt chart and more loaders for wandb

Browse files
opendashboards/utils/plotting.py CHANGED
@@ -28,6 +28,22 @@ from typing import List, Union
28
 
29
  plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
33
  """Plot throughput of event log.
 
28
 
29
  plotly_config = {"width": 800, "height": 600, "template": "plotly_white"}
30
 
31
+ def plot_gantt(df_runs: pd.DataFrame, y='username'):
32
+ fig = px.timeline(df_runs,
33
+ x_start="start_time", x_end="end_time", y=y, color="state",
34
+ title="Timeline of WandB Runs",
35
+ category_orders={'run_name': df_runs.run_name.unique()},
36
+ hover_name="run_name",
37
+ hover_data=[col for col in ['hotkey','user','username','run_id','num_steps','num_completions'] if col in df_runs],
38
+ color_discrete_map={'running': 'green', 'finished': 'grey', 'killed':'blue', 'crashed':'orange', 'failed': 'red'},
39
+ opacity=0.3,
40
+ width=1200,
41
+ height=800,
42
+ template="plotly_white",
43
+ )
44
+ # remove y axis ticks
45
+ fig.update_yaxes(tickfont_size=8, title='')
46
+ return fig
47
 
48
  def plot_throughput(df: pd.DataFrame, n_minutes: int = 10) -> go.Figure:
49
  """Plot throughput of event log.
opendashboards/utils/utils.py CHANGED
@@ -16,14 +16,72 @@
16
  # DEALINGS IN THE SOFTWARE.
17
 
18
  import os
 
19
  import tqdm
20
  import wandb
21
  import pandas as pd
 
 
22
  from pandas.api.types import is_list_like
23
 
24
  from typing import List, Dict, Any, Union
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def get_runs(project: str = "openvalidators", filters: Dict[str, Any] = None, return_paths: bool = False, api_key: str = None) -> List:
28
  """Download runs from wandb.
29
 
@@ -82,7 +140,7 @@ def download_data(run_path: Union[str, List] = None, timeout: float = 600, api_k
82
  return df
83
 
84
 
85
- def load_data(path: str, nrows: int = None):
86
  """Load data from csv."""
87
  df = pd.read_csv(path, nrows=nrows)
88
  # filter out events with missing step length
@@ -92,9 +150,56 @@ def load_data(path: str, nrows: int = None):
92
  list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
93
  # convert string representation of list to list
94
  df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
95
-
96
  return df
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def explode_data(df: pd.DataFrame, list_cols: List[str] = None, list_len: int = None) -> pd.DataFrame:
100
  """Explode list columns in dataframe so that each element in the list is a separate row.
 
16
  # DEALINGS IN THE SOFTWARE.
17
 
18
  import os
19
+ import re
20
  import tqdm
21
  import wandb
22
  import pandas as pd
23
+
24
+ from traceback import format_exc
25
  from pandas.api.types import is_list_like
26
 
27
  from typing import List, Dict, Any, Union
28
 
29
 
30
+ def pull_wandb_runs(project='openvalidators', filters=None, min_steps=50, ntop=10, summary_filters=None ):
31
+ all_runs = get_runs(project, filters)
32
+ print(f'Using {ntop}/{len(all_runs)} runs with more than {min_steps} events')
33
+ pbar = tqdm.tqdm(all_runs)
34
+ runs = []
35
+ n_events = 0
36
+ successful = 0
37
+ for i, run in enumerate(pbar):
38
+
39
+ summary = run.summary
40
+ if summary_filters is not None and not summary_filters(summary):
41
+ continue
42
+ step = summary.get('_step',0)
43
+ if step < min_steps:
44
+ # warnings.warn(f'Skipped run `{run.name}` because it contains {step} events (<{min_steps})')
45
+ continue
46
+
47
+ prog_msg = f'Loading data {i/len(all_runs)*100:.0f}% ({successful}/{len(all_runs)} runs, {n_events} events)'
48
+ pbar.set_description(f'{prog_msg}... **fetching** `{run.name}`')
49
+
50
+ duration = summary.get('_runtime')
51
+ end_time = summary.get('_timestamp')
52
+ # extract values for selected tags
53
+ rules = {'hotkey': re.compile('^[0-9a-z]{48}$',re.IGNORECASE), 'version': re.compile('^\\d\.\\d+\.\\d+$'), 'spec_version': re.compile('\\d{4}$')}
54
+ tags = {k: tag for k, rule in rules.items() for tag in run.tags if rule.match(tag)}
55
+ # include bool flag for remaining tags
56
+ tags.update({k: True for k in run.tags if k not in tags.keys() and k not in tags.values()})
57
+
58
+ runs.append({
59
+ 'state': run.state,
60
+ 'num_steps': step,
61
+ 'num_completions': step*sum(len(v) for k, v in run.summary.items() if k.endswith('completions') and isinstance(v, list)),
62
+ 'entity': run.entity,
63
+ 'user': run.user.name,
64
+ 'username': run.user.username,
65
+ 'run_id': run.id,
66
+ 'run_name': run.name,
67
+ 'project': run.project,
68
+ 'run_url': run.url,
69
+ 'run_path': os.path.join(run.entity, run.project, run.id),
70
+ 'start_time': pd.to_datetime(end_time-duration, unit="s"),
71
+ 'end_time': pd.to_datetime(end_time, unit="s"),
72
+ 'duration': pd.to_timedelta(duration, unit="s").round('s'),
73
+ **tags
74
+ })
75
+ n_events += step
76
+ successful += 1
77
+ if successful >= ntop:
78
+ break
79
+
80
+ cat_cols = ['state', 'hotkey', 'version', 'spec_version']
81
+ return pd.DataFrame(runs).astype({k: 'category' for k in cat_cols if k in runs[0]})
82
+
83
+
84
+
85
  def get_runs(project: str = "openvalidators", filters: Dict[str, Any] = None, return_paths: bool = False, api_key: str = None) -> List:
86
  """Download runs from wandb.
87
 
 
140
  return df
141
 
142
 
143
+ def read_data(path: str, nrows: int = None):
144
  """Load data from csv."""
145
  df = pd.read_csv(path, nrows=nrows)
146
  # filter out events with missing step length
 
150
  list_cols = [c for c in df.columns if df[c].dtype == "object" and df[c].str.startswith("[").all()]
151
  # convert string representation of list to list
152
  df[list_cols] = df[list_cols].applymap(eval, na_action='ignore')
153
+
154
  return df
155
 
156
+ def load_data(selected_runs, load=True, save=False, explode=True, datadir='data/'):
157
+
158
+ frames = []
159
+ n_events = 0
160
+ successful = 0
161
+ if not os.path.exists(datadir):
162
+ os.makedirs(datadir)
163
+
164
+ pbar = tqdm.tqdm(selected_runs.index, desc="Loading runs", total=len(selected_runs), unit="run")
165
+ for i, idx in enumerate(pbar):
166
+ run = selected_runs.loc[idx]
167
+ prog_msg = f'Loading data {i/len(selected_runs)*100:.0f}% ({successful}/{len(selected_runs)} runs, {n_events} events)'
168
+
169
+ file_path = os.path.join(datadir,f'history-{run.run_id}.csv')
170
+
171
+ if (load is True and os.path.exists(file_path)) or (callable(load) and load(run.to_dict())):
172
+ pbar.set_description(f'{prog_msg}... **reading** `{file_path}`')
173
+ try:
174
+ df = read_data(file_path)
175
+ except Exception as e:
176
+ print(f'Failed to load history from `{file_path}`: {format_exc(e)}')
177
+ continue
178
+ else:
179
+ pbar.set_description(f'{prog_msg}... **downloading** `{run.run_path}`')
180
+ try:
181
+ # Download the history from wandb and add metadata
182
+ df = download_data(run.run_path).assign(**run.to_dict())
183
+ if explode:
184
+ df = explode_data(df)
185
+
186
+ print(f'Downloaded {df.shape[0]} events from `{run.run_path}`. Columns: {df.columns}')
187
+
188
+ if save is True or (callable(save) and save(run.to_dict())):
189
+ df.to_csv(file_path, index=False)
190
+ print(f'Saved {df.shape[0]} events to `{file_path}`')
191
+
192
+ except Exception as e:
193
+ print(f'Failed to download history for `{run.run_path}`: {e}')
194
+ continue
195
+
196
+ frames.append(df)
197
+ n_events += df.shape[0]
198
+ successful += 1
199
+
200
+ # Remove rows which contain chain weights as it messes up schema
201
+ return pd.concat(frames)
202
+
203
 
204
  def explode_data(df: pd.DataFrame, list_cols: List[str] = None, list_len: int = None) -> pd.DataFrame:
205
  """Explode list columns in dataframe so that each element in the list is a separate row.