carryover from evijit
Browse files- app.py +468 -788
- models_processed.parquet +3 -0
- preprocess.py +371 -0
app.py
CHANGED
|
@@ -1,846 +1,526 @@
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import gradio as gr
|
| 3 |
import pandas as pd
|
| 4 |
import plotly.express as px
|
| 5 |
import os
|
| 6 |
import numpy as np
|
| 7 |
-
import io
|
| 8 |
import duckdb
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
PIPELINE_TAGS = [
|
| 12 |
-
'text-generation',
|
| 13 |
-
'text-to-image',
|
| 14 |
-
'text-classification',
|
| 15 |
-
'text2text-generation',
|
| 16 |
-
'audio-to-audio',
|
| 17 |
-
'feature-extraction',
|
| 18 |
-
'image-classification',
|
| 19 |
-
'translation',
|
| 20 |
-
'reinforcement-learning',
|
| 21 |
-
'fill-mask',
|
| 22 |
-
'text-to-speech',
|
| 23 |
-
'automatic-speech-recognition',
|
| 24 |
-
'image-text-to-text',
|
| 25 |
-
'token-classification',
|
| 26 |
-
'sentence-similarity',
|
| 27 |
-
'question-answering',
|
| 28 |
-
'image-feature-extraction',
|
| 29 |
-
'summarization',
|
| 30 |
-
'zero-shot-image-classification',
|
| 31 |
-
'object-detection',
|
| 32 |
-
'image-segmentation',
|
| 33 |
-
'image-to-image',
|
| 34 |
-
'image-to-text',
|
| 35 |
-
'audio-classification',
|
| 36 |
-
'visual-question-answering',
|
| 37 |
-
'text-to-video',
|
| 38 |
-
'zero-shot-classification',
|
| 39 |
-
'depth-estimation',
|
| 40 |
-
'text-ranking',
|
| 41 |
-
'image-to-video',
|
| 42 |
-
'multiple-choice',
|
| 43 |
-
'unconditional-image-generation',
|
| 44 |
-
'video-classification',
|
| 45 |
-
'text-to-audio',
|
| 46 |
-
'time-series-forecasting',
|
| 47 |
-
'any-to-any',
|
| 48 |
-
'video-text-to-text',
|
| 49 |
-
'table-question-answering',
|
| 50 |
-
]
|
| 51 |
-
|
| 52 |
-
# Model size categories in GB
|
| 53 |
MODEL_SIZE_RANGES = {
|
| 54 |
-
"Small (<1GB)": (0, 1),
|
| 55 |
-
"
|
| 56 |
-
"Large (5-20GB)": (5, 20),
|
| 57 |
-
"X-Large (20-50GB)": (20, 50),
|
| 58 |
-
"XX-Large (>50GB)": (50, float('inf'))
|
| 59 |
}
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def is_music(row):
|
| 67 |
-
# Use cached column instead of recalculating
|
| 68 |
-
return row['has_music']
|
| 69 |
-
|
| 70 |
-
def is_robotics(row):
|
| 71 |
-
# Use cached column instead of recalculating
|
| 72 |
-
return row['has_robot']
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
def
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
def
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
def
|
| 87 |
-
|
| 88 |
-
return row['has_video']
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
if
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
return "
|
| 114 |
-
return False
|
| 115 |
-
|
| 116 |
-
def is_text(row):
|
| 117 |
-
tags = row.get("tags", [])
|
| 118 |
|
| 119 |
-
|
| 120 |
-
if
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
return True
|
| 172 |
|
| 173 |
-
#
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
"
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
"
|
| 185 |
-
|
| 186 |
-
|
| 187 |
|
| 188 |
-
def extract_org_from_id(model_id):
|
| 189 |
-
"""Extract organization name from model ID"""
|
| 190 |
-
if "/" in model_id:
|
| 191 |
-
return model_id.split("/")[0]
|
| 192 |
-
return "unaffiliated"
|
| 193 |
|
| 194 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
| 195 |
-
|
| 196 |
-
# Create a copy to avoid modifying the original
|
| 197 |
filtered_df = df.copy()
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
filtered_df = filtered_df[filtered_df[
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
elif tag_filter == "Video":
|
| 221 |
-
filtered_df = filtered_df[filtered_df['has_video']]
|
| 222 |
-
elif tag_filter == "Images":
|
| 223 |
-
filtered_df = filtered_df[filtered_df['has_image']]
|
| 224 |
-
elif tag_filter == "Text":
|
| 225 |
-
filtered_df = filtered_df[filtered_df['has_text']]
|
| 226 |
-
|
| 227 |
-
filter_stats["after_tag_filter"] = len(filtered_df)
|
| 228 |
-
print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
| 229 |
-
start_time = pd.Timestamp.now()
|
| 230 |
-
|
| 231 |
-
# Apply pipeline filter
|
| 232 |
if pipeline_filter:
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
# Use the cached size_category column directly
|
| 244 |
-
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
| 245 |
-
|
| 246 |
-
# Debug info
|
| 247 |
-
print(f"Size filter '{size_filter}' applied.")
|
| 248 |
-
print(f"Models after size filter: {len(filtered_df)}")
|
| 249 |
-
|
| 250 |
-
filter_stats["after_size_filter"] = len(filtered_df)
|
| 251 |
-
print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
| 252 |
-
start_time = pd.Timestamp.now()
|
| 253 |
-
|
| 254 |
-
# Add organization column
|
| 255 |
-
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
| 256 |
-
|
| 257 |
-
# Skip organizations if specified
|
| 258 |
if skip_orgs and len(skip_orgs) > 0:
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
print("Warning: No data left after applying filters!")
|
| 270 |
-
return pd.DataFrame() # Return empty DataFrame
|
| 271 |
-
|
| 272 |
-
# Aggregate by organization
|
| 273 |
-
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
|
| 274 |
-
org_totals = org_totals.sort_values(by=count_by, ascending=False)
|
| 275 |
-
|
| 276 |
-
# Get top organizations
|
| 277 |
-
top_orgs = org_totals.head(top_k)["organization"].tolist()
|
| 278 |
-
|
| 279 |
-
# Filter to only include models from top organizations
|
| 280 |
-
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
|
| 281 |
-
|
| 282 |
-
# Prepare data for treemap
|
| 283 |
-
treemap_data = filtered_df[["id", "organization", count_by]].copy()
|
| 284 |
-
|
| 285 |
-
# Add a root node
|
| 286 |
treemap_data["root"] = "models"
|
| 287 |
-
|
| 288 |
-
# Ensure numeric values
|
| 289 |
-
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
|
| 290 |
-
|
| 291 |
-
print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds")
|
| 292 |
return treemap_data
|
| 293 |
|
| 294 |
def create_treemap(treemap_data, count_by, title=None):
|
| 295 |
-
"""Create a Plotly treemap from the prepared data"""
|
| 296 |
if treemap_data.empty:
|
| 297 |
-
|
| 298 |
-
fig =
|
| 299 |
-
names=["No data matches the selected filters"],
|
| 300 |
-
values=[1]
|
| 301 |
-
)
|
| 302 |
-
fig.update_layout(
|
| 303 |
-
title="No data matches the selected filters",
|
| 304 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
| 305 |
-
)
|
| 306 |
return fig
|
| 307 |
-
|
| 308 |
-
# Create the treemap
|
| 309 |
fig = px.treemap(
|
| 310 |
-
treemap_data,
|
| 311 |
-
path=["root", "organization", "id"],
|
| 312 |
-
values=count_by,
|
| 313 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
| 314 |
color_discrete_sequence=px.colors.qualitative.Plotly
|
| 315 |
)
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
fig.update_layout(
|
| 319 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
# Update traces for better readability
|
| 323 |
-
fig.update_traces(
|
| 324 |
-
textinfo="label+value+percent root",
|
| 325 |
-
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
return fig
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
try:
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
df = duckdb.sql(query).df()
|
| 353 |
-
except Exception as sql_error:
|
| 354 |
-
print(f"Error with specific column selection: {sql_error}")
|
| 355 |
-
# Fallback to just selecting everything and then filtering
|
| 356 |
-
print("Falling back to select * query...")
|
| 357 |
-
query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')"
|
| 358 |
-
raw_df = duckdb.sql(query).df()
|
| 359 |
-
|
| 360 |
-
# Now extract only the columns we need
|
| 361 |
-
needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors']
|
| 362 |
-
available_columns = set(raw_df.columns)
|
| 363 |
-
df = pd.DataFrame()
|
| 364 |
-
|
| 365 |
-
# Copy over columns that exist
|
| 366 |
-
for col in needed_columns:
|
| 367 |
-
if col in available_columns:
|
| 368 |
-
df[col] = raw_df[col]
|
| 369 |
else:
|
| 370 |
-
|
| 371 |
-
if col in ['downloads', 'downloadsAllTime', 'likes']:
|
| 372 |
-
df[col] = 0
|
| 373 |
-
elif col == 'pipeline_tag':
|
| 374 |
-
df[col] = ''
|
| 375 |
-
elif col == 'tags':
|
| 376 |
-
df[col] = [[] for _ in range(len(raw_df))]
|
| 377 |
-
elif col == 'safetensors':
|
| 378 |
-
df[col] = None
|
| 379 |
-
elif col == 'id':
|
| 380 |
-
# Create IDs based on index if missing
|
| 381 |
-
df[col] = [f"model_{i}" for i in range(len(raw_df))]
|
| 382 |
-
|
| 383 |
-
print(f"Data fetched successfully. Shape: {df.shape}")
|
| 384 |
-
|
| 385 |
-
# Check if safetensors column exists before trying to process it
|
| 386 |
-
if 'safetensors' in df.columns:
|
| 387 |
-
# Add params column derived from safetensors.total (model size in GB)
|
| 388 |
-
df['params'] = df['safetensors'].apply(extract_model_size)
|
| 389 |
-
|
| 390 |
-
# Debug model sizes
|
| 391 |
-
size_ranges = {
|
| 392 |
-
"Small (<1GB)": 0,
|
| 393 |
-
"Medium (1-5GB)": 0,
|
| 394 |
-
"Large (5-20GB)": 0,
|
| 395 |
-
"X-Large (20-50GB)": 0,
|
| 396 |
-
"XX-Large (>50GB)": 0
|
| 397 |
-
}
|
| 398 |
-
|
| 399 |
-
# Count models in each size range
|
| 400 |
-
for idx, row in df.iterrows():
|
| 401 |
-
size_gb = row['params']
|
| 402 |
-
if 0 <= size_gb < 1:
|
| 403 |
-
size_ranges["Small (<1GB)"] += 1
|
| 404 |
-
elif 1 <= size_gb < 5:
|
| 405 |
-
size_ranges["Medium (1-5GB)"] += 1
|
| 406 |
-
elif 5 <= size_gb < 20:
|
| 407 |
-
size_ranges["Large (5-20GB)"] += 1
|
| 408 |
-
elif 20 <= size_gb < 50:
|
| 409 |
-
size_ranges["X-Large (20-50GB)"] += 1
|
| 410 |
-
elif size_gb >= 50:
|
| 411 |
-
size_ranges["XX-Large (>50GB)"] += 1
|
| 412 |
-
|
| 413 |
-
print("Model size distribution:")
|
| 414 |
-
for size_range, count in size_ranges.items():
|
| 415 |
-
print(f" {size_range}: {count} models")
|
| 416 |
-
|
| 417 |
-
# CACHE SIZE CATEGORY: Add a size_category column for faster filtering
|
| 418 |
-
def get_size_category(size_gb):
|
| 419 |
-
if 0 <= size_gb < 1:
|
| 420 |
-
return "Small (<1GB)"
|
| 421 |
-
elif 1 <= size_gb < 5:
|
| 422 |
-
return "Medium (1-5GB)"
|
| 423 |
-
elif 5 <= size_gb < 20:
|
| 424 |
-
return "Large (5-20GB)"
|
| 425 |
-
elif 20 <= size_gb < 50:
|
| 426 |
-
return "X-Large (20-50GB)"
|
| 427 |
-
elif size_gb >= 50:
|
| 428 |
-
return "XX-Large (>50GB)"
|
| 429 |
-
return None
|
| 430 |
-
|
| 431 |
-
# Add cached size category column
|
| 432 |
-
df['size_category'] = df['params'].apply(get_size_category)
|
| 433 |
-
|
| 434 |
-
# Remove the safetensors column as we don't need it anymore
|
| 435 |
-
df = df.drop(columns=['safetensors'])
|
| 436 |
-
else:
|
| 437 |
-
# If no safetensors column, add empty params column
|
| 438 |
-
df['params'] = 0
|
| 439 |
-
df['size_category'] = None
|
| 440 |
-
|
| 441 |
-
# Process tags to ensure it's in the right format - FIXED
|
| 442 |
-
def process_tags(tags_value):
|
| 443 |
-
try:
|
| 444 |
-
if pd.isna(tags_value) or tags_value is None:
|
| 445 |
-
return []
|
| 446 |
-
|
| 447 |
-
# If it's a numpy array, convert to a list of strings
|
| 448 |
-
if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'):
|
| 449 |
-
# Note: This is the fix for the error
|
| 450 |
-
return [str(tag) for tag in tags_value.tolist()]
|
| 451 |
-
|
| 452 |
-
# If already a list, ensure all elements are strings
|
| 453 |
-
if isinstance(tags_value, list):
|
| 454 |
-
return [str(tag) for tag in tags_value]
|
| 455 |
|
| 456 |
-
|
| 457 |
-
if
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
# Split by comma if JSON parsing fails
|
| 464 |
-
return [tag.strip() for tag in tags_value.split(',') if tag.strip()]
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
-
#
|
| 474 |
-
if '
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
# CACHE TAG CATEGORIES: Pre-calculate tag categories for faster filtering
|
| 479 |
-
print("Pre-calculating cached tag categories...")
|
| 480 |
-
|
| 481 |
-
# Helper functions to check for specific tags (simplified for caching)
|
| 482 |
-
def has_audio_tag(tags):
|
| 483 |
-
if tags and isinstance(tags, list):
|
| 484 |
-
return any("audio" in str(tag).lower() for tag in tags)
|
| 485 |
-
return False
|
| 486 |
-
|
| 487 |
-
def has_speech_tag(tags):
|
| 488 |
-
if tags and isinstance(tags, list):
|
| 489 |
-
return any("speech" in str(tag).lower() for tag in tags)
|
| 490 |
-
return False
|
| 491 |
-
|
| 492 |
-
def has_music_tag(tags):
|
| 493 |
-
if tags and isinstance(tags, list):
|
| 494 |
-
return any("music" in str(tag).lower() for tag in tags)
|
| 495 |
-
return False
|
| 496 |
-
|
| 497 |
-
def has_robot_tag(tags):
|
| 498 |
-
if tags and isinstance(tags, list):
|
| 499 |
-
return any("robot" in str(tag).lower() for tag in tags)
|
| 500 |
-
return False
|
| 501 |
-
|
| 502 |
-
def has_bio_tag(tags):
|
| 503 |
-
if tags and isinstance(tags, list):
|
| 504 |
-
return any("bio" in str(tag).lower() for tag in tags)
|
| 505 |
-
return False
|
| 506 |
-
|
| 507 |
-
def has_med_tag(tags):
|
| 508 |
-
if tags and isinstance(tags, list):
|
| 509 |
-
return any("medic" in str(tag).lower() for tag in tags)
|
| 510 |
-
return False
|
| 511 |
-
|
| 512 |
-
def has_series_tag(tags):
|
| 513 |
-
if tags and isinstance(tags, list):
|
| 514 |
-
return any("series" in str(tag).lower() for tag in tags)
|
| 515 |
-
return False
|
| 516 |
-
|
| 517 |
-
def has_science_tag(tags):
|
| 518 |
-
if tags and isinstance(tags, list):
|
| 519 |
-
return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags)
|
| 520 |
-
return False
|
| 521 |
-
|
| 522 |
-
def has_video_tag(tags):
|
| 523 |
-
if tags and isinstance(tags, list):
|
| 524 |
-
return any("video" in str(tag).lower() for tag in tags)
|
| 525 |
-
return False
|
| 526 |
-
|
| 527 |
-
def has_image_tag(tags):
|
| 528 |
-
if tags and isinstance(tags, list):
|
| 529 |
-
return any("image" in str(tag).lower() for tag in tags)
|
| 530 |
-
return False
|
| 531 |
-
|
| 532 |
-
def has_text_tag(tags):
|
| 533 |
-
if tags and isinstance(tags, list):
|
| 534 |
-
return any("text" in str(tag).lower() for tag in tags)
|
| 535 |
-
return False
|
| 536 |
-
|
| 537 |
-
# Add cached columns for tag categories
|
| 538 |
-
print("Creating cached tag columns...")
|
| 539 |
-
df['has_audio'] = df['tags'].apply(has_audio_tag)
|
| 540 |
-
df['has_speech'] = df['tags'].apply(has_speech_tag)
|
| 541 |
-
df['has_music'] = df['tags'].apply(has_music_tag)
|
| 542 |
-
df['has_robot'] = df['tags'].apply(has_robot_tag)
|
| 543 |
-
df['has_bio'] = df['tags'].apply(has_bio_tag)
|
| 544 |
-
df['has_med'] = df['tags'].apply(has_med_tag)
|
| 545 |
-
df['has_series'] = df['tags'].apply(has_series_tag)
|
| 546 |
-
df['has_science'] = df['tags'].apply(has_science_tag)
|
| 547 |
-
df['has_video'] = df['tags'].apply(has_video_tag)
|
| 548 |
-
df['has_image'] = df['tags'].apply(has_image_tag)
|
| 549 |
-
df['has_text'] = df['tags'].apply(has_text_tag)
|
| 550 |
-
|
| 551 |
-
# Create combined category flags for faster filtering
|
| 552 |
-
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
| 553 |
-
df['pipeline_tag'].str.contains('audio', case=False, na=False) |
|
| 554 |
-
df['pipeline_tag'].str.contains('speech', case=False, na=False))
|
| 555 |
-
df['is_biomed'] = df['has_bio'] | df['has_med']
|
| 556 |
-
|
| 557 |
-
print("Cached tag columns created successfully!")
|
| 558 |
-
else:
|
| 559 |
-
# If no tags column, add empty tags and set all category flags to False
|
| 560 |
-
df['tags'] = [[] for _ in range(len(df))]
|
| 561 |
-
for col in ['has_audio', 'has_speech', 'has_music', 'has_robot',
|
| 562 |
-
'has_bio', 'has_med', 'has_series', 'has_science',
|
| 563 |
-
'has_video', 'has_image', 'has_text',
|
| 564 |
-
'is_audio_speech', 'is_biomed']:
|
| 565 |
-
df[col] = False
|
| 566 |
-
|
| 567 |
-
# Fill NaN values
|
| 568 |
-
df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True)
|
| 569 |
-
|
| 570 |
-
# Ensure pipeline_tag is a string
|
| 571 |
-
if 'pipeline_tag' in df.columns:
|
| 572 |
-
df['pipeline_tag'] = df['pipeline_tag'].fillna('')
|
| 573 |
-
else:
|
| 574 |
-
df['pipeline_tag'] = ''
|
| 575 |
-
|
| 576 |
-
# Make sure all required columns exist
|
| 577 |
-
for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']:
|
| 578 |
-
if col not in df.columns:
|
| 579 |
-
if col in ['downloads', 'downloadsAllTime', 'likes', 'params']:
|
| 580 |
-
df[col] = 0
|
| 581 |
-
elif col == 'pipeline_tag':
|
| 582 |
-
df[col] = ''
|
| 583 |
-
elif col == 'tags':
|
| 584 |
-
df[col] = [[] for _ in range(len(df))]
|
| 585 |
-
elif col == 'id':
|
| 586 |
-
df[col] = [f"model_{i}" for i in range(len(df))]
|
| 587 |
-
|
| 588 |
-
print(f"Successfully processed {len(df)} models with cached tag and size information")
|
| 589 |
-
return df, True
|
| 590 |
-
|
| 591 |
-
except Exception as e:
|
| 592 |
-
print(f"Error loading data: {e}")
|
| 593 |
-
# Return an empty DataFrame and False to indicate loading failure
|
| 594 |
-
return pd.DataFrame(), False
|
| 595 |
-
|
| 596 |
-
# Create Gradio interface
|
| 597 |
-
with gr.Blocks() as demo:
|
| 598 |
-
models_data = gr.State()
|
| 599 |
-
loading_complete = gr.State(False) # Flag to indicate data load completion
|
| 600 |
-
|
| 601 |
-
with gr.Row():
|
| 602 |
-
gr.Markdown("""
|
| 603 |
-
# HuggingFace Models TreeMap Visualization
|
| 604 |
-
|
| 605 |
-
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
| 606 |
-
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
""")
|
| 611 |
-
|
| 612 |
-
with gr.Row():
|
| 613 |
-
with gr.Column(scale=1):
|
| 614 |
-
count_by_dropdown = gr.Dropdown(
|
| 615 |
-
label="Metric",
|
| 616 |
-
choices=[
|
| 617 |
-
("Downloads (last 30 days)", "downloads"),
|
| 618 |
-
("Downloads (All Time)", "downloadsAllTime"),
|
| 619 |
-
("Likes", "likes")
|
| 620 |
-
],
|
| 621 |
-
value="downloads",
|
| 622 |
-
info="Select the metric to determine box sizes"
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
filter_choice_radio = gr.Radio(
|
| 626 |
-
label="Filter Type",
|
| 627 |
-
choices=["None", "Tag Filter", "Pipeline Filter"],
|
| 628 |
-
value="None",
|
| 629 |
-
info="Choose how to filter the models"
|
| 630 |
-
)
|
| 631 |
-
|
| 632 |
-
tag_filter_dropdown = gr.Dropdown(
|
| 633 |
-
label="Select Tag",
|
| 634 |
-
choices=list(TAG_FILTER_FUNCS.keys()),
|
| 635 |
-
value=None,
|
| 636 |
-
visible=False,
|
| 637 |
-
info="Filter models by domain/category"
|
| 638 |
-
)
|
| 639 |
-
|
| 640 |
-
pipeline_filter_dropdown = gr.Dropdown(
|
| 641 |
-
label="Select Pipeline Tag",
|
| 642 |
-
choices=PIPELINE_TAGS,
|
| 643 |
-
value=None,
|
| 644 |
-
visible=False,
|
| 645 |
-
info="Filter models by specific pipeline"
|
| 646 |
-
)
|
| 647 |
-
|
| 648 |
-
size_filter_dropdown = gr.Dropdown(
|
| 649 |
-
label="Model Size Filter",
|
| 650 |
-
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
| 651 |
-
value="None",
|
| 652 |
-
info="Filter models by their size (using params column)"
|
| 653 |
-
)
|
| 654 |
-
|
| 655 |
-
top_k_slider = gr.Slider(
|
| 656 |
-
label="Number of Top Organizations",
|
| 657 |
-
minimum=5,
|
| 658 |
-
maximum=50,
|
| 659 |
-
value=25,
|
| 660 |
-
step=5,
|
| 661 |
-
info="Number of top organizations to include"
|
| 662 |
-
)
|
| 663 |
-
|
| 664 |
-
skip_orgs_textbox = gr.Textbox(
|
| 665 |
-
label="Organizations to Skip (comma-separated)",
|
| 666 |
-
placeholder="e.g., OpenAI, Google",
|
| 667 |
-
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski"
|
| 668 |
-
)
|
| 669 |
-
|
| 670 |
-
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False)
|
| 671 |
-
refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary")
|
| 672 |
-
|
| 673 |
-
with gr.Column(scale=3):
|
| 674 |
-
plot_output = gr.Plot()
|
| 675 |
-
stats_output = gr.Markdown("*Loading data from Hugging Face...*")
|
| 676 |
-
data_info = gr.Markdown("")
|
| 677 |
-
|
| 678 |
-
# Button enablement after data load
|
| 679 |
-
def enable_plot_button(loaded):
|
| 680 |
-
return gr.update(interactive=loaded)
|
| 681 |
-
|
| 682 |
-
loading_complete.change(
|
| 683 |
-
fn=enable_plot_button,
|
| 684 |
-
inputs=[loading_complete],
|
| 685 |
-
outputs=[generate_plot_button]
|
| 686 |
-
)
|
| 687 |
-
|
| 688 |
-
# Show/hide tag/pipeline dropdown
|
| 689 |
-
def update_filter_visibility(filter_choice):
|
| 690 |
-
if filter_choice == "Tag Filter":
|
| 691 |
-
return gr.update(visible=True), gr.update(visible=False)
|
| 692 |
-
elif filter_choice == "Pipeline Filter":
|
| 693 |
-
return gr.update(visible=False), gr.update(visible=True)
|
| 694 |
-
else:
|
| 695 |
-
return gr.update(visible=False), gr.update(visible=False)
|
| 696 |
-
|
| 697 |
-
filter_choice_radio.change(
|
| 698 |
-
fn=update_filter_visibility,
|
| 699 |
-
inputs=[filter_choice_radio],
|
| 700 |
-
outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
|
| 701 |
-
)
|
| 702 |
-
|
| 703 |
-
# Function to handle data load and provide data info
|
| 704 |
-
def load_and_provide_info():
|
| 705 |
-
df, success = load_models_data()
|
| 706 |
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
- **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 713 |
-
- **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet)
|
| 714 |
-
"""
|
| 715 |
-
|
| 716 |
-
# Return the data, loading status, and info text
|
| 717 |
-
return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*"
|
| 718 |
-
else:
|
| 719 |
-
# Return empty data, failed loading status, and error message
|
| 720 |
-
return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*"
|
| 721 |
-
|
| 722 |
-
# Main generate function
|
| 723 |
-
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
|
| 724 |
-
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty:
|
| 725 |
-
return None, "Error: Data is still loading. Please wait a moment and try again."
|
| 726 |
-
|
| 727 |
-
selected_tag_filter = None
|
| 728 |
-
selected_pipeline_filter = None
|
| 729 |
-
selected_size_filter = None
|
| 730 |
-
|
| 731 |
-
if filter_choice == "Tag Filter":
|
| 732 |
-
selected_tag_filter = tag_filter
|
| 733 |
-
elif filter_choice == "Pipeline Filter":
|
| 734 |
-
selected_pipeline_filter = pipeline_filter
|
| 735 |
-
|
| 736 |
-
if size_filter != "None":
|
| 737 |
-
selected_size_filter = size_filter
|
| 738 |
-
|
| 739 |
-
skip_orgs = []
|
| 740 |
-
if skip_orgs_text and skip_orgs_text.strip():
|
| 741 |
-
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
| 742 |
-
|
| 743 |
-
treemap_data = make_treemap_data(
|
| 744 |
-
df=data_df,
|
| 745 |
-
count_by=count_by,
|
| 746 |
-
top_k=top_k,
|
| 747 |
-
tag_filter=selected_tag_filter,
|
| 748 |
-
pipeline_filter=selected_pipeline_filter,
|
| 749 |
-
size_filter=selected_size_filter,
|
| 750 |
-
skip_orgs=skip_orgs
|
| 751 |
-
)
|
| 752 |
-
|
| 753 |
-
title_labels = {
|
| 754 |
-
"downloads": "Downloads (last 30 days)",
|
| 755 |
-
"downloadsAllTime": "Downloads (All Time)",
|
| 756 |
-
"likes": "Likes"
|
| 757 |
-
}
|
| 758 |
-
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization"
|
| 759 |
-
|
| 760 |
-
fig = create_treemap(
|
| 761 |
-
treemap_data=treemap_data,
|
| 762 |
-
count_by=count_by,
|
| 763 |
-
title=title_text
|
| 764 |
-
)
|
| 765 |
-
|
| 766 |
-
if treemap_data.empty:
|
| 767 |
-
stats_md = "No data matches the selected filters."
|
| 768 |
else:
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
| 774 |
-
|
| 775 |
-
# Get top 5 individual models
|
| 776 |
-
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5)
|
| 777 |
-
|
| 778 |
-
# Create statistics section
|
| 779 |
-
stats_md = f"""
|
| 780 |
-
## Statistics
|
| 781 |
-
- **Total models shown**: {total_models:,}
|
| 782 |
-
- **Total {count_by}**: {int(total_value):,}
|
| 783 |
-
|
| 784 |
-
## Top Organizations by {count_by.capitalize()}
|
| 785 |
|
| 786 |
-
| Organization | {count_by.capitalize()} | % of Total |
|
| 787 |
-
|--------------|-------------:|----------:|
|
| 788 |
-
"""
|
| 789 |
-
|
| 790 |
-
# Add top organizations to the table
|
| 791 |
-
for org, value in top_5_orgs.items():
|
| 792 |
-
percentage = (value / total_value) * 100
|
| 793 |
-
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n"
|
| 794 |
-
|
| 795 |
-
# Add the top models table
|
| 796 |
-
stats_md += f"""
|
| 797 |
-
## Top Models by {count_by.capitalize()}
|
| 798 |
-
|
| 799 |
-
| Model | {count_by.capitalize()} | % of Total |
|
| 800 |
-
|-------|-------------:|----------:|
|
| 801 |
-
"""
|
| 802 |
-
|
| 803 |
-
# Add top models to the table
|
| 804 |
-
for _, row in top_5_models.iterrows():
|
| 805 |
-
model_id = row["id"]
|
| 806 |
-
value = row[count_by]
|
| 807 |
-
percentage = (value / total_value) * 100
|
| 808 |
-
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n"
|
| 809 |
-
|
| 810 |
-
# Add note about skipped organizations if any
|
| 811 |
-
if skip_orgs:
|
| 812 |
-
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
|
| 813 |
-
|
| 814 |
-
return fig, stats_md
|
| 815 |
-
|
| 816 |
-
# Load data at startup
|
| 817 |
demo.load(
|
| 818 |
-
fn=
|
| 819 |
-
inputs=[],
|
| 820 |
-
outputs=[
|
| 821 |
)
|
| 822 |
-
|
| 823 |
-
# Refresh data when button is clicked
|
| 824 |
refresh_data_button.click(
|
| 825 |
-
fn=
|
| 826 |
-
inputs=[],
|
| 827 |
-
outputs=[
|
| 828 |
)
|
| 829 |
-
|
| 830 |
generate_plot_button.click(
|
| 831 |
-
fn=
|
| 832 |
-
inputs=[
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
tag_filter_dropdown,
|
| 836 |
-
pipeline_filter_dropdown,
|
| 837 |
-
size_filter_dropdown,
|
| 838 |
-
top_k_slider,
|
| 839 |
-
skip_orgs_textbox,
|
| 840 |
-
models_data
|
| 841 |
-
],
|
| 842 |
-
outputs=[plot_output, stats_output]
|
| 843 |
)
|
| 844 |
|
| 845 |
if __name__ == "__main__":
|
| 846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- START OF FILE app.py ---
|
| 2 |
+
|
| 3 |
import json
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
import plotly.express as px
|
| 7 |
import os
|
| 8 |
import numpy as np
|
|
|
|
| 9 |
import duckdb
|
| 10 |
+
from tqdm.auto import tqdm # Standard tqdm for console, gr.Progress will track it
|
| 11 |
+
import time
|
| 12 |
+
import ast # For safely evaluating string representations of lists/dicts
|
| 13 |
|
| 14 |
+
# --- Constants ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
MODEL_SIZE_RANGES = {
|
| 16 |
+
"Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
|
| 17 |
+
"X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
|
|
|
|
|
|
|
|
|
|
| 18 |
}
|
| 19 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
| 20 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet' # Added for completeness within app.py context
|
| 21 |
|
| 22 |
+
TAG_FILTER_CHOICES = [
|
| 23 |
+
"Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
|
| 24 |
+
"Text", "Biomedical", "Sciences"
|
| 25 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
PIPELINE_TAGS = [
|
| 28 |
+
'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
|
| 29 |
+
'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
|
| 30 |
+
'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
|
| 31 |
+
'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
|
| 32 |
+
'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
|
| 33 |
+
'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
|
| 34 |
+
'audio-classification', 'visual-question-answering', 'text-to-video',
|
| 35 |
+
'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
|
| 36 |
+
'multiple-choice', 'unconditional-image-generation', 'video-classification',
|
| 37 |
+
'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
|
| 38 |
+
'table-question-answering',
|
| 39 |
+
]
|
| 40 |
|
| 41 |
+
def extract_model_size(safetensors_data):
|
| 42 |
+
try:
|
| 43 |
+
if pd.isna(safetensors_data): return 0.0
|
| 44 |
+
data_to_parse = safetensors_data
|
| 45 |
+
if isinstance(safetensors_data, str):
|
| 46 |
+
try:
|
| 47 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
| 48 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
| 49 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
| 50 |
+
else: data_to_parse = json.loads(safetensors_data)
|
| 51 |
+
except: return 0.0
|
| 52 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
| 53 |
+
try:
|
| 54 |
+
total_bytes_val = data_to_parse['total']
|
| 55 |
+
size_bytes = float(total_bytes_val)
|
| 56 |
+
return size_bytes / (1024 * 1024 * 1024)
|
| 57 |
+
except (ValueError, TypeError): pass
|
| 58 |
+
return 0.0
|
| 59 |
+
except: return 0.0
|
| 60 |
|
| 61 |
+
def extract_org_from_id(model_id):
|
| 62 |
+
if pd.isna(model_id): return "unaffiliated"
|
| 63 |
+
model_id_str = str(model_id)
|
| 64 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
| 65 |
|
| 66 |
+
def process_tags_for_series(series_of_tags_values):
|
| 67 |
+
processed_tags_accumulator = []
|
|
|
|
| 68 |
|
| 69 |
+
for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
|
| 70 |
+
temp_processed_list_for_row = []
|
| 71 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
|
| 72 |
|
| 73 |
+
try:
|
| 74 |
+
# Order of checks is important!
|
| 75 |
+
# 1. Handle explicit Python lists first
|
| 76 |
+
if isinstance(tags_value_from_series, list):
|
| 77 |
+
current_tags_in_list = []
|
| 78 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series):
|
| 79 |
+
try:
|
| 80 |
+
# Ensure item is not NaN before string conversion if it might be a float NaN in a list
|
| 81 |
+
if pd.isna(tag_item): continue
|
| 82 |
+
str_tag = str(tag_item)
|
| 83 |
+
stripped_tag = str_tag.strip()
|
| 84 |
+
if stripped_tag:
|
| 85 |
+
current_tags_in_list.append(stripped_tag)
|
| 86 |
+
except Exception as e_inner_list_proc:
|
| 87 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
|
| 88 |
+
temp_processed_list_for_row = current_tags_in_list
|
| 89 |
+
|
| 90 |
+
# 2. Handle NumPy arrays
|
| 91 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
| 92 |
+
# Convert to list, then process elements, handling potential NaNs within the array
|
| 93 |
+
current_tags_in_list = []
|
| 94 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
|
| 95 |
+
try:
|
| 96 |
+
if pd.isna(tag_item): continue # Check for NaN after converting to Python type
|
| 97 |
+
str_tag = str(tag_item)
|
| 98 |
+
stripped_tag = str_tag.strip()
|
| 99 |
+
if stripped_tag:
|
| 100 |
+
current_tags_in_list.append(stripped_tag)
|
| 101 |
+
except Exception as e_inner_array_proc:
|
| 102 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
|
| 103 |
+
temp_processed_list_for_row = current_tags_in_list
|
| 104 |
+
|
| 105 |
+
# 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
|
| 106 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
|
| 107 |
+
temp_processed_list_for_row = []
|
| 108 |
+
|
| 109 |
+
# 4. Handle strings (could be JSON-like, list-like, or comma-separated)
|
| 110 |
+
elif isinstance(tags_value_from_series, str):
|
| 111 |
+
processed_str_tags = []
|
| 112 |
+
# Attempt ast.literal_eval for strings that look like lists/tuples
|
| 113 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
| 114 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
| 115 |
+
try:
|
| 116 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
| 117 |
+
if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
|
| 118 |
+
# Recursively process this evaluated list/tuple, as its elements could be complex
|
| 119 |
+
# For simplicity here, assume elements are simple strings after eval
|
| 120 |
+
current_eval_list = []
|
| 121 |
+
for tag_item in evaluated_tags:
|
| 122 |
+
if pd.isna(tag_item): continue
|
| 123 |
+
str_tag = str(tag_item).strip()
|
| 124 |
+
if str_tag: current_eval_list.append(str_tag)
|
| 125 |
+
processed_str_tags = current_eval_list
|
| 126 |
+
except (ValueError, SyntaxError):
|
| 127 |
+
pass # If ast.literal_eval fails, let it fall to JSON or comma split
|
| 128 |
+
|
| 129 |
+
# If ast.literal_eval didn't populate, try JSON
|
| 130 |
+
if not processed_str_tags:
|
| 131 |
+
try:
|
| 132 |
+
json_tags = json.loads(tags_value_from_series)
|
| 133 |
+
if isinstance(json_tags, list):
|
| 134 |
+
# Similar to above, assume elements are simple strings after JSON parsing
|
| 135 |
+
current_json_list = []
|
| 136 |
+
for tag_item in json_tags:
|
| 137 |
+
if pd.isna(tag_item): continue
|
| 138 |
+
str_tag = str(tag_item).strip()
|
| 139 |
+
if str_tag: current_json_list.append(str_tag)
|
| 140 |
+
processed_str_tags = current_json_list
|
| 141 |
+
except json.JSONDecodeError:
|
| 142 |
+
# If not a valid JSON list, fall back to comma splitting as the final string strategy
|
| 143 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
| 144 |
+
except Exception as e_json_other:
|
| 145 |
+
print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
| 146 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
|
| 147 |
+
|
| 148 |
+
temp_processed_list_for_row = processed_str_tags
|
| 149 |
+
|
| 150 |
+
# 5. Fallback for other scalar types (e.g., int, float that are not NaN)
|
| 151 |
+
else:
|
| 152 |
+
# This path is for non-list, non-ndarray, non-None/NaN, non-string types.
|
| 153 |
+
# Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
|
| 154 |
+
if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
|
| 155 |
+
temp_processed_list_for_row = []
|
| 156 |
+
else:
|
| 157 |
+
str_val = str(tags_value_from_series).strip()
|
| 158 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
| 159 |
+
|
| 160 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
| 161 |
|
| 162 |
+
except Exception as e_outer_tag_proc:
|
| 163 |
+
print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
| 164 |
+
processed_tags_accumulator.append([])
|
| 165 |
+
|
| 166 |
+
return processed_tags_accumulator
|
| 167 |
+
|
| 168 |
+
def load_models_data(force_refresh=False, tqdm_cls=None):
|
| 169 |
+
if tqdm_cls is None: tqdm_cls = tqdm
|
| 170 |
+
overall_start_time = time.time()
|
| 171 |
+
print(f"Gradio load_models_data called with force_refresh={force_refresh}")
|
| 172 |
+
|
| 173 |
+
expected_cols_in_processed_parquet = [
|
| 174 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
|
| 175 |
+
'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
|
| 176 |
+
'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
|
| 177 |
+
'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
|
| 178 |
+
'data_download_timestamp'
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
if not force_refresh and os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
| 182 |
+
print(f"Attempting to load pre-processed data from: {PROCESSED_PARQUET_FILE_PATH}")
|
| 183 |
+
try:
|
| 184 |
+
df = pd.read_parquet(PROCESSED_PARQUET_FILE_PATH)
|
| 185 |
+
elapsed = time.time() - overall_start_time
|
| 186 |
+
missing_cols = [col for col in expected_cols_in_processed_parquet if col not in df.columns]
|
| 187 |
+
if missing_cols:
|
| 188 |
+
raise ValueError(f"Pre-processed Parquet is missing columns: {missing_cols}. Please run preprocessor or refresh data in app.")
|
| 189 |
+
|
| 190 |
+
# --- Diagnostic for 'has_robot' after loading parquet ---
|
| 191 |
+
if 'has_robot' in df.columns:
|
| 192 |
+
robot_count_parquet = df['has_robot'].sum()
|
| 193 |
+
print(f"DIAGNOSTIC (App - Parquet Load): 'has_robot' column found. Number of True values: {robot_count_parquet}")
|
| 194 |
+
if 0 < robot_count_parquet < 10:
|
| 195 |
+
print(f"Sample 'has_robot' models (from parquet): {df[df['has_robot']]['id'].head().tolist()}")
|
| 196 |
+
else:
|
| 197 |
+
print("DIAGNOSTIC (App - Parquet Load): 'has_robot' column NOT FOUND.")
|
| 198 |
+
# --- End Diagnostic ---
|
| 199 |
+
|
| 200 |
+
msg = f"Successfully loaded pre-processed data in {elapsed:.2f}s. Shape: {df.shape}"
|
| 201 |
+
print(msg)
|
| 202 |
+
return df, True, msg
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"Could not load pre-processed Parquet: {e}. ")
|
| 205 |
+
if force_refresh: print("Proceeding to fetch fresh data as force_refresh=True.")
|
| 206 |
+
else:
|
| 207 |
+
err_msg = (f"Pre-processed data could not be loaded: {e}. "
|
| 208 |
+
"Please use 'Refresh Data from Hugging Face' button.")
|
| 209 |
+
return pd.DataFrame(), False, err_msg
|
| 210 |
+
|
| 211 |
+
df_raw = None
|
| 212 |
+
raw_data_source_msg = ""
|
| 213 |
+
if force_refresh:
|
| 214 |
+
print("force_refresh=True (Gradio). Fetching fresh data...")
|
| 215 |
+
fetch_start = time.time()
|
| 216 |
+
try:
|
| 217 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')" # Ensure HF_PARQUET_URL is defined
|
| 218 |
+
df_raw = duckdb.sql(query).df()
|
| 219 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
| 220 |
+
raw_data_source_msg = f"Fetched by Gradio in {time.time() - fetch_start:.2f}s. Rows: {len(df_raw)}"
|
| 221 |
+
print(raw_data_source_msg)
|
| 222 |
+
except Exception as e_hf:
|
| 223 |
+
return pd.DataFrame(), False, f"Fatal error fetching from Hugging Face (Gradio): {e_hf}"
|
| 224 |
+
else:
|
| 225 |
+
err_msg = (f"Pre-processed data '{PROCESSED_PARQUET_FILE_PATH}' not found/invalid. "
|
| 226 |
+
"Run preprocessor or use 'Refresh Data' button.")
|
| 227 |
+
return pd.DataFrame(), False, err_msg
|
| 228 |
+
|
| 229 |
+
print(f"Initiating processing for data newly fetched by Gradio. {raw_data_source_msg}")
|
| 230 |
+
df = pd.DataFrame()
|
| 231 |
+
proc_start = time.time()
|
| 232 |
|
| 233 |
+
core_cols = {'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
| 234 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object}
|
| 235 |
+
for col, dtype in core_cols.items():
|
| 236 |
+
if col in df_raw.columns:
|
| 237 |
+
df[col] = df_raw[col]
|
| 238 |
+
if dtype == float: df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
|
| 239 |
+
elif dtype == str: df[col] = df[col].astype(str).fillna('')
|
| 240 |
+
else:
|
| 241 |
+
if col in ['downloads', 'downloadsAllTime', 'likes']: df[col] = 0.0
|
| 242 |
+
elif col == 'pipeline_tag': df[col] = ''
|
| 243 |
+
elif col == 'tags': df[col] = pd.Series([[] for _ in range(len(df_raw))])
|
| 244 |
+
elif col == 'safetensors': df[col] = None
|
| 245 |
+
elif col == 'id': return pd.DataFrame(), False, "Critical: 'id' column missing."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
output_filesize_col_name = 'params'
|
| 248 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
| 249 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
| 250 |
+
elif 'safetensors' in df.columns:
|
| 251 |
+
safetensors_iter = df['safetensors']
|
| 252 |
+
if tqdm_cls != tqdm :
|
| 253 |
+
safetensors_iter = tqdm_cls(df['safetensors'], desc="Extracting model sizes (GB)")
|
| 254 |
+
df[output_filesize_col_name] = [extract_model_size(s) for s in safetensors_iter]
|
| 255 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
| 256 |
+
else:
|
| 257 |
+
df[output_filesize_col_name] = 0.0
|
| 258 |
+
|
| 259 |
+
def get_size_category_gradio(size_gb_val):
|
| 260 |
+
try: numeric_size_gb = float(size_gb_val)
|
| 261 |
+
except (ValueError, TypeError): numeric_size_gb = 0.0
|
| 262 |
+
if pd.isna(numeric_size_gb): numeric_size_gb = 0.0
|
| 263 |
+
if 0 <= numeric_size_gb < 1: return "Small (<1GB)"
|
| 264 |
+
elif 1 <= numeric_size_gb < 5: return "Medium (1-5GB)"
|
| 265 |
+
elif 5 <= numeric_size_gb < 20: return "Large (5-20GB)"
|
| 266 |
+
elif 20 <= numeric_size_gb < 50: return "X-Large (20-50GB)"
|
| 267 |
+
elif numeric_size_gb >= 50: return "XX-Large (>50GB)"
|
| 268 |
+
else: return "Small (<1GB)"
|
| 269 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_size_category_gradio)
|
| 270 |
+
|
| 271 |
+
df['tags'] = process_tags_for_series(df['tags'])
|
| 272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
| 273 |
+
lambda tl: '~~~'.join(str(t).lower() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
| 274 |
+
)
|
| 275 |
+
tag_map = {
|
| 276 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
| 277 |
+
'has_robot': ['robot', 'robotics'],
|
| 278 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
| 279 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
| 280 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
| 281 |
+
'has_text': ['text', 'nlp', 'llm']
|
| 282 |
+
}
|
| 283 |
+
for col, kws in tag_map.items():
|
| 284 |
+
pattern = '|'.join(kws)
|
| 285 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
| 286 |
+
df['has_science'] = (
|
| 287 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
| 288 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
| 289 |
+
)
|
| 290 |
+
del df['temp_tags_joined']
|
| 291 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
| 292 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
| 293 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
| 294 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
| 295 |
+
|
| 296 |
+
if 'safetensors' in df.columns and \
|
| 297 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
| 298 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
|
|
|
| 299 |
|
| 300 |
+
# --- Diagnostic for 'has_robot' after app-side processing (force_refresh path) ---
|
| 301 |
+
if force_refresh and 'has_robot' in df.columns:
|
| 302 |
+
robot_count_app_proc = df['has_robot'].sum()
|
| 303 |
+
print(f"DIAGNOSTIC (App - Force Refresh Processing): 'has_robot' column processed. Number of True values: {robot_count_app_proc}")
|
| 304 |
+
if 0 < robot_count_app_proc < 10:
|
| 305 |
+
print(f"Sample 'has_robot' models (App processed): {df[df['has_robot']]['id'].head().tolist()}")
|
| 306 |
+
# --- End Diagnostic ---
|
| 307 |
+
|
| 308 |
+
print(f"Data processing by Gradio completed in {time.time() - proc_start:.2f}s.")
|
| 309 |
+
|
| 310 |
+
total_elapsed = time.time() - overall_start_time
|
| 311 |
+
final_msg = f"{raw_data_source_msg}. Processing by Gradio took {time.time() - proc_start:.2f}s. Total: {total_elapsed:.2f}s. Shape: {df.shape}"
|
| 312 |
+
print(final_msg)
|
| 313 |
+
return df, True, final_msg
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
| 317 |
+
if df is None or df.empty: return pd.DataFrame()
|
|
|
|
| 318 |
filtered_df = df.copy()
|
| 319 |
+
col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
|
| 320 |
+
"Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
|
| 321 |
+
"Video": "has_video", "Images": "has_image", "Text": "has_text"}
|
| 322 |
|
| 323 |
+
# --- Diagnostic within make_treemap_data ---
|
| 324 |
+
if 'has_robot' in filtered_df.columns:
|
| 325 |
+
initial_robot_count = filtered_df['has_robot'].sum()
|
| 326 |
+
print(f"DIAGNOSTIC (make_treemap_data entry): Input df has {initial_robot_count} 'has_robot' models.")
|
| 327 |
+
else:
|
| 328 |
+
print("DIAGNOSTIC (make_treemap_data entry): 'has_robot' column NOT in input df.")
|
| 329 |
+
# --- End Diagnostic ---
|
| 330 |
+
|
| 331 |
+
if tag_filter and tag_filter in col_map:
|
| 332 |
+
target_col = col_map[tag_filter]
|
| 333 |
+
if target_col in filtered_df.columns:
|
| 334 |
+
# --- Diagnostic for specific 'Robotics' filter application ---
|
| 335 |
+
if tag_filter == "Robotics":
|
| 336 |
+
count_before_robot_filter = filtered_df[target_col].sum()
|
| 337 |
+
print(f"DIAGNOSTIC (make_treemap_data): Applying 'Robotics' filter. Models with '{target_col}'=True before this filter step: {count_before_robot_filter}")
|
| 338 |
+
# --- End Diagnostic ---
|
| 339 |
+
filtered_df = filtered_df[filtered_df[target_col]]
|
| 340 |
+
if tag_filter == "Robotics":
|
| 341 |
+
print(f"DIAGNOSTIC (make_treemap_data): After 'Robotics' filter ({target_col}), df rows: {len(filtered_df)}")
|
| 342 |
+
else:
|
| 343 |
+
print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
if pipeline_filter:
|
| 345 |
+
if "pipeline_tag" in filtered_df.columns:
|
| 346 |
+
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
|
| 347 |
+
else:
|
| 348 |
+
print(f"Warning: 'pipeline_tag' column not found for filtering.")
|
| 349 |
+
if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
|
| 350 |
+
if 'size_category' in filtered_df.columns:
|
| 351 |
+
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
| 352 |
+
else:
|
| 353 |
+
print("Warning: 'size_category' column not found for filtering.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
if skip_orgs and len(skip_orgs) > 0:
|
| 355 |
+
if "organization" in filtered_df.columns:
|
| 356 |
+
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
|
| 357 |
+
else:
|
| 358 |
+
print("Warning: 'organization' column not found for filtering.")
|
| 359 |
+
if filtered_df.empty: return pd.DataFrame()
|
| 360 |
+
if count_by not in filtered_df.columns or not pd.api.types.is_numeric_dtype(filtered_df[count_by]):
|
| 361 |
+
filtered_df[count_by] = pd.to_numeric(filtered_df.get(count_by), errors="coerce").fillna(0.0)
|
| 362 |
+
org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first')
|
| 363 |
+
top_orgs_list = org_totals.index.tolist()
|
| 364 |
+
treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
treemap_data["root"] = "models"
|
| 366 |
+
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
return treemap_data
|
| 368 |
|
| 369 |
def create_treemap(treemap_data, count_by, title=None):
|
|
|
|
| 370 |
if treemap_data.empty:
|
| 371 |
+
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1])
|
| 372 |
+
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
return fig
|
|
|
|
|
|
|
| 374 |
fig = px.treemap(
|
| 375 |
+
treemap_data, path=["root", "organization", "id"], values=count_by,
|
|
|
|
|
|
|
| 376 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
| 377 |
color_discrete_sequence=px.colors.qualitative.Plotly
|
| 378 |
)
|
| 379 |
+
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
|
| 380 |
+
fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
return fig
|
| 382 |
|
| 383 |
+
with gr.Blocks(title="HuggingFace Model Explorer", fill_width=True) as demo:
|
| 384 |
+
models_data_state = gr.State(pd.DataFrame())
|
| 385 |
+
loading_complete_state = gr.State(False)
|
| 386 |
+
|
| 387 |
+
with gr.Row(): gr.Markdown("# HuggingFace Models TreeMap Visualization")
|
| 388 |
+
with gr.Row():
|
| 389 |
+
with gr.Column(scale=1):
|
| 390 |
+
count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
|
| 391 |
+
filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
|
| 392 |
+
tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
|
| 393 |
+
pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
|
| 394 |
+
size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
|
| 395 |
+
top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
|
| 396 |
+
skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski")
|
| 397 |
+
generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False)
|
| 398 |
+
refresh_data_button = gr.Button(value="Refresh Data from Hugging Face", variant="secondary")
|
| 399 |
+
with gr.Column(scale=3):
|
| 400 |
+
plot_output = gr.Plot()
|
| 401 |
+
status_message_md = gr.Markdown("Initializing...")
|
| 402 |
+
data_info_md = gr.Markdown("")
|
| 403 |
+
|
| 404 |
+
def _update_button_interactivity(is_loaded_flag):
|
| 405 |
+
return gr.update(interactive=is_loaded_flag)
|
| 406 |
+
loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)
|
| 407 |
+
|
| 408 |
+
def _toggle_filters_visibility(choice):
|
| 409 |
+
return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
|
| 410 |
+
filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])
|
| 411 |
+
|
| 412 |
+
def ui_load_data_controller(force_refresh_ui_trigger=False, progress=gr.Progress(track_tqdm=True)):
|
| 413 |
+
print(f"ui_load_data_controller called with force_refresh_ui_trigger={force_refresh_ui_trigger}")
|
| 414 |
+
status_msg_ui = "Loading data..."
|
| 415 |
+
data_info_text = ""
|
| 416 |
+
current_df = pd.DataFrame()
|
| 417 |
+
load_success_flag = False
|
| 418 |
+
data_as_of_date_display = "N/A"
|
| 419 |
try:
|
| 420 |
+
current_df, load_success_flag, status_msg_from_load = load_models_data(
|
| 421 |
+
force_refresh=force_refresh_ui_trigger, tqdm_cls=progress.tqdm
|
| 422 |
+
)
|
| 423 |
+
if load_success_flag:
|
| 424 |
+
if force_refresh_ui_trigger:
|
| 425 |
+
data_as_of_date_display = pd.Timestamp.now(tz='UTC').strftime('%B %d, %Y, %H:%M:%S %Z')
|
| 426 |
+
elif 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
|
| 427 |
+
timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
|
| 428 |
+
if timestamp_from_parquet.tzinfo is None:
|
| 429 |
+
timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
|
| 430 |
+
data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
else:
|
| 432 |
+
data_as_of_date_display = "Pre-processed (date unavailable)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
+
size_dist_lines = []
|
| 435 |
+
if 'size_category' in current_df.columns:
|
| 436 |
+
for cat in MODEL_SIZE_RANGES.keys():
|
| 437 |
+
count = (current_df['size_category'] == cat).sum()
|
| 438 |
+
size_dist_lines.append(f" - {cat}: {count:,} models")
|
| 439 |
+
else: size_dist_lines.append(" - Size category information not available.")
|
| 440 |
+
size_dist = "\n".join(size_dist_lines)
|
|
|
|
|
|
|
| 441 |
|
| 442 |
+
data_info_text = (f"### Data Information\n"
|
| 443 |
+
f"- Overall Status: {status_msg_from_load}\n"
|
| 444 |
+
f"- Total models loaded: {len(current_df):,}\n"
|
| 445 |
+
f"- Data as of: {data_as_of_date_display}\n"
|
| 446 |
+
f"- Size categories:\n{size_dist}")
|
| 447 |
|
| 448 |
+
# # --- MODIFICATION: Add 'has_robot' count to UI data_info_text ---
|
| 449 |
+
# if not current_df.empty and 'has_robot' in current_df.columns:
|
| 450 |
+
# robot_true_count = current_df['has_robot'].sum()
|
| 451 |
+
# data_info_text += f"\n- **Models flagged 'has_robot'**: {robot_true_count}"
|
| 452 |
+
# if 0 < robot_true_count <= 10: # If a few are found, list some IDs
|
| 453 |
+
# sample_robot_ids = current_df[current_df['has_robot']]['id'].head(5).tolist()
|
| 454 |
+
# data_info_text += f"\n - Sample 'has_robot' model IDs: `{', '.join(sample_robot_ids)}`"
|
| 455 |
+
# elif not current_df.empty:
|
| 456 |
+
# data_info_text += "\n- **Models flagged 'has_robot'**: 'has_robot' column not found in loaded data."
|
| 457 |
+
# # --- END MODIFICATION ---
|
| 458 |
+
|
| 459 |
+
status_msg_ui = "Data loaded successfully. Ready to generate plot."
|
| 460 |
+
else:
|
| 461 |
+
data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
|
| 462 |
+
status_msg_ui = status_msg_from_load
|
| 463 |
+
except Exception as e:
|
| 464 |
+
status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
|
| 465 |
+
data_info_text = f"### Critical Error\n- {status_msg_ui}"
|
| 466 |
+
print(f"Critical error in ui_load_data_controller: {e}")
|
| 467 |
+
load_success_flag = False
|
| 468 |
+
return current_df, load_success_flag, data_info_text, status_msg_ui
|
| 469 |
+
|
| 470 |
+
def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
|
| 471 |
+
size_choice, k_orgs, skip_orgs_input, df_current_models):
|
| 472 |
+
if df_current_models is None or df_current_models.empty:
|
| 473 |
+
empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
|
| 474 |
+
error_msg = "Model data is not loaded or is empty. Please load or refresh data first."
|
| 475 |
+
gr.Warning(error_msg)
|
| 476 |
+
return empty_fig, error_msg
|
| 477 |
+
tag_to_use = tag_choice if filter_type == "Tag Filter" else None
|
| 478 |
+
pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
|
| 479 |
+
size_to_use = size_choice if size_choice != "None" else None
|
| 480 |
+
orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
|
| 481 |
|
| 482 |
+
# --- Diagnostic before calling make_treemap_data ---
|
| 483 |
+
if 'has_robot' in df_current_models.columns:
|
| 484 |
+
robot_count_before_treemap = df_current_models['has_robot'].sum()
|
| 485 |
+
print(f"DIAGNOSTIC (ui_generate_plot_controller): df_current_models entering make_treemap_data has {robot_count_before_treemap} 'has_robot' models.")
|
| 486 |
+
# --- End Diagnostic ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
+
treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
+
title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
|
| 491 |
+
chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
|
| 492 |
+
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
|
| 493 |
+
if treemap_df.empty:
|
| 494 |
+
plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
else:
|
| 496 |
+
total_items_in_plot = len(treemap_df['id'].unique())
|
| 497 |
+
total_value_in_plot = treemap_df[metric_choice].sum()
|
| 498 |
+
plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
|
| 499 |
+
return plotly_fig, plot_stats_md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
demo.load(
|
| 502 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=False, progress=progress),
|
| 503 |
+
inputs=[],
|
| 504 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
| 505 |
)
|
|
|
|
|
|
|
| 506 |
refresh_data_button.click(
|
| 507 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=True, progress=progress),
|
| 508 |
+
inputs=[],
|
| 509 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
| 510 |
)
|
|
|
|
| 511 |
generate_plot_button.click(
|
| 512 |
+
fn=ui_generate_plot_controller,
|
| 513 |
+
inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
|
| 514 |
+
size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
|
| 515 |
+
outputs=[plot_output, status_message_md]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
)
|
| 517 |
|
| 518 |
if __name__ == "__main__":
|
| 519 |
+
if not os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
| 520 |
+
print(f"WARNING: Pre-processed data file '{PROCESSED_PARQUET_FILE_PATH}' not found.")
|
| 521 |
+
print("It is highly recommended to run the preprocessing script (e.g., preprocess.py) first.") # Corrected script name
|
| 522 |
+
else:
|
| 523 |
+
print(f"Found pre-processed data file: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
| 524 |
+
demo.launch()
|
| 525 |
+
|
| 526 |
+
# --- END OF FILE app.py ---
|
models_processed.parquet
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:998afad6c0c4c64f9e98efd8609d1cbab1dd2ac281b9c2e023878ad436c2fbde
|
| 3 |
+
size 96033487
|
preprocess.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- START OF FILE preprocess.py ---
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import ast
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
import time
|
| 9 |
+
import os
|
| 10 |
+
import duckdb
|
| 11 |
+
import re # Import re for the manual regex check in debug
|
| 12 |
+
|
| 13 |
+
# --- Constants ---
|
| 14 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
| 15 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
|
| 16 |
+
|
| 17 |
+
MODEL_SIZE_RANGES = {
|
| 18 |
+
"Small (<1GB)": (0, 1),
|
| 19 |
+
"Medium (1-5GB)": (1, 5),
|
| 20 |
+
"Large (5-20GB)": (5, 20),
|
| 21 |
+
"X-Large (20-50GB)": (20, 50),
|
| 22 |
+
"XX-Large (>50GB)": (50, float('inf'))
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# --- Debugging Constant ---
|
| 26 |
+
# <<<<<<< SET THE MODEL ID YOU WANT TO DEBUG HERE >>>>>>>
|
| 27 |
+
MODEL_ID_TO_DEBUG = "openvla/openvla-7b"
|
| 28 |
+
# Example: MODEL_ID_TO_DEBUG = "openai-community/gpt2"
|
| 29 |
+
# If you don't have a specific ID, the debug block will just report it's not found.
|
| 30 |
+
|
| 31 |
+
# --- Utility Functions (extract_model_file_size_gb, extract_org_from_id, process_tags_for_series, get_file_size_category - unchanged from previous correct version) ---
|
| 32 |
+
def extract_model_file_size_gb(safetensors_data):
|
| 33 |
+
try:
|
| 34 |
+
if pd.isna(safetensors_data): return 0.0
|
| 35 |
+
data_to_parse = safetensors_data
|
| 36 |
+
if isinstance(safetensors_data, str):
|
| 37 |
+
try:
|
| 38 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
| 39 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
| 40 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
| 41 |
+
else: data_to_parse = json.loads(safetensors_data)
|
| 42 |
+
except Exception: return 0.0
|
| 43 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
| 44 |
+
total_bytes_val = data_to_parse['total']
|
| 45 |
+
try:
|
| 46 |
+
size_bytes = float(total_bytes_val)
|
| 47 |
+
return size_bytes / (1024 * 1024 * 1024)
|
| 48 |
+
except (ValueError, TypeError): return 0.0
|
| 49 |
+
return 0.0
|
| 50 |
+
except Exception: return 0.0
|
| 51 |
+
|
| 52 |
+
def extract_org_from_id(model_id):
|
| 53 |
+
if pd.isna(model_id): return "unaffiliated"
|
| 54 |
+
model_id_str = str(model_id)
|
| 55 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
| 56 |
+
|
| 57 |
+
def process_tags_for_series(series_of_tags_values):
|
| 58 |
+
processed_tags_accumulator = []
|
| 59 |
+
|
| 60 |
+
for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
|
| 61 |
+
temp_processed_list_for_row = []
|
| 62 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
# Order of checks is important!
|
| 66 |
+
# 1. Handle explicit Python lists first
|
| 67 |
+
if isinstance(tags_value_from_series, list):
|
| 68 |
+
current_tags_in_list = []
|
| 69 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series):
|
| 70 |
+
try:
|
| 71 |
+
# Ensure item is not NaN before string conversion if it might be a float NaN in a list
|
| 72 |
+
if pd.isna(tag_item): continue
|
| 73 |
+
str_tag = str(tag_item)
|
| 74 |
+
stripped_tag = str_tag.strip()
|
| 75 |
+
if stripped_tag:
|
| 76 |
+
current_tags_in_list.append(stripped_tag)
|
| 77 |
+
except Exception as e_inner_list_proc:
|
| 78 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
|
| 79 |
+
temp_processed_list_for_row = current_tags_in_list
|
| 80 |
+
|
| 81 |
+
# 2. Handle NumPy arrays
|
| 82 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
| 83 |
+
# Convert to list, then process elements, handling potential NaNs within the array
|
| 84 |
+
current_tags_in_list = []
|
| 85 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
|
| 86 |
+
try:
|
| 87 |
+
if pd.isna(tag_item): continue # Check for NaN after converting to Python type
|
| 88 |
+
str_tag = str(tag_item)
|
| 89 |
+
stripped_tag = str_tag.strip()
|
| 90 |
+
if stripped_tag:
|
| 91 |
+
current_tags_in_list.append(stripped_tag)
|
| 92 |
+
except Exception as e_inner_array_proc:
|
| 93 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
|
| 94 |
+
temp_processed_list_for_row = current_tags_in_list
|
| 95 |
+
|
| 96 |
+
# 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
|
| 97 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
|
| 98 |
+
temp_processed_list_for_row = []
|
| 99 |
+
|
| 100 |
+
# 4. Handle strings (could be JSON-like, list-like, or comma-separated)
|
| 101 |
+
elif isinstance(tags_value_from_series, str):
|
| 102 |
+
processed_str_tags = []
|
| 103 |
+
# Attempt ast.literal_eval for strings that look like lists/tuples
|
| 104 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
| 105 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
| 106 |
+
try:
|
| 107 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
| 108 |
+
if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
|
| 109 |
+
# Recursively process this evaluated list/tuple, as its elements could be complex
|
| 110 |
+
# For simplicity here, assume elements are simple strings after eval
|
| 111 |
+
current_eval_list = []
|
| 112 |
+
for tag_item in evaluated_tags:
|
| 113 |
+
if pd.isna(tag_item): continue
|
| 114 |
+
str_tag = str(tag_item).strip()
|
| 115 |
+
if str_tag: current_eval_list.append(str_tag)
|
| 116 |
+
processed_str_tags = current_eval_list
|
| 117 |
+
except (ValueError, SyntaxError):
|
| 118 |
+
pass # If ast.literal_eval fails, let it fall to JSON or comma split
|
| 119 |
+
|
| 120 |
+
# If ast.literal_eval didn't populate, try JSON
|
| 121 |
+
if not processed_str_tags:
|
| 122 |
+
try:
|
| 123 |
+
json_tags = json.loads(tags_value_from_series)
|
| 124 |
+
if isinstance(json_tags, list):
|
| 125 |
+
# Similar to above, assume elements are simple strings after JSON parsing
|
| 126 |
+
current_json_list = []
|
| 127 |
+
for tag_item in json_tags:
|
| 128 |
+
if pd.isna(tag_item): continue
|
| 129 |
+
str_tag = str(tag_item).strip()
|
| 130 |
+
if str_tag: current_json_list.append(str_tag)
|
| 131 |
+
processed_str_tags = current_json_list
|
| 132 |
+
except json.JSONDecodeError:
|
| 133 |
+
# If not a valid JSON list, fall back to comma splitting as the final string strategy
|
| 134 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
| 135 |
+
except Exception as e_json_other:
|
| 136 |
+
print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
| 137 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
|
| 138 |
+
|
| 139 |
+
temp_processed_list_for_row = processed_str_tags
|
| 140 |
+
|
| 141 |
+
# 5. Fallback for other scalar types (e.g., int, float that are not NaN)
|
| 142 |
+
else:
|
| 143 |
+
# This path is for non-list, non-ndarray, non-None/NaN, non-string types.
|
| 144 |
+
# Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
|
| 145 |
+
if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
|
| 146 |
+
temp_processed_list_for_row = []
|
| 147 |
+
else:
|
| 148 |
+
str_val = str(tags_value_from_series).strip()
|
| 149 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
| 150 |
+
|
| 151 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
| 152 |
+
|
| 153 |
+
except Exception as e_outer_tag_proc:
|
| 154 |
+
print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
| 155 |
+
processed_tags_accumulator.append([])
|
| 156 |
+
|
| 157 |
+
return processed_tags_accumulator
|
| 158 |
+
|
| 159 |
+
def get_file_size_category(file_size_gb_val):
|
| 160 |
+
try:
|
| 161 |
+
numeric_file_size_gb = float(file_size_gb_val)
|
| 162 |
+
if pd.isna(numeric_file_size_gb): numeric_file_size_gb = 0.0
|
| 163 |
+
except (ValueError, TypeError): numeric_file_size_gb = 0.0
|
| 164 |
+
if 0 <= numeric_file_size_gb < 1: return "Small (<1GB)"
|
| 165 |
+
elif 1 <= numeric_file_size_gb < 5: return "Medium (1-5GB)"
|
| 166 |
+
elif 5 <= numeric_file_size_gb < 20: return "Large (5-20GB)"
|
| 167 |
+
elif 20 <= numeric_file_size_gb < 50: return "X-Large (20-50GB)"
|
| 168 |
+
elif numeric_file_size_gb >= 50: return "XX-Large (>50GB)"
|
| 169 |
+
else: return "Small (<1GB)"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main_preprocessor():
|
| 173 |
+
print(f"Starting pre-processing script. Output: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
| 174 |
+
overall_start_time = time.time()
|
| 175 |
+
|
| 176 |
+
print(f"Fetching fresh data from Hugging Face: {HF_PARQUET_URL}")
|
| 177 |
+
try:
|
| 178 |
+
fetch_start_time = time.time()
|
| 179 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
|
| 180 |
+
df_raw = duckdb.sql(query).df()
|
| 181 |
+
data_download_timestamp = pd.Timestamp.now(tz='UTC')
|
| 182 |
+
|
| 183 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
| 184 |
+
if 'id' not in df_raw.columns: raise ValueError("Fetched data must contain 'id' column.")
|
| 185 |
+
|
| 186 |
+
print(f"Fetched data in {time.time() - fetch_start_time:.2f}s. Rows: {len(df_raw)}. Downloaded at: {data_download_timestamp.strftime('%Y-%m-%d %H:%M:%S %Z')}")
|
| 187 |
+
except Exception as e_fetch:
|
| 188 |
+
print(f"ERROR: Could not fetch data from Hugging Face: {e_fetch}.")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
df = pd.DataFrame()
|
| 192 |
+
print("Processing raw data...")
|
| 193 |
+
proc_start = time.time()
|
| 194 |
+
|
| 195 |
+
expected_cols_setup = {
|
| 196 |
+
'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
| 197 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object
|
| 198 |
+
}
|
| 199 |
+
for col_name, target_dtype in expected_cols_setup.items():
|
| 200 |
+
if col_name in df_raw.columns:
|
| 201 |
+
df[col_name] = df_raw[col_name]
|
| 202 |
+
if target_dtype == float: df[col_name] = pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
|
| 203 |
+
elif target_dtype == str: df[col_name] = df[col_name].astype(str).fillna('')
|
| 204 |
+
else:
|
| 205 |
+
if col_name in ['downloads', 'downloadsAllTime', 'likes']: df[col_name] = 0.0
|
| 206 |
+
elif col_name == 'pipeline_tag': df[col_name] = ''
|
| 207 |
+
elif col_name == 'tags': df[col_name] = pd.Series([[] for _ in range(len(df_raw))]) # Initialize with empty lists
|
| 208 |
+
elif col_name == 'safetensors': df[col_name] = None # Initialize with None
|
| 209 |
+
elif col_name == 'id': print("CRITICAL ERROR: 'id' column missing."); return
|
| 210 |
+
|
| 211 |
+
output_filesize_col_name = 'params'
|
| 212 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
| 213 |
+
print(f"Using pre-existing '{output_filesize_col_name}' column as file size in GB.")
|
| 214 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
| 215 |
+
elif 'safetensors' in df.columns:
|
| 216 |
+
print(f"Calculating '{output_filesize_col_name}' (file size in GB) from 'safetensors' data...")
|
| 217 |
+
df[output_filesize_col_name] = df['safetensors'].apply(extract_model_file_size_gb)
|
| 218 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
| 219 |
+
else:
|
| 220 |
+
print(f"Cannot determine file size. Setting '{output_filesize_col_name}' to 0.0.")
|
| 221 |
+
df[output_filesize_col_name] = 0.0
|
| 222 |
+
|
| 223 |
+
df['data_download_timestamp'] = data_download_timestamp
|
| 224 |
+
print(f"Added 'data_download_timestamp' column.")
|
| 225 |
+
|
| 226 |
+
print("Categorizing models by file size...")
|
| 227 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_file_size_category)
|
| 228 |
+
|
| 229 |
+
print("Standardizing 'tags' column...")
|
| 230 |
+
df['tags'] = process_tags_for_series(df['tags']) # This now uses tqdm internally
|
| 231 |
+
|
| 232 |
+
# --- START DEBUGGING BLOCK ---
|
| 233 |
+
# This block will execute before the main tag processing loop
|
| 234 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values: # Check if ID exists
|
| 235 |
+
print(f"\n--- Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---")
|
| 236 |
+
|
| 237 |
+
# 1. Check the 'tags' column content after process_tags_for_series
|
| 238 |
+
model_specific_tags_list = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]
|
| 239 |
+
print(f"1. Tags from df['tags'] (after process_tags_for_series): {model_specific_tags_list}")
|
| 240 |
+
print(f" Type of tags: {type(model_specific_tags_list)}")
|
| 241 |
+
if isinstance(model_specific_tags_list, list):
|
| 242 |
+
for i, tag_item in enumerate(model_specific_tags_list):
|
| 243 |
+
print(f" Tag item {i}: '{tag_item}' (type: {type(tag_item)}, len: {len(str(tag_item))})")
|
| 244 |
+
# Detailed check for 'robotics' specifically
|
| 245 |
+
if 'robotics' in str(tag_item).lower():
|
| 246 |
+
print(f" DEBUG: Found 'robotics' substring in '{tag_item}'")
|
| 247 |
+
print(f" - str(tag_item).lower().strip(): '{str(tag_item).lower().strip()}'")
|
| 248 |
+
print(f" - Is it exactly 'robotics'?: {str(tag_item).lower().strip() == 'robotics'}")
|
| 249 |
+
print(f" - Ordinals: {[ord(c) for c in str(tag_item)]}")
|
| 250 |
+
|
| 251 |
+
# 2. Simulate temp_tags_joined for this specific model
|
| 252 |
+
if isinstance(model_specific_tags_list, list):
|
| 253 |
+
simulated_temp_tags_joined = '~~~'.join(str(t).lower().strip() for t in model_specific_tags_list if pd.notna(t) and str(t).strip())
|
| 254 |
+
else:
|
| 255 |
+
simulated_temp_tags_joined = ''
|
| 256 |
+
print(f"2. Simulated 'temp_tags_joined' for this model: '{simulated_temp_tags_joined}'")
|
| 257 |
+
|
| 258 |
+
# 3. Simulate 'has_robot' check for this model
|
| 259 |
+
robot_keywords = ['robot', 'robotics']
|
| 260 |
+
robot_pattern = '|'.join(robot_keywords)
|
| 261 |
+
manual_robot_check = bool(re.search(robot_pattern, simulated_temp_tags_joined, flags=re.IGNORECASE))
|
| 262 |
+
print(f"3. Manual regex check for 'has_robot' ('{robot_pattern}' in '{simulated_temp_tags_joined}'): {manual_robot_check}")
|
| 263 |
+
print(f"--- End Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---\n")
|
| 264 |
+
elif MODEL_ID_TO_DEBUG:
|
| 265 |
+
print(f"DEBUG: Model ID '{MODEL_ID_TO_DEBUG}' not found in DataFrame for pre-loop debugging.")
|
| 266 |
+
# --- END DEBUGGING BLOCK ---
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
print("Vectorized creation of cached tag columns...")
|
| 270 |
+
tag_time = time.time()
|
| 271 |
+
# This is the original temp_tags_joined creation:
|
| 272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
| 273 |
+
lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
tag_map = {
|
| 277 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
| 278 |
+
'has_robot': ['robot', 'robotics','openvla','vla'],
|
| 279 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
| 280 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
| 281 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
| 282 |
+
'has_text': ['text', 'nlp', 'llm']
|
| 283 |
+
}
|
| 284 |
+
for col, kws in tag_map.items():
|
| 285 |
+
pattern = '|'.join(kws)
|
| 286 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
| 287 |
+
|
| 288 |
+
df['has_science'] = (
|
| 289 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
| 290 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
| 291 |
+
)
|
| 292 |
+
del df['temp_tags_joined'] # Clean up temporary column
|
| 293 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
| 294 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
| 295 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
| 296 |
+
print(f"Vectorized tag columns created in {time.time() - tag_time:.2f}s.")
|
| 297 |
+
|
| 298 |
+
# --- POST-LOOP DIAGNOSTIC for has_robot & a specific model ---
|
| 299 |
+
if 'has_robot' in df.columns:
|
| 300 |
+
print("\n--- 'has_robot' Diagnostics (Preprocessor - Post-Loop) ---")
|
| 301 |
+
print(df['has_robot'].value_counts(dropna=False))
|
| 302 |
+
|
| 303 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values:
|
| 304 |
+
model_has_robot_val = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'has_robot'].iloc[0]
|
| 305 |
+
print(f"Value of 'has_robot' for model '{MODEL_ID_TO_DEBUG}': {model_has_robot_val}")
|
| 306 |
+
if model_has_robot_val:
|
| 307 |
+
print(f" Original tags for '{MODEL_ID_TO_DEBUG}': {df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]}")
|
| 308 |
+
|
| 309 |
+
if df['has_robot'].any():
|
| 310 |
+
print("Sample models flagged as 'has_robot':")
|
| 311 |
+
print(df[df['has_robot']][['id', 'tags', 'has_robot']].head(5))
|
| 312 |
+
else:
|
| 313 |
+
print("No models were flagged as 'has_robot' after processing.")
|
| 314 |
+
print("--------------------------------------------------------\n")
|
| 315 |
+
# --- END POST-LOOP DIAGNOSTIC ---
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
print("Adding organization column...")
|
| 319 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
| 320 |
+
|
| 321 |
+
# Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
|
| 322 |
+
if 'safetensors' in df.columns and \
|
| 323 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
| 324 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
| 325 |
+
|
| 326 |
+
final_expected_cols = [
|
| 327 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags',
|
| 328 |
+
'params', 'size_category', 'organization',
|
| 329 |
+
'has_audio', 'has_speech', 'has_music', 'has_robot', 'has_bio', 'has_med',
|
| 330 |
+
'has_series', 'has_video', 'has_image', 'has_text', 'has_science',
|
| 331 |
+
'is_audio_speech', 'is_biomed',
|
| 332 |
+
'data_download_timestamp'
|
| 333 |
+
]
|
| 334 |
+
# Ensure all final columns exist, adding defaults if necessary
|
| 335 |
+
for col in final_expected_cols:
|
| 336 |
+
if col not in df.columns:
|
| 337 |
+
print(f"Warning: Final expected column '{col}' is missing! Defaulting appropriately.")
|
| 338 |
+
if col == 'params': df[col] = 0.0
|
| 339 |
+
elif col == 'size_category': df[col] = "Small (<1GB)" # Default size category
|
| 340 |
+
elif 'has_' in col or 'is_' in col : df[col] = False # Default boolean flags to False
|
| 341 |
+
elif col == 'data_download_timestamp': df[col] = pd.NaT # Default timestamp to NaT
|
| 342 |
+
|
| 343 |
+
print(f"Data processing completed in {time.time() - proc_start:.2f}s.")
|
| 344 |
+
try:
|
| 345 |
+
print(f"Saving processed data to: {PROCESSED_PARQUET_FILE_PATH}")
|
| 346 |
+
df_to_save = df[final_expected_cols].copy() # Ensure only expected columns are saved
|
| 347 |
+
df_to_save.to_parquet(PROCESSED_PARQUET_FILE_PATH, index=False, engine='pyarrow')
|
| 348 |
+
print(f"Successfully saved processed data.")
|
| 349 |
+
except Exception as e_save:
|
| 350 |
+
print(f"ERROR: Could not save processed data: {e_save}")
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
total_elapsed_script = time.time() - overall_start_time
|
| 354 |
+
print(f"Pre-processing finished. Total time: {total_elapsed_script:.2f}s. Final Parquet shape: {df_to_save.shape}")
|
| 355 |
+
|
| 356 |
+
if __name__ == "__main__":
|
| 357 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
| 358 |
+
print(f"Deleting existing '{PROCESSED_PARQUET_FILE_PATH}' to ensure fresh processing...")
|
| 359 |
+
try: os.remove(PROCESSED_PARQUET_FILE_PATH)
|
| 360 |
+
except OSError as e: print(f"Error deleting file: {e}. Please delete manually and rerun."); exit()
|
| 361 |
+
|
| 362 |
+
main_preprocessor()
|
| 363 |
+
|
| 364 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
| 365 |
+
print(f"\nTo verify, load parquet and check 'has_robot' and its 'tags':")
|
| 366 |
+
print(f"import pandas as pd; df_chk = pd.read_parquet('{PROCESSED_PARQUET_FILE_PATH}')")
|
| 367 |
+
print(f"print(df_chk['has_robot'].value_counts())")
|
| 368 |
+
if MODEL_ID_TO_DEBUG:
|
| 369 |
+
print(f"print(df_chk[df_chk['id'] == '{MODEL_ID_TO_DEBUG}'][['id', 'tags', 'has_robot']])")
|
| 370 |
+
else:
|
| 371 |
+
print(f"print(df_chk[df_chk['has_robot']][['id', 'tags', 'has_robot']].head())")
|