Spaces:
Running
Running
File size: 2,599 Bytes
518d841 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import json
from datasets import Dataset, Features, Value
from huggingface_hub import HfApi, ModelCard, hf_hub_download
from huggingface_hub.utils import disable_progress_bars
import tqdm
def get_model_ids(pipeline_tag: str) -> list[str]:
hf_api = HfApi()
models = hf_api.list_models(
library=["transformers"],
pipeline_tag=pipeline_tag,
gated=False,
fetch_config=True,
)
models = list(models)
model_ids = [model.id for model in models]
return model_ids
def get_model_card(model_id: str) -> str:
try:
model_card = ModelCard.load(model_id)
return model_card.text
except Exception:
return ""
def get_model_labels(model_id: str) -> list[str]:
hf_api = HfApi()
if hf_api.file_exists(model_id, filename="config.json"):
config_path = hf_hub_download(model_id, filename="config.json")
with open(config_path, "r") as f:
try:
model_config = json.load(f)
except json.JSONDecodeError:
return []
if "id2label" in model_config:
labels = list(model_config["id2label"].values())
labels = [str(label).lower() for label in labels]
return labels
else:
return []
else:
return []
def create_dataset(pipeline_tag: str):
def dataset_gen(model_ids: list[str]):
for model_id in model_ids:
try:
model_card = get_model_card(model_id)
model_labels = get_model_labels(model_id)
except Exception as e:
print(f"Error getting model card or labels for {model_id}: {e}")
continue
if len(model_labels) > 1 and len(model_card) > 0:
yield {
"model_id": model_id,
"model_card": model_card,
"model_labels": model_labels,
}
model_ids = get_model_ids(pipeline_tag)
print(f"Found {len(model_ids)} models")
dataset = Dataset.from_generator(
dataset_gen,
gen_kwargs={"model_ids": model_ids},
features=Features(
{
"model_id": Value("string"),
"model_card": Value("string"),
"model_labels": [Value("string")],
}
),
num_proc=24,
)
return dataset
if __name__ == "__main__":
disable_progress_bars()
dataset = create_dataset("image-classification")
print(dataset)
dataset.push_to_hub("stevenbucaille/image-classification-models-dataset")
|