Spaces:
Running
Running
Upload 3 files
Browse files- src/plot_utils.py +3 -4
- src/trend_utils.py +88 -66
- src/version_utils.py +1 -1
src/plot_utils.py
CHANGED
@@ -4,7 +4,7 @@ import requests
|
|
4 |
import json
|
5 |
import gradio as gr
|
6 |
|
7 |
-
from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME
|
8 |
from src.leaderboard_utils import get_github_data
|
9 |
|
10 |
|
@@ -131,8 +131,7 @@ def split_models(model_list: list):
|
|
131 |
commercial_models = []
|
132 |
|
133 |
# Load model registry data from main repo
|
134 |
-
|
135 |
-
response = requests.get(model_registry_url)
|
136 |
|
137 |
if response.status_code == 200:
|
138 |
json_data = json.loads(response.text)
|
@@ -149,7 +148,7 @@ def split_models(model_list: list):
|
|
149 |
break
|
150 |
|
151 |
else:
|
152 |
-
print(f"Failed to read JSON file: Status Code : {response.status_code}")
|
153 |
|
154 |
open_models.sort(key=lambda o: o.upper())
|
155 |
commercial_models.sort(key=lambda c: c.upper())
|
|
|
4 |
import json
|
5 |
import gradio as gr
|
6 |
|
7 |
+
from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME, REGISTRY_URL
|
8 |
from src.leaderboard_utils import get_github_data
|
9 |
|
10 |
|
|
|
131 |
commercial_models = []
|
132 |
|
133 |
# Load model registry data from main repo
|
134 |
+
response = requests.get(REGISTRY_URL)
|
|
|
135 |
|
136 |
if response.status_code == 200:
|
137 |
json_data = json.loads(response.text)
|
|
|
148 |
break
|
149 |
|
150 |
else:
|
151 |
+
print(f"Failed to read JSON file: {REGISTRY_URL} Status Code : {response.status_code}")
|
152 |
|
153 |
open_models.sort(key=lambda o: o.upper())
|
154 |
commercial_models.sort(key=lambda c: c.upper())
|
src/trend_utils.py
CHANGED
@@ -13,6 +13,10 @@ from src.leaderboard_utils import get_github_data
|
|
13 |
# Cut-off date from where to start the trendgraph
|
14 |
START_DATE = '2023-06-01'
|
15 |
|
|
|
|
|
|
|
|
|
16 |
def get_param_size(params: str) -> float:
|
17 |
"""Convert parameter size from string to float.
|
18 |
|
@@ -109,21 +113,40 @@ def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip
|
|
109 |
comm_models = populate_list(comm_model_df, comm_dip)
|
110 |
return open_models, comm_models
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
def get_trend_data(
|
114 |
"""Process text data frames to extract model information.
|
115 |
|
116 |
Args:
|
117 |
-
|
118 |
model_registry_data (list): List of dictionaries containing model registry data.
|
119 |
|
120 |
Returns:
|
121 |
pd.DataFrame: DataFrame containing processed model data.
|
122 |
"""
|
123 |
visited = set() # Track models that have been processed
|
124 |
-
result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag'])
|
125 |
|
126 |
-
|
|
|
|
|
|
|
127 |
for i in range(len(df)):
|
128 |
model_name = df['Model'].iloc[i]
|
129 |
if model_name not in visited:
|
@@ -138,10 +161,12 @@ def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
|
|
138 |
est_flag = False
|
139 |
|
140 |
param_size = get_param_size(params)
|
|
|
141 |
new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
|
142 |
-
'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag}
|
143 |
result_df.loc[len(result_df)] = new_data
|
144 |
break
|
|
|
145 |
return result_df # Return the compiled DataFrame
|
146 |
|
147 |
|
@@ -175,12 +200,12 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
175 |
|
176 |
max_clemscore = df['clemscore'].max()
|
177 |
# Convert 'release_date' to datetime
|
178 |
-
df['Release
|
|
|
179 |
# Filter out data before April 2023/START_DATE
|
180 |
-
df = df[df['Release
|
181 |
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
182 |
models_to_display = open_model_list + comm_model_list
|
183 |
-
print(f"open_model_list: {open_model_list}, comm_model_list: {comm_model_list}")
|
184 |
|
185 |
# Create a column to indicate if the model should be labeled
|
186 |
df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
|
@@ -189,38 +214,72 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
189 |
if mobile_view:
|
190 |
df = df[df['model'].isin(models_to_display)]
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
# Add an identifier column to each DataFrame
|
193 |
-
df['Model Type'] = df
|
|
|
|
|
|
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
|
200 |
# Create the scatter plot
|
201 |
fig = px.scatter(df,
|
202 |
-
x="Release
|
203 |
y="clemscore",
|
204 |
-
color="Model Type", # Differentiates the datasets by color
|
|
|
205 |
hover_name="model",
|
206 |
size=marker_size,
|
207 |
size_max=40, # Max size of the circles
|
208 |
template="plotly_white",
|
209 |
hover_data={ # Customize hover information
|
210 |
-
"Release
|
211 |
"clemscore": True, # Show the clemscore
|
212 |
-
"
|
213 |
},
|
214 |
-
custom_data=["model", "Release
|
|
|
215 |
)
|
216 |
|
217 |
fig.update_traces(
|
218 |
-
hovertemplate='Model Name: %{customdata[0]}<br>Release
|
219 |
)
|
220 |
|
221 |
# Sort dataframes for line plotting
|
222 |
-
df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release
|
223 |
-
df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release
|
224 |
|
225 |
## Custom tics for x axis
|
226 |
# Define the start and end dates
|
@@ -236,43 +295,6 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
236 |
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
|
237 |
custom_tickvals = list(custom_ticks.keys())
|
238 |
|
239 |
-
|
240 |
-
for date, version in benchmark_ticks.items():
|
241 |
-
# Find the corresponding update date from benchmark_update based on the version name
|
242 |
-
update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None)
|
243 |
-
|
244 |
-
if update_date:
|
245 |
-
# Add vertical black dotted line for each benchmark_tick date
|
246 |
-
fig.add_shape(
|
247 |
-
go.layout.Shape(
|
248 |
-
type='line',
|
249 |
-
x0=date,
|
250 |
-
x1=date,
|
251 |
-
y0=0,
|
252 |
-
y1=1,
|
253 |
-
yref='paper',
|
254 |
-
line=dict(color='#A9A9A9', dash='dash'), # Black dotted line
|
255 |
-
)
|
256 |
-
)
|
257 |
-
|
258 |
-
# Add hover information across the full y-axis range
|
259 |
-
fig.add_trace(
|
260 |
-
go.Scatter(
|
261 |
-
x=[date]*100,
|
262 |
-
y=list(range(0,100)), # Covers full y-axis range
|
263 |
-
mode='markers',
|
264 |
-
line=dict(color='rgba(255,255,255,0)', width=0), # Fully transparent line
|
265 |
-
hovertext=[
|
266 |
-
f"Version: {version} released on {date.strftime('%d %b %Y')}, last updated on: {update_date.strftime('%d %b %Y')}"
|
267 |
-
for _ in range(100)
|
268 |
-
], # Unique hovertext for all points
|
269 |
-
hoverinfo="text",
|
270 |
-
hoveron='points',
|
271 |
-
showlegend=False
|
272 |
-
)
|
273 |
-
)
|
274 |
-
|
275 |
-
|
276 |
if mobile_view:
|
277 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
278 |
one_month = pd.DateOffset(months=1)
|
@@ -313,20 +335,20 @@ def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '
|
|
313 |
|
314 |
|
315 |
# Add lines connecting the points for open models
|
316 |
-
fig.add_trace(go.Scatter(x=df_open['Release
|
317 |
mode=display_mode, # Include 'text' in the mode
|
318 |
name='Open Models Trendline',
|
319 |
text=df_open['label_model'], # Use label_model for text labels
|
320 |
textposition='top center', # Position of the text labels
|
321 |
-
line=dict(color=
|
322 |
|
323 |
# Add lines connecting the points for commercial models
|
324 |
-
fig.add_trace(go.Scatter(x=df_commercial['Release
|
325 |
mode=display_mode, # Include 'text' in the mode
|
326 |
name='Commercial Models Trendline',
|
327 |
text=df_commercial['label_model'], # Use label_model for text labels
|
328 |
textposition='top center', # Position of the text labels
|
329 |
-
line=dict(color=
|
330 |
|
331 |
|
332 |
# Update layout to ensure text labels are visible
|
@@ -367,7 +389,7 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
367 |
|
368 |
# Check if the JSON file request was successful
|
369 |
if response.status_code != 200:
|
370 |
-
print(f"Failed to read JSON file: Status Code: {response.status_code}")
|
371 |
|
372 |
json_data = response.json()
|
373 |
versions = json_data['versions']
|
@@ -385,8 +407,8 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
385 |
benchmark_ticks = {}
|
386 |
benchmark_update = {}
|
387 |
if benchmark == "Text":
|
388 |
-
|
389 |
-
text_result_df = get_trend_data(
|
390 |
## Get benchmark tickvalues as dates for X-axis
|
391 |
for ver in versions:
|
392 |
if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
|
@@ -398,8 +420,8 @@ def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) ->
|
|
398 |
|
399 |
fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
400 |
else:
|
401 |
-
|
402 |
-
result_df = get_trend_data(
|
403 |
df = result_df
|
404 |
for ver in versions:
|
405 |
if 'multimodal' in ver['version']:
|
|
|
13 |
# Cut-off date from where to start the trendgraph
|
14 |
START_DATE = '2023-06-01'
|
15 |
|
16 |
+
# Graph colours
|
17 |
+
COLOUR_OPEN = 'red'
|
18 |
+
COLOUR_COMM = 'blue'
|
19 |
+
|
20 |
def get_param_size(params: str) -> float:
|
21 |
"""Convert parameter size from string to float.
|
22 |
|
|
|
113 |
comm_models = populate_list(comm_model_df, comm_dip)
|
114 |
return open_models, comm_models
|
115 |
|
116 |
+
# Function to interpolate between two colors
|
117 |
+
def interpolate_color(rank_val, start_color):
|
118 |
+
"""
|
119 |
+
"""
|
120 |
+
if start_color == 'red':
|
121 |
+
hue = 0
|
122 |
+
elif start_color == 'blue':
|
123 |
+
hue = 240
|
124 |
+
else:
|
125 |
+
raise KeyError(f"Invalid color selected for trend graph: {start_color}. Please set either red or blue. Alternatively, set hue value in src.trend_utils.interpolate_colour")
|
126 |
+
|
127 |
+
saturation = rank_val*100
|
128 |
+
value = 70 if rank_val == 1 else 100
|
129 |
+
|
130 |
+
return f"hsv({hue},{saturation},{value})"
|
131 |
+
|
132 |
|
133 |
+
def get_trend_data(text_data: dict, model_registry_data: list) -> pd.DataFrame:
|
134 |
"""Process text data frames to extract model information.
|
135 |
|
136 |
Args:
|
137 |
+
text_data (dict): Dict containing DataFrames and version deatils.
|
138 |
model_registry_data (list): List of dictionaries containing model registry data.
|
139 |
|
140 |
Returns:
|
141 |
pd.DataFrame: DataFrame containing processed model data.
|
142 |
"""
|
143 |
visited = set() # Track models that have been processed
|
144 |
+
result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag', 'version'])
|
145 |
|
146 |
+
text_dfs = text_data['dataframes']
|
147 |
+
for i in range(len(text_dfs)):
|
148 |
+
df = text_dfs[i]
|
149 |
+
version = text_data['version_data'][i]['name']
|
150 |
for i in range(len(df)):
|
151 |
model_name = df['Model'].iloc[i]
|
152 |
if model_name not in visited:
|
|
|
161 |
est_flag = False
|
162 |
|
163 |
param_size = get_param_size(params)
|
164 |
+
|
165 |
new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
|
166 |
+
'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag, 'version': version}
|
167 |
result_df.loc[len(result_df)] = new_data
|
168 |
break
|
169 |
+
|
170 |
return result_df # Return the compiled DataFrame
|
171 |
|
172 |
|
|
|
200 |
|
201 |
max_clemscore = df['clemscore'].max()
|
202 |
# Convert 'release_date' to datetime
|
203 |
+
df['Release Date (Model and & Benchmark Version)'] = pd.to_datetime(df['release_date'], format='ISO8601')
|
204 |
+
|
205 |
# Filter out data before April 2023/START_DATE
|
206 |
+
df = df[df['Release Date (Model and & Benchmark Version)'] >= pd.to_datetime(start_date)]
|
207 |
open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)
|
208 |
models_to_display = open_model_list + comm_model_list
|
|
|
209 |
|
210 |
# Create a column to indicate if the model should be labeled
|
211 |
df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")
|
|
|
214 |
if mobile_view:
|
215 |
df = df[df['model'].isin(models_to_display)]
|
216 |
|
217 |
+
versions = df['version'].unique()
|
218 |
+
version_names = sorted(
|
219 |
+
[ver for ver in versions],
|
220 |
+
key=lambda v: list(map(int, v[1:].split('_')[0].split('.'))),
|
221 |
+
reverse=True
|
222 |
+
)
|
223 |
+
|
224 |
+
version_names = version_names[:3] # Select 3 latest benchmark versions
|
225 |
+
df = df[df['version'].isin(tuple(version_names))]
|
226 |
+
|
227 |
+
rank = 2
|
228 |
+
max_rank = len(version_names)
|
229 |
+
rank_value = {version_names[0]: 1}
|
230 |
+
for ver in version_names:
|
231 |
+
if ver not in rank_value:
|
232 |
+
rank_value[ver] = 1 - (rank-1-(max_rank/15))/(max_rank-1)
|
233 |
+
rank += 1
|
234 |
+
|
235 |
+
df['color_value'] = df.apply(
|
236 |
+
lambda row: rank_value[row['version']],
|
237 |
+
axis=1
|
238 |
+
)
|
239 |
+
|
240 |
# Add an identifier column to each DataFrame
|
241 |
+
df['Model Type & Benchmark Version'] = df.apply(
|
242 |
+
lambda row: f"Open-Weight {row['version']}" if row['open_weight'] else f"Commercial {row['version']}",
|
243 |
+
axis=1
|
244 |
+
)
|
245 |
|
246 |
+
color_map = {}
|
247 |
+
for i in range(len(df)):
|
248 |
+
if df.iloc[i]['Model Type & Benchmark Version'] not in color_map:
|
249 |
+
if df.iloc[i]['open_weight']:
|
250 |
+
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_OPEN)
|
251 |
+
else:
|
252 |
+
color_map[df.iloc[i]['Model Type & Benchmark Version']] = interpolate_color(df.iloc[i]['color_value'], COLOUR_COMM)
|
253 |
|
254 |
+
|
255 |
+
marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float) # Arbitrary sqrt value to scale marker size based on parameter size
|
256 |
|
257 |
# Create the scatter plot
|
258 |
fig = px.scatter(df,
|
259 |
+
x="Release Date (Model and & Benchmark Version)",
|
260 |
y="clemscore",
|
261 |
+
color="Model Type & Benchmark Version", # Differentiates the datasets by color
|
262 |
+
color_discrete_map=color_map, # Map colors to the defined subclasses
|
263 |
hover_name="model",
|
264 |
size=marker_size,
|
265 |
size_max=40, # Max size of the circles
|
266 |
template="plotly_white",
|
267 |
hover_data={ # Customize hover information
|
268 |
+
"Release Date (Model and & Benchmark Version)": True, # Show the Release Date (Model and & Benchmark Version)
|
269 |
"clemscore": True, # Show the clemscore
|
270 |
+
"version": True
|
271 |
},
|
272 |
+
custom_data=["model", "Release Date (Model and & Benchmark Version)", "clemscore", "version"], # Specify custom data columns for hover
|
273 |
+
opacity=0.8
|
274 |
)
|
275 |
|
276 |
fig.update_traces(
|
277 |
+
hovertemplate='Model Name: %{customdata[0]}<br>Release Date (Model and & Benchmark Version): %{customdata[1]}<br>Clemscore: %{customdata[2]}<br>Benchmark Version: %{customdata[3]}<br>'
|
278 |
)
|
279 |
|
280 |
# Sort dataframes for line plotting
|
281 |
+
df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
282 |
+
df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release Date (Model and & Benchmark Version)')
|
283 |
|
284 |
## Custom tics for x axis
|
285 |
# Define the start and end dates
|
|
|
295 |
custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
|
296 |
custom_tickvals = list(custom_ticks.keys())
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
if mobile_view:
|
299 |
# Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
|
300 |
one_month = pd.DateOffset(months=1)
|
|
|
335 |
|
336 |
|
337 |
# Add lines connecting the points for open models
|
338 |
+
fig.add_trace(go.Scatter(x=df_open['Release Date (Model and & Benchmark Version)'], y=df_open['clemscore'],
|
339 |
mode=display_mode, # Include 'text' in the mode
|
340 |
name='Open Models Trendline',
|
341 |
text=df_open['label_model'], # Use label_model for text labels
|
342 |
textposition='top center', # Position of the text labels
|
343 |
+
line=dict(color='red'), showlegend=False))
|
344 |
|
345 |
# Add lines connecting the points for commercial models
|
346 |
+
fig.add_trace(go.Scatter(x=df_commercial['Release Date (Model and & Benchmark Version)'], y=df_commercial['clemscore'],
|
347 |
mode=display_mode, # Include 'text' in the mode
|
348 |
name='Commercial Models Trendline',
|
349 |
text=df_commercial['label_model'], # Use label_model for text labels
|
350 |
textposition='top center', # Position of the text labels
|
351 |
+
line=dict(color='blue'), showlegend=False))
|
352 |
|
353 |
|
354 |
# Update layout to ensure text labels are visible
|
|
|
389 |
|
390 |
# Check if the JSON file request was successful
|
391 |
if response.status_code != 200:
|
392 |
+
print(f"Failed to read JSON file {json_url}: Status Code: {response.status_code}")
|
393 |
|
394 |
json_data = response.json()
|
395 |
versions = json_data['versions']
|
|
|
407 |
benchmark_ticks = {}
|
408 |
benchmark_update = {}
|
409 |
if benchmark == "Text":
|
410 |
+
text_data = get_github_data()['text']
|
411 |
+
text_result_df = get_trend_data(text_data, model_registry_data)
|
412 |
## Get benchmark tickvalues as dates for X-axis
|
413 |
for ver in versions:
|
414 |
if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
|
|
|
420 |
|
421 |
fig = get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
|
422 |
else:
|
423 |
+
mm_data = get_github_data()['multimodal']
|
424 |
+
result_df = get_trend_data(mm_data, model_registry_data)
|
425 |
df = result_df
|
426 |
for ver in versions:
|
427 |
if 'multimodal' in ver['version']:
|
src/version_utils.py
CHANGED
@@ -27,7 +27,7 @@ def get_version_data():
|
|
27 |
|
28 |
# Check if the JSON file request was successful
|
29 |
if response.status_code != 200:
|
30 |
-
print(f"Failed to read JSON file: Status Code: {response.status_code}")
|
31 |
return None, None, None, None
|
32 |
|
33 |
json_data = response.json()
|
|
|
27 |
|
28 |
# Check if the JSON file request was successful
|
29 |
if response.status_code != 200:
|
30 |
+
print(f"Failed to read JSON file {json_url}: Status Code: {response.status_code}")
|
31 |
return None, None, None, None
|
32 |
|
33 |
json_data = response.json()
|