sherzod-hakimov commited on
Commit
99fbf22
·
verified ·
1 Parent(s): e7fee79

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/plot_utils.py +3 -4
  2. src/trend_utils.py +88 -66
  3. src/version_utils.py +1 -1
src/plot_utils.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  import json
5
  import gradio as gr
6
 
7
- from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME
8
  from src.leaderboard_utils import get_github_data
9
 
10
 
@@ -131,8 +131,7 @@ def split_models(model_list: list):
131
  commercial_models = []
132
 
133
  # Load model registry data from main repo
134
- model_registry_url = "https://raw.githubusercontent.com/clp-research/clembench/main/backends/model_registry.json"
135
- response = requests.get(model_registry_url)
136
 
137
  if response.status_code == 200:
138
  json_data = json.loads(response.text)
@@ -149,7 +148,7 @@ def split_models(model_list: list):
149
  break
150
 
151
  else:
152
- print(f"Failed to read JSON file: Status Code : {response.status_code}")
153
 
154
  open_models.sort(key=lambda o: o.upper())
155
  commercial_models.sort(key=lambda c: c.upper())
 
4
  import json
5
  import gradio as gr
6
 
7
+ from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME, REGISTRY_URL
8
  from src.leaderboard_utils import get_github_data
9
 
10
 
 
131
  commercial_models = []
132
 
133
  # Load model registry data from main repo
134
+ response = requests.get(REGISTRY_URL)
 
135
 
136
  if response.status_code == 200:
137
  json_data = json.loads(response.text)
 
148
  break
149
 
150
  else:
151
+ print(f"Failed to read JSON file: {REGISTRY_URL} Status Code : {response.status_code}")
152
 
153
  open_models.sort(key=lambda o: o.upper())
154
  commercial_models.sort(key=lambda c: c.upper())
src/trend_utils.py CHANGED
@@ -13,6 +13,10 @@ from src.leaderboard_utils import get_github_data
13
  # Cut-off date from where to start the trendgraph
14
  START_DATE = '2023-06-01'
15
 
 
 
 
 
16
  def get_param_size(params: str) -> float:
17
  """Convert parameter size from string to float.
18
 
@@ -109,21 +113,40 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
109
  comm_models = populate_list(comm_model_df, comm_dip)
110
  return open_models, comm_models
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
114
  """Process text data frames to extract model information.
115
 
116
  Args:
117
- text_dfs (list): List of DataFrames containing model information.
118
  model_registry_data (list): List of dictionaries containing model registry data.
119
 
120
  Returns:
121
  pd.DataFrame: DataFrame containing processed model data.
122
  """
123
  visited = set() # Track models that have been processed
124
- result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag'])
125
 
126
- for df in text_dfs:
 
 
 
127
  for i in range(len(df)):
128
  model_name = df['Model'].iloc[i]
129
  if model_name not in visited:
@@ -138,10 +161,12 @@ def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
138
  est_flag = False
139
 
140
  param_size = get_param_size(params)
 
141
  new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
142
- 'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag}
143
  result_df.loc[len(result_df)] = new_data
144
  break
 
145
  return result_df # Return the compiled DataFrame
146
 
147
 
@@ -175,12 +200,12 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
175
 
176
  max_clemscore = df['clemscore'].max()
177
  # Convert 'release_date' to datetime
178
- df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
 
179
  # Filter out data before April 2023/START_DATE
180
- df = df[df['Release date'] >= pd.to_datetime(start_date)]
181
  open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
182
  models_to_display = open_model_list + comm_model_list
183
- print(f"open_model_list: {open_model_list}, comm_model_list: {comm_model_list}")
184
 
185
  # Create a column to indicate if the model should be labeled
186
  df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
@@ -189,38 +214,72 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
189
  if mobile_view:
190
  df = df[df['model'].isin(models_to_display)]
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  # Add an identifier column to each DataFrame
193
- df['Model Type'] = df['open_weight'].map({True: 'Open-Weight', False: 'Commercial'})
 
 
 
194
 
195
- marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float) # Arbitrary sqrt value to scale marker size based on parameter size
 
 
 
 
 
 
196
 
197
- open_color = 'red'
198
- comm_color = 'blue'
199
 
200
  # Create the scatter plot
201
  fig = px.scatter(df,
202
- x="Release date",
203
  y="clemscore",
204
- color="Model Type", # Differentiates the datasets by color
 
205
  hover_name="model",
206
  size=marker_size,
207
  size_max=40, # Max size of the circles
208
  template="plotly_white",
209
  hover_data={ # Customize hover information
210
- "Release date": True, # Show the release date
211
  "clemscore": True, # Show the clemscore
212
- "Model Type": True # Show the model type
213
  },
214
- custom_data=["model", "Release date", "clemscore"] # Specify custom data columns for hover
 
215
  )
216
 
217
  fig.update_traces(
218
- hovertemplate='Model Name: %{customdata[0]}<br>Release date: %{customdata[1]}<br>Clemscore: %{customdata[2]}<br>'
219
  )
220
 
221
  # Sort dataframes for line plotting
222
- df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release date')
223
- df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release date')
224
 
225
  ## Custom tics for x axis
226
  # Define the start and end dates
@@ -236,43 +295,6 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
236
  custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
237
  custom_tickvals = list(custom_ticks.keys())
238
 
239
-
240
- for date, version in benchmark_ticks.items():
241
- # Find the corresponding update date from benchmark_update based on the version name
242
- update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None)
243
-
244
- if update_date:
245
- # Add vertical black dotted line for each benchmark_tick date
246
- fig.add_shape(
247
- go.layout.Shape(
248
- type='line',
249
- x0=date,
250
- x1=date,
251
- y0=0,
252
- y1=1,
253
- yref='paper',
254
- line=dict(color='#A9A9A9', dash='dash'), # Black dotted line
255
- )
256
- )
257
-
258
- # Add hover information across the full y-axis range
259
- fig.add_trace(
260
- go.Scatter(
261
- x=[date]*100,
262
- y=list(range(0,100)), # Covers full y-axis range
263
- mode='markers',
264
- line=dict(color='rgba(255,255,255,0)', width=0), # Fully transparent line
265
- hovertext=[
266
- f"Version: {version} released on {date.strftime('%d %b %Y')}, last updated on: {update_date.strftime('%d %b %Y')}"
267
- for _ in range(100)
268
- ], # Unique hovertext for all points
269
- hoverinfo="text",
270
- hoveron='points',
271
- showlegend=False
272
- )
273
- )
274
-
275
-
276
  if mobile_view:
277
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
278
  one_month = pd.DateOffset(months=1)
@@ -313,20 +335,20 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
313
 
314
 
315
  # Add lines connecting the points for open models
316
- fig.add_trace(go.Scatter(x=df_open['Release date'], y=df_open['clemscore'],
317
  mode=display_mode, # Include 'text' in the mode
318
  name='Open Models Trendline',
319
  text=df_open['label_model'], # Use label_model for text labels
320
  textposition='top center', # Position of the text labels
321
- line=dict(color=open_color), showlegend=False))
322
 
323
  # Add lines connecting the points for commercial models
324
- fig.add_trace(go.Scatter(x=df_commercial['Release date'], y=df_commercial['clemscore'],
325
  mode=display_mode, # Include 'text' in the mode
326
  name='Commercial Models Trendline',
327
  text=df_commercial['label_model'], # Use label_model for text labels
328
  textposition='top center', # Position of the text labels
329
- line=dict(color=comm_color), showlegend=False))
330
 
331
 
332
  # Update layout to ensure text labels are visible
@@ -367,7 +389,7 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
367
 
368
  # Check if the JSON file request was successful
369
  if response.status_code != 200:
370
- print(f"Failed to read JSON file: Status Code: {response.status_code}")
371
 
372
  json_data = response.json()
373
  versions = json_data['versions']
@@ -385,8 +407,8 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
385
  benchmark_ticks = {}
386
  benchmark_update = {}
387
  if benchmark == "Text":
388
- text_dfs = get_github_data()['text']['dataframes']
389
- text_result_df = get_trend_data(text_dfs, model_registry_data)
390
  ## Get benchmark tickvalues as dates for X-axis
391
  for ver in versions:
392
  if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
@@ -398,8 +420,8 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
398
 
399
  fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
400
  else:
401
- mm_dfs = get_github_data()['multimodal']['dataframes']
402
- result_df = get_trend_data(mm_dfs, model_registry_data)
403
  df = result_df
404
  for ver in versions:
405
  if 'multimodal' in ver['version']:
 
13
  # Cut-off date from where to start the trendgraph
14
  START_DATE = '2023-06-01'
15
 
16
+ # Graph colours
17
+ COLOUR_OPEN = 'red'
18
+ COLOUR_COMM = 'blue'
19
+
20
  def get_param_size(params: str) -> float:
21
  """Convert parameter size from string to float.
22
 
 
113
  comm_models = populate_list(comm_model_df, comm_dip)
114
  return open_models, comm_models
115
 
116
+ # Function to interpolate between two colors
117
+ def interpolate_color(rank_val, start_color):
118
+ """
119
+ """
120
+ if start_color == 'red':
121
+ hue = 0
122
+ elif start_color == 'blue':
123
+ hue = 240
124
+ else:
125
+ raise KeyError(f"Invalid color selected for trend graph: {start_color}. Please set either red or blue. Alternatively, set hue value in src.trend_utils.interpolate_colour")
126
+
127
+ saturation = rank_val*100
128
+ value = 70 if rank_val == 1 else 100
129
+
130
+ return f"hsv({hue},{saturation},{value})"
131
+
132
 
133
+ def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
134
  """Process text data frames to extract model information.
135
 
136
  Args:
137
+ text_data (dict): Dict containing DataFrames and version deatils.
138
  model_registry_data (list): List of dictionaries containing model registry data.
139
 
140
  Returns:
141
  pd.DataFrame: DataFrame containing processed model data.
142
  """
143
  visited = set() # Track models that have been processed
144
+ result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag', 'version'])
145
 
146
+ text_dfs = text_data['dataframes']
147
+ for i in range(len(text_dfs)):
148
+ df = text_dfs[i]
149
+ version = text_data['version_data'][i]['name']
150
  for i in range(len(df)):
151
  model_name = df['Model'].iloc[i]
152
  if model_name not in visited:
 
161
  est_flag = False
162
 
163
  param_size = get_param_size(params)
164
+
165
  new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
166
+ 'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag, 'version': version}
167
  result_df.loc[len(result_df)] = new_data
168
  break
169
+
170
  return result_df # Return the compiled DataFrame
171
 
172
 
 
200
 
201
  max_clemscore = df['clemscore'].max()
202
  # Convert 'release_date' to datetime
203
+ df['Release Date (Model and & Benchmark Version)'] = pd.to_datetime(df['release_date'], format='ISO8601')
204
+
205
  # Filter out data before April 2023/START_DATE
206
+ df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
207
  open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
208
  models_to_display = open_model_list + comm_model_list
 
209
 
210
  # Create a column to indicate if the model should be labeled
211
  df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
 
214
  if mobile_view:
215
  df = df[df['model'].isin(models_to_display)]
216
 
217
+ versions = df['version'].unique()
218
+ version_names = sorted(
219
+ [ver for ver in versions],
220
+ key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
221
+ reverse=True
222
+ )
223
+
224
+ version_names = version_names[:3] # Select 3 latest benchmark versions
225
+ df = df[df['version'].isin(tuple(version_names))]
226
+
227
+ rank = 2
228
+ max_rank = len(version_names)
229
+ rank_value = {version_names[0]: 1}
230
+ for ver in version_names:
231
+ if ver not in rank_value:
232
+ rank_value[ver] = 1 - (rank-1-(max_rank/15))/(max_rank-1)
233
+ rank += 1
234
+
235
+ df['color_value'] = df.apply(
236
+ lambda row: rank_value[row['version']],
237
+ axis=1
238
+ )
239
+
240
  # Add an identifier column to each DataFrame
241
+ df['Model Type & Benchmark Version'] = df.apply(
242
+ lambda row: f"Open-Weight {row['version']}" if row['open_weight'] else f"Commercial {row['version']}",
243
+ axis=1
244
+ )
245
 
246
+ color_map = {}
247
+ for i in range(len(df)):
248
+ if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
249
+ if df.iloc[i]['open_weight']:
250
+ color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_OPEN)
251
+ else:
252
+ color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_COMM)
253
 
254
+
255
+ marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float) # Arbitrary sqrt value to scale marker size based on parameter size
256
 
257
  # Create the scatter plot
258
  fig = px.scatter(df,
259
+ x="Release Date (Model and & Benchmark Version)",
260
  y="clemscore",
261
+ color="Model Type & Benchmark Version", # Differentiates the datasets by color
262
+ color_discrete_map=color_map, # Map colors to the defined subclasses
263
  hover_name="model",
264
  size=marker_size,
265
  size_max=40, # Max size of the circles
266
  template="plotly_white",
267
  hover_data={ # Customize hover information
268
+ "Release Date (Model and & Benchmark Version)": True, # Show the Release Date (Model and & Benchmark Version)
269
  "clemscore": True, # Show the clemscore
270
+ "version": True
271
  },
272
+ custom_data=["model", "Release Date (Model and & Benchmark Version)", "clemscore", "version"], # Specify custom data columns for hover
273
+ opacity=0.8
274
  )
275
 
276
  fig.update_traces(
277
+ hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model and & Benchmark Version): %{customdata[1]}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
278
  )
279
 
280
  # Sort dataframes for line plotting
281
+ df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
282
+ df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
283
 
284
  ## Custom tics for x axis
285
  # Define the start and end dates
 
295
  custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
296
  custom_tickvals = list(custom_ticks.keys())
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  if mobile_view:
299
  # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
300
  one_month = pd.DateOffset(months=1)
 
335
 
336
 
337
  # Add lines connecting the points for open models
338
+ fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
339
  mode=display_mode, # Include 'text' in the mode
340
  name='Open Models Trendline',
341
  text=df_open['label_model'], # Use label_model for text labels
342
  textposition='top center', # Position of the text labels
343
+ line=dict(color='red'), showlegend=False))
344
 
345
  # Add lines connecting the points for commercial models
346
+ fig.add_trace(go.Scatter(x=df_commercial['Release Date (Model and & Benchmark Version)'], y=df_commercial['clemscore'],
347
  mode=display_mode, # Include 'text' in the mode
348
  name='Commercial Models Trendline',
349
  text=df_commercial['label_model'], # Use label_model for text labels
350
  textposition='top center', # Position of the text labels
351
+ line=dict(color='blue'), showlegend=False))
352
 
353
 
354
  # Update layout to ensure text labels are visible
 
389
 
390
  # Check if the JSON file request was successful
391
  if response.status_code != 200:
392
+ print(f"Failed to read JSON file {json_url}: Status Code: {response.status_code}")
393
 
394
  json_data = response.json()
395
  versions = json_data['versions']
 
407
  benchmark_ticks = {}
408
  benchmark_update = {}
409
  if benchmark == "Text":
410
+ text_data = get_github_data()['text']
411
+ text_result_df = get_trend_data(text_data, model_registry_data)
412
  ## Get benchmark tickvalues as dates for X-axis
413
  for ver in versions:
414
  if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
 
420
 
421
  fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
422
  else:
423
+ mm_data = get_github_data()['multimodal']
424
+ result_df = get_trend_data(mm_data, model_registry_data)
425
  df = result_df
426
  for ver in versions:
427
  if 'multimodal' in ver['version']:
src/version_utils.py CHANGED
@@ -27,7 +27,7 @@ def get_version_data():
27
 
28
  # Check if the JSON file request was successful
29
  if response.status_code != 200:
30
- print(f"Failed to read JSON file: Status Code: {response.status_code}")
31
  return None, None, None, None
32
 
33
  json_data = response.json()
 
27
 
28
  # Check if the JSON file request was successful
29
  if response.status_code != 200:
30
+ print(f"Failed to read JSON file {json_url}: Status Code: {response.status_code}")
31
  return None, None, None, None
32
 
33
  json_data = response.json()