justheuristic commited on
Commit
f3aa1a2
·
unverified ·
2 Parent(s): 08e475f b9cce0e

Merge pull request #2 from training-transformers-together/LS/add-leaderboard

Browse files
app.py CHANGED
@@ -2,23 +2,97 @@ import pandas as pd
2
  import streamlit as st
3
  import wandb
4
 
5
- from dashboard_utils.bubbles import get_new_bubble_data
6
  from dashboard_utils.main_metrics import get_main_metrics
7
  from streamlit_observable import observable
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Only need to set these here as we are add controls outside of Hydralit, to customise a run Hydralit!
10
- st.set_page_config(page_title="Dashboard", layout="centered")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  wandb.login(anonymous="must")
13
 
14
- st.markdown("<h1 style='text-align: center;'>Dashboard</h1>", unsafe_allow_html=True)
15
- st.caption("Training Loss")
16
 
17
  steps, dates, losses, alive_peers = get_main_metrics()
18
  source = pd.DataFrame({"steps": steps, "loss": losses, "alive participants": alive_peers, "date": dates})
19
 
20
 
21
- st.vega_lite_chart(
22
  source,
23
  {
24
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
@@ -30,8 +104,7 @@ st.vega_lite_chart(
30
  use_container_width=True,
31
  )
32
 
33
- st.caption("Number of alive runs over time")
34
- st.vega_lite_chart(
35
  source,
36
  {
37
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
@@ -45,8 +118,7 @@ st.vega_lite_chart(
45
  },
46
  use_container_width=True,
47
  )
48
- st.caption("Number of steps")
49
- st.vega_lite_chart(
50
  source,
51
  {
52
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
@@ -58,11 +130,26 @@ st.vega_lite_chart(
58
  use_container_width=True,
59
  )
60
 
61
- st.header("Collaborative training participants")
62
  serialized_data, profiles = get_new_bubble_data()
 
63
  observable(
64
- "Participants",
65
  notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
66
  targets=["c_noaws"],
67
- redefine={"serializedData": serialized_data, "profileSimple": profiles},
68
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
  import wandb
4
 
5
+ from dashboard_utils.bubbles import get_global_metrics, get_new_bubble_data, get_leaderboard
6
  from dashboard_utils.main_metrics import get_main_metrics
7
  from streamlit_observable import observable
8
+ import time
9
+ import requests
10
+
11
+ import streamlit as st
12
+ from streamlit_lottie import st_lottie
13
+
14
+
15
+ def load_lottieurl(url: str):
16
+ r = requests.get(url)
17
+ if r.status_code != 200:
18
+ return None
19
+ return r.json()
20
+
21
 
22
  # Only need to set these here as we are add controls outside of Hydralit, to customise a run Hydralit!
23
+ st.set_page_config(page_title="Dashboard", layout="wide")
24
+
25
+ st.markdown("<h1 style='text-align: center;'>Dashboard</h1>", unsafe_allow_html=True)
26
+
27
+ key_figures_margin_left, key_figures_c1, key_figures_c2, key_figures_c3, key_figures_margin_right = st.columns(
28
+ (2, 1, 1, 1, 2)
29
+ )
30
+ chart_c1, chart_c2 = st.columns((3, 2))
31
+
32
+ lottie_url_loading = "https://assets5.lottiefiles.com/packages/lf20_OdNgAj.json"
33
+ lottie_loading = load_lottieurl(lottie_url_loading)
34
+
35
+
36
+ with key_figures_c1:
37
+ st.caption("\# of contributing users")
38
+ placeholder_key_figures_c1 = st.empty()
39
+ with placeholder_key_figures_c1:
40
+ st_lottie(lottie_loading, height=100, key="loading_key_figure_c1")
41
+
42
+ with key_figures_c2:
43
+ st.caption("\# active users")
44
+ placeholder_key_figures_c2 = st.empty()
45
+ with placeholder_key_figures_c2:
46
+ st_lottie(lottie_loading, height=100, key="loading_key_figure_c2")
47
+
48
+ with key_figures_c3:
49
+ st.caption("Total runtime")
50
+ placeholder_key_figures_c3 = st.empty()
51
+ with placeholder_key_figures_c3:
52
+ st_lottie(lottie_loading, height=100, key="loading_key_figure_c3")
53
+
54
+ with chart_c1:
55
+ st.subheader("Metrics over time")
56
+ st.caption("Training Loss")
57
+ placeholder_chart_c1_1 = st.empty()
58
+ with placeholder_chart_c1_1:
59
+ st_lottie(lottie_loading, height=100, key="loading_c1_1")
60
+
61
+ st.caption("Number of alive runs over time")
62
+ placeholder_chart_c1_2 = st.empty()
63
+ with placeholder_chart_c1_2:
64
+ st_lottie(lottie_loading, height=100, key="loading_c1_2")
65
+
66
+ st.caption("Number of steps")
67
+ placeholder_chart_c1_3 = st.empty()
68
+ with placeholder_chart_c1_3:
69
+ st_lottie(lottie_loading, height=100, key="loading_c1_3")
70
+
71
+ with chart_c2:
72
+ st.subheader("Global metrics")
73
+ st.caption("Collaborative training participants")
74
+ placeholder_chart_c2_1 = st.empty()
75
+ with placeholder_chart_c2_1:
76
+ st_lottie(lottie_loading, height=100, key="loading_c2_1")
77
+
78
+ st.write("Chart showing participants of the collaborative-training. Circle radius is relative to the total number of "
79
+ "processed batches, the circle is greyed if the participant is not active. Every purple square represents an "
80
+ "active device, darker color corresponds to higher performance.")
81
+
82
+ st.caption("Leaderboard")
83
+ placeholder_chart_c2_3 = st.empty()
84
+ with placeholder_chart_c2_3:
85
+ st_lottie(lottie_loading, height=100, key="loading_c2_2")
86
+
87
 
88
  wandb.login(anonymous="must")
89
 
 
 
90
 
91
  steps, dates, losses, alive_peers = get_main_metrics()
92
  source = pd.DataFrame({"steps": steps, "loss": losses, "alive participants": alive_peers, "date": dates})
93
 
94
 
95
+ placeholder_chart_c1_1.vega_lite_chart(
96
  source,
97
  {
98
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
 
104
  use_container_width=True,
105
  )
106
 
107
+ placeholder_chart_c1_2.vega_lite_chart(
 
108
  source,
109
  {
110
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
 
118
  },
119
  use_container_width=True,
120
  )
121
+ placeholder_chart_c1_3.vega_lite_chart(
 
122
  source,
123
  {
124
  "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
 
130
  use_container_width=True,
131
  )
132
 
 
133
  serialized_data, profiles = get_new_bubble_data()
134
+ df_leaderboard = get_leaderboard(serialized_data)
135
  observable(
136
+ "_",
137
  notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
138
  targets=["c_noaws"],
139
+ redefine={"serializedData": serialized_data, "profileSimple": profiles, "width": 0},
140
  )
141
+ placeholder_chart_c2_3.dataframe(df_leaderboard[["User", "Total time contributed"]])
142
+
143
+ global_metrics = get_global_metrics(serialized_data)
144
+
145
+ placeholder_key_figures_c1.write(f"<b>{global_metrics['num_contributing_users']}</b>", unsafe_allow_html=True)
146
+ placeholder_key_figures_c2.write(f"<b>{global_metrics['num_active_users']}</b>", unsafe_allow_html=True)
147
+ placeholder_key_figures_c3.write(f"<b>{global_metrics['total_runtime']}</b>", unsafe_allow_html=True)
148
+
149
+ with placeholder_chart_c2_1:
150
+ observable(
151
+ "Participants",
152
+ notebook="d/9ae236a507f54046", # "@huggingface/participants-bubbles-chart",
153
+ targets=["c_noaws"],
154
+ redefine={"serializedData": serialized_data, "profileSimple": profiles},
155
+ )
dashboard_utils/bubbles.py CHANGED
@@ -2,6 +2,8 @@ import datetime
2
  from concurrent.futures import as_completed
3
  from urllib import parse
4
 
 
 
5
  import streamlit as st
6
  import wandb
7
  from requests_futures.sessions import FuturesSession
@@ -11,9 +13,10 @@ from dashboard_utils.time_tracker import _log, simple_time_tracker
11
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
12
  WANDB_REPO = "learning-at-home/Worker_logs"
13
  CACHE_TTL = 100
 
14
 
15
 
16
- @st.cache(ttl=CACHE_TTL)
17
  @simple_time_tracker(_log)
18
  def get_new_bubble_data():
19
  serialized_data_points, latest_timestamp = get_serialized_data_points()
@@ -28,7 +31,7 @@ def get_new_bubble_data():
28
  return serialized_data, profiles
29
 
30
 
31
- @st.cache(ttl=CACHE_TTL)
32
  @simple_time_tracker(_log)
33
  def get_profiles(usernames):
34
  profiles = []
@@ -60,7 +63,7 @@ def get_profiles(usernames):
60
  return profiles
61
 
62
 
63
- @st.cache(ttl=CACHE_TTL)
64
  @simple_time_tracker(_log)
65
  def get_serialized_data_points():
66
 
@@ -108,7 +111,7 @@ def get_serialized_data_points():
108
  return serialized_data_points, latest_timestamp
109
 
110
 
111
- @st.cache(ttl=CACHE_TTL)
112
  @simple_time_tracker(_log)
113
  def get_serialized_data(serialized_data_points, latest_timestamp):
114
  serialized_data_points_v2 = []
@@ -138,3 +141,46 @@ def get_serialized_data(serialized_data_points, latest_timestamp):
138
  serialized_data_points_v2.append(new_item)
139
  serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity}
140
  return serialized_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from concurrent.futures import as_completed
3
  from urllib import parse
4
 
5
+ import pandas as pd
6
+
7
  import streamlit as st
8
  import wandb
9
  from requests_futures.sessions import FuturesSession
 
13
  URL_QUICKSEARCH = "https://huggingface.co/api/quicksearch?"
14
  WANDB_REPO = "learning-at-home/Worker_logs"
15
  CACHE_TTL = 100
16
+ MAX_DELTA_ACTIVE_RUN_SEC = 60 * 5
17
 
18
 
19
+ @st.cache(ttl=CACHE_TTL, show_spinner=False)
20
  @simple_time_tracker(_log)
21
  def get_new_bubble_data():
22
  serialized_data_points, latest_timestamp = get_serialized_data_points()
 
31
  return serialized_data, profiles
32
 
33
 
34
+ @st.cache(ttl=CACHE_TTL, show_spinner=False)
35
  @simple_time_tracker(_log)
36
  def get_profiles(usernames):
37
  profiles = []
 
63
  return profiles
64
 
65
 
66
+ @st.cache(ttl=CACHE_TTL, show_spinner=False)
67
  @simple_time_tracker(_log)
68
  def get_serialized_data_points():
69
 
 
111
  return serialized_data_points, latest_timestamp
112
 
113
 
114
+ @st.cache(ttl=CACHE_TTL, show_spinner=False)
115
  @simple_time_tracker(_log)
116
  def get_serialized_data(serialized_data_points, latest_timestamp):
117
  serialized_data_points_v2 = []
 
141
  serialized_data_points_v2.append(new_item)
142
  serialized_data = {"points": [serialized_data_points_v2], "maxVelocity": max_velocity}
143
  return serialized_data
144
+
145
+
146
+ def get_leaderboard(serialized_data):
147
+ data_leaderboard = {"user": [], "runtime": []}
148
+
149
+ for user_item in serialized_data["points"][0]:
150
+ data_leaderboard["user"].append(user_item["profileId"])
151
+ data_leaderboard["runtime"].append(user_item["runtime"])
152
+
153
+ df = pd.DataFrame(data_leaderboard)
154
+ df = df.sort_values("runtime", ascending=False)
155
+ df["runtime"] = df["runtime"].apply(lambda x: datetime.timedelta(seconds=x))
156
+ df["runtime"] = df["runtime"].apply(lambda x: str(x))
157
+
158
+ df.reset_index(drop=True, inplace=True)
159
+ df.rename(columns={"user": "User", "runtime": "Total time contributed"}, inplace=True)
160
+ df["Rank"] = df.index + 1
161
+ df = df.set_index("Rank")
162
+ return df
163
+
164
+
165
+ def get_global_metrics(serialized_data):
166
+ current_time = datetime.datetime.utcnow()
167
+ num_contributing_users = len(serialized_data["points"][0])
168
+ num_active_users = 0
169
+ total_runtime = 0
170
+
171
+ for user_item in serialized_data["points"][0]:
172
+ for run in user_item["activeRuns"]:
173
+ date_run = datetime.datetime.fromisoformat(run["date"])
174
+ delta_time_sec = (current_time - date_run).total_seconds()
175
+ if delta_time_sec < MAX_DELTA_ACTIVE_RUN_SEC:
176
+ num_active_users += 1
177
+ break
178
+
179
+ total_runtime += user_item["runtime"]
180
+
181
+ total_runtime = datetime.timedelta(seconds=total_runtime)
182
+ return {
183
+ "num_contributing_users": num_contributing_users,
184
+ "num_active_users": num_active_users,
185
+ "total_runtime": total_runtime,
186
+ }
dashboard_utils/main_metrics.py CHANGED
@@ -9,7 +9,7 @@ WANDB_REPO = "learning-at-home/Main_metrics"
9
  CACHE_TTL = 100
10
 
11
 
12
- @st.cache(ttl=CACHE_TTL)
13
  @simple_time_tracker(_log)
14
  def get_main_metrics():
15
  api = wandb.Api()
 
9
  CACHE_TTL = 100
10
 
11
 
12
+ @st.cache(ttl=CACHE_TTL, show_spinner=False)
13
  @simple_time_tracker(_log)
14
  def get_main_metrics():
15
  api = wandb.Api()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  streamlit
2
  wandb
3
- requests_futures
 
 
1
  streamlit
2
  wandb
3
+ requests_futures
4
+ streamlit-lottie