Jellyfish042 commited on
Commit
3dc582a
β€’
2 Parent(s): 41539a4 31bf9ae

Merge remote-tracking branch 'origin/main'

Browse files
Files changed (2) hide show
  1. app.py +19 -13
  2. data/2024-07/14b.xlsx +0 -0
app.py CHANGED
@@ -164,7 +164,8 @@ def update_table(period: str,
164
  if len(combined_data) > 0:
165
  sorted_data = combined_data.sort_values(by=sort_by, ascending=ascending)
166
  sorted_data = sorted_data.rename(columns={'Average (The lower the better)': 'Average (lower=better)'})
167
- visible_columns = ['Name', 'Parameters Count (B)', 'Average (lower=better)'] + visible_columns
 
168
  filtered_data = sorted_data[visible_columns]
169
 
170
  filtered_data.columns = [col.replace('_', ' ') for col in filtered_data.columns]
@@ -178,7 +179,7 @@ def update_table(period: str,
178
  vmin = {}
179
  vmax = {}
180
  for column in filtered_data.columns:
181
- if column in ['Name', 'Parameters Count (B)']:
182
  continue
183
  col_values = filtered_data[column]
184
  if len(col_values) > 1:
@@ -191,9 +192,11 @@ def update_table(period: str,
191
  target_color_columns.append('Average (lower=better)')
192
  if 'Individual Tests' in color_columns:
193
  target_color_columns.extend([col for col in filtered_data.columns if
194
- col not in ['Name', 'Parameters Count (B)', 'Average (lower=better)']])
 
 
 
195
 
196
- styler = filtered_data.style.format(formatter)
197
  for column in target_color_columns:
198
  if column in vmin and column in vmax: # Ensure that the vmin and vmax dicts contain the column
199
  styler = styler.background_gradient(cmap=cmap, subset=[column], vmin=vmin[column], vmax=vmax[column])
@@ -271,7 +274,8 @@ def create_scaling_plot(all_data, period):
271
  names_to_connect = ['Meta-Llama-3-8B',
272
  'stablelm-3b-4e1t',
273
  'Qwen2-1.5B',
274
- 'TinyLlama-1.1B-intermediate-step-1431k-3T']
 
275
  connection_points = new_df[new_df['Name'].isin(names_to_connect)]
276
 
277
  new_df['Color'] = new_df['Name'].apply(lambda name: '#39C5BB' if name in names_to_connect else '#636efa')
@@ -284,7 +288,7 @@ def create_scaling_plot(all_data, period):
284
 
285
  x_min = connection_points['Log Params(B)'].min()
286
  x_max = connection_points['Log Params(B)'].max()
287
- extended_x = np.linspace(x_min, x_max * 1.25, 100)
288
  extended_x_original = np.exp(extended_x)
289
  trend_line_y = model.predict(extended_x.reshape(-1, 1))
290
  trend_line_y_original = np.exp(trend_line_y)
@@ -347,8 +351,11 @@ def read_all_data(folder_name):
347
 
348
  all_data, time_list = read_all_data('data')
349
 
350
- initial_fig = create_scaling_plot(all_data, time_list[-1])
351
- initial_period = time_list[-1]
 
 
 
352
  initial_models = model_size_list
353
  initial_metric = metric_list[0]
354
  initial_columns = get_unique_column_names(all_data)
@@ -379,7 +386,7 @@ with gr.Blocks(css=css) as demo:
379
  with gr.Tab("πŸ† Leaderboard"):
380
  with gr.Row():
381
  with gr.Column():
382
- period_selector = gr.Dropdown(label="Period", choices=time_list, value=time_list[-1])
383
  model_selector = gr.CheckboxGroup(label="Model", choices=model_size_list, value=model_size_list)
384
  metric_selector = gr.Dropdown(label="Metric", choices=metric_list, value=metric_list[0])
385
  with gr.Column():
@@ -390,7 +397,7 @@ with gr.Blocks(css=css) as demo:
390
  choices=get_unique_column_names(all_data),
391
  value=get_unique_column_names(all_data))
392
 
393
- table = gr.Dataframe(initial_data, column_widths=[130, 60, 60, 35, 35, 35, 35, 35, 35, 35],
394
  wrap=True,
395
  height=800,
396
  )
@@ -414,14 +421,13 @@ with gr.Blocks(css=css) as demo:
414
  with gr.Tab("🌍 MultiLang"):
415
  gr.Markdown("## Coming soon...")
416
  with gr.Tab("πŸ“ˆ Scaling Law"):
417
- period_selector_2 = gr.Dropdown(label="Period", choices=time_list, value=time_list[0])
418
-
419
 
420
  def update_plot(period):
421
  new_fig = create_scaling_plot(all_data, period)
422
  return new_fig
423
 
424
-
425
  plot = gr.Plot(initial_fig)
426
  period_selector_2.change(update_plot, inputs=period_selector_2, outputs=plot)
427
 
 
164
  if len(combined_data) > 0:
165
  sorted_data = combined_data.sort_values(by=sort_by, ascending=ascending)
166
  sorted_data = sorted_data.rename(columns={'Average (The lower the better)': 'Average (lower=better)'})
167
+ sorted_data = sorted_data.rename(columns={'Parameters Count (B)': 'Params (B)'})
168
+ visible_columns = ['Name', 'Params (B)', 'Average (lower=better)'] + visible_columns
169
  filtered_data = sorted_data[visible_columns]
170
 
171
  filtered_data.columns = [col.replace('_', ' ') for col in filtered_data.columns]
 
179
  vmin = {}
180
  vmax = {}
181
  for column in filtered_data.columns:
182
+ if column in ['Name', 'Params (B)']:
183
  continue
184
  col_values = filtered_data[column]
185
  if len(col_values) > 1:
 
192
  target_color_columns.append('Average (lower=better)')
193
  if 'Individual Tests' in color_columns:
194
  target_color_columns.extend([col for col in filtered_data.columns if
195
+ col not in ['Name', 'Params (B)', 'Average (lower=better)']])
196
+
197
+
198
+ styler = filtered_data.style.format(formatter).applymap(color_cell, subset=['Params (B)'])
199
 
 
200
  for column in target_color_columns:
201
  if column in vmin and column in vmax: # Ensure that the vmin and vmax dicts contain the column
202
  styler = styler.background_gradient(cmap=cmap, subset=[column], vmin=vmin[column], vmax=vmax[column])
 
274
  names_to_connect = ['Meta-Llama-3-8B',
275
  'stablelm-3b-4e1t',
276
  'Qwen2-1.5B',
277
+ 'TinyLlama-1.1B-intermediate-step-1431k-3T',
278
+ 'Mistral-Nemo-Base-2407']
279
  connection_points = new_df[new_df['Name'].isin(names_to_connect)]
280
 
281
  new_df['Color'] = new_df['Name'].apply(lambda name: '#39C5BB' if name in names_to_connect else '#636efa')
 
288
 
289
  x_min = connection_points['Log Params(B)'].min()
290
  x_max = connection_points['Log Params(B)'].max()
291
+ extended_x = np.linspace(x_min, x_max * 1.5, 100)
292
  extended_x_original = np.exp(extended_x)
293
  trend_line_y = model.predict(extended_x.reshape(-1, 1))
294
  trend_line_y_original = np.exp(trend_line_y)
 
351
 
352
  all_data, time_list = read_all_data('data')
353
 
354
+ time_list.sort()
355
+ last_period = time_list[-1]
356
+
357
+ initial_fig = create_scaling_plot(all_data, last_period)
358
+ initial_period = last_period
359
  initial_models = model_size_list
360
  initial_metric = metric_list[0]
361
  initial_columns = get_unique_column_names(all_data)
 
386
  with gr.Tab("πŸ† Leaderboard"):
387
  with gr.Row():
388
  with gr.Column():
389
+ period_selector = gr.Dropdown(label="Period", choices=time_list, value=last_period)
390
  model_selector = gr.CheckboxGroup(label="Model", choices=model_size_list, value=model_size_list)
391
  metric_selector = gr.Dropdown(label="Metric", choices=metric_list, value=metric_list[0])
392
  with gr.Column():
 
397
  choices=get_unique_column_names(all_data),
398
  value=get_unique_column_names(all_data))
399
 
400
+ table = gr.Dataframe(initial_data, column_widths=[130, 50, 50, 35, 35, 35, 35, 35, 35, 35],
401
  wrap=True,
402
  height=800,
403
  )
 
421
  with gr.Tab("🌍 MultiLang"):
422
  gr.Markdown("## Coming soon...")
423
  with gr.Tab("πŸ“ˆ Scaling Law"):
424
+ print(time_list)
425
+ period_selector_2 = gr.Dropdown(label="Period", choices=time_list, value=last_period)
426
 
427
  def update_plot(period):
428
  new_fig = create_scaling_plot(all_data, period)
429
  return new_fig
430
 
 
431
  plot = gr.Plot(initial_fig)
432
  period_selector_2.change(update_plot, inputs=period_selector_2, outputs=plot)
433
 
data/2024-07/14b.xlsx CHANGED
Binary files a/data/2024-07/14b.xlsx and b/data/2024-07/14b.xlsx differ