File size: 2,546 Bytes
518d841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json

from datasets import Dataset, Features, Value
from huggingface_hub import HfApi, ModelCard, hf_hub_download


def get_model_ids(pipeline_tag: str) -> list[str]:
    hf_api = HfApi()
    models = hf_api.list_models(
        library=["transformers"],
        tags=["image-captioning"],
        # pipeline_tag=pipeline_tag,
        gated=False,
        fetch_config=True,
    )
    models = list(models)
    model_ids = [model.id for model in models]
    return model_ids


def get_model_card(model_id: str) -> str:
    # try:
    model_card = ModelCard.load(model_id, ignore_metadata_errors=True)
    if not (model_card.data["pipeline_tag"] == "image-captioning" or "image-captioning" in model_card.data["tags"]):
        print(model_card)
        assert False

    return model_card.text
    # except Exception:
    # return ""


def get_model_labels(model_id: str) -> list[str]:
    hf_api = HfApi()
    if hf_api.file_exists(model_id, filename="config.json"):
        config_path = hf_hub_download(model_id, filename="config.json")
        with open(config_path, "r") as f:
            try:
                model_config = json.load(f)
            except json.JSONDecodeError:
                return [""]
        if "id2label" in model_config:
            labels = list(model_config["id2label"].values())
            labels = [str(label).lower() for label in labels]
            return labels
        else:
            return [""]
    else:
        return [""]


def create_dataset(pipeline_tag: str):
    def dataset_gen(model_ids: list[str]):
        for model_id in model_ids:
            model_card = get_model_card(model_id)
            # model_labels = get_model_labels(model_id)
            # if len(model_labels) > 1 and len(model_card) > 0:
            yield {
                "model_id": model_id,
                "model_card": model_card,
                # "model_labels": model_labels,
            }

    model_ids = get_model_ids(pipeline_tag)
    print(f"Found {len(model_ids)} models")

    dataset = Dataset.from_generator(
        dataset_gen,
        gen_kwargs={"model_ids": model_ids},
        features=Features(
            {
                "model_id": Value("string"),
                "model_card": Value("string"),
                "model_labels": [Value("string")],
            }
        ),
        num_proc=12,
    )

    return dataset


if __name__ == "__main__":
    dataset = create_dataset("image-captioning")
    print(dataset)
    dataset.push_to_hub("stevenbucaille/image-captioning-models-dataset")