ScouterAI / rag /create_image_captioning_dataset.py
stevenbucaille's picture
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
raw
history blame
2.55 kB
import json
from datasets import Dataset, Features, Value
from huggingface_hub import HfApi, ModelCard, hf_hub_download
def get_model_ids(pipeline_tag: str) -> list[str]:
hf_api = HfApi()
models = hf_api.list_models(
library=["transformers"],
tags=["image-captioning"],
# pipeline_tag=pipeline_tag,
gated=False,
fetch_config=True,
)
models = list(models)
model_ids = [model.id for model in models]
return model_ids
def get_model_card(model_id: str) -> str:
# try:
model_card = ModelCard.load(model_id, ignore_metadata_errors=True)
if not (model_card.data["pipeline_tag"] == "image-captioning" or "image-captioning" in model_card.data["tags"]):
print(model_card)
assert False
return model_card.text
# except Exception:
# return ""
def get_model_labels(model_id: str) -> list[str]:
hf_api = HfApi()
if hf_api.file_exists(model_id, filename="config.json"):
config_path = hf_hub_download(model_id, filename="config.json")
with open(config_path, "r") as f:
try:
model_config = json.load(f)
except json.JSONDecodeError:
return [""]
if "id2label" in model_config:
labels = list(model_config["id2label"].values())
labels = [str(label).lower() for label in labels]
return labels
else:
return [""]
else:
return [""]
def create_dataset(pipeline_tag: str):
def dataset_gen(model_ids: list[str]):
for model_id in model_ids:
model_card = get_model_card(model_id)
# model_labels = get_model_labels(model_id)
# if len(model_labels) > 1 and len(model_card) > 0:
yield {
"model_id": model_id,
"model_card": model_card,
# "model_labels": model_labels,
}
model_ids = get_model_ids(pipeline_tag)
print(f"Found {len(model_ids)} models")
dataset = Dataset.from_generator(
dataset_gen,
gen_kwargs={"model_ids": model_ids},
features=Features(
{
"model_id": Value("string"),
"model_card": Value("string"),
"model_labels": [Value("string")],
}
),
num_proc=12,
)
return dataset
if __name__ == "__main__":
dataset = create_dataset("image-captioning")
print(dataset)
dataset.push_to_hub("stevenbucaille/image-captioning-models-dataset")