ScouterAI / rag /create_dataset.py
stevenbucaille's picture
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
raw
history blame
2.35 kB
import json
from smolagents import Tool
from huggingface_hub import HfApi, hf_hub_download, ModelCard
from datasets import Dataset, Features, Value
def get_model_ids(pipeline_tag: str) -> list[str]:
hf_api = HfApi()
models = hf_api.list_models(
library=["transformers"],
pipeline_tag=pipeline_tag,
gated=False,
fetch_config=True,
)
models = list(models)
model_ids = [model.id for model in models]
return model_ids
def get_model_card(model_id: str) -> str:
try:
model_card = ModelCard.load(model_id)
return model_card.text
except Exception as e:
return ""
def get_model_labels(model_id: str) -> list[str]:
hf_api = HfApi()
if hf_api.file_exists(model_id, filename="config.json"):
config_path = hf_hub_download(model_id, filename="config.json")
with open(config_path, "r") as f:
try:
model_config = json.load(f)
except json.JSONDecodeError:
return [""]
if "id2label" in model_config:
labels = list(model_config["id2label"].values())
labels = [str(label).lower() for label in labels]
return labels
else:
return [""]
else:
return [""]
def create_dataset(pipeline_tag: str, repo_id: str):
def dataset_gen(model_ids: list[str]):
for model_id in model_ids:
model_card = get_model_card(model_id)
model_labels = get_model_labels(model_id)
if len(model_labels) > 1 and len(model_card) > 0:
yield {
"model_id": model_id,
"model_card": model_card,
"model_labels": model_labels,
}
model_ids = get_model_ids(pipeline_tag)
dataset = Dataset.from_generator(
dataset_gen,
gen_kwargs={"model_ids": model_ids},
features=Features(
{
"model_id": Value("string"),
"model_card": Value("string"),
"model_labels": [Value("string")],
}
),
num_proc=12,
)
dataset.push_to_hub(repo_id)
return dataset
if __name__ == "__main__":
dataset = create_dataset("image-segmentation", "stevenbucaille/image-segmentation-models-dataset")
print(dataset)