dataset_generator / load_base_model_data.py
davanstrien's picture
davanstrien HF Staff
Update dataset push destination
42f30e0
raw
history blame
3.05 kB
import os
from datetime import datetime
from functools import lru_cache
from typing import List, Optional
from datasets import Dataset
from huggingface_hub import list_models
from pydantic import BaseModel, field_validator
from toolz import groupby
from tqdm.auto import tqdm
HF_TOKEN = os.environ.get("HF_TOKEN")
@lru_cache()
def get_all_models():
models = list(
tqdm(
iter(
list_models(
cardData=True, full=True, limit=None, sort="downloads", direction=-1
)
)
)
)
return [model for model in models if model is not None]
def has_base_model_info(model):
return bool(hasattr(model.card_data, "base_model"))
class HubModel(BaseModel):
author: Optional[str] = None
last_modified: Optional[datetime] = None
createdAt: Optional[datetime] = None
downloads: Optional[int] = None
likes: Optional[int] = None
library_name: Optional[str] = None
modelId: Optional[str] = None
datasets: Optional[List[str]] = None
language: Optional[List[str]] = None
base_model: Optional[str] = None
@field_validator("language", "datasets", mode="before")
def ensure_list(cls, v):
return [v] if isinstance(v, str) else v
@classmethod
def from_original(cls, original_data: dict) -> "HubModel":
card_data = original_data.get("card_data", {})
if card_data is None:
card_data = {}
if not isinstance(card_data, dict):
card_data = card_data.__dict__
return cls(
author=original_data.get("author"),
last_modified=original_data.get("last_modified"),
createdAt=original_data.get("createdAt"),
downloads=original_data.get("downloads"),
likes=original_data.get("likes"),
library_name=original_data.get("library_name"),
modelId=original_data.get("modelId"),
datasets=card_data.get("datasets"),
language=card_data.get("language"),
base_model=card_data.get("base_model"),
)
def load_data():
grouped_by_has_base_model_info = groupby(has_base_model_info, get_all_models())
models_with_base_model_info = grouped_by_has_base_model_info.get(True)
models_without_base_models = grouped_by_has_base_model_info.get(False)
parsed_models = [
HubModel.from_original(model.__dict__).model_dump()
for model in models_with_base_model_info
]
base_models = {model["base_model"] for model in parsed_models}
base_models = [
model for model in tqdm(models_without_base_models) if model.id in base_models
]
base_models = [model for model in base_models if model is not None]
parsed_base_models = [
HubModel.from_original(model.__dict__).model_dump() for model in base_models
]
ds = Dataset.from_list(parsed_models + parsed_base_models)
ds.push_to_hub("librarian-bots/hub_models_with_base_model_info", token=HF_TOKEN)
print("Pushed to hub")
return ds