ScouterAI / rag /create_image_classification_dataset.py
stevenbucaille's picture
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
raw
history blame
2.6 kB
import json
from datasets import Dataset, Features, Value
from huggingface_hub import HfApi, ModelCard, hf_hub_download
from huggingface_hub.utils import disable_progress_bars
import tqdm
def get_model_ids(pipeline_tag: str) -> list[str]:
hf_api = HfApi()
models = hf_api.list_models(
library=["transformers"],
pipeline_tag=pipeline_tag,
gated=False,
fetch_config=True,
)
models = list(models)
model_ids = [model.id for model in models]
return model_ids
def get_model_card(model_id: str) -> str:
try:
model_card = ModelCard.load(model_id)
return model_card.text
except Exception:
return ""
def get_model_labels(model_id: str) -> list[str]:
hf_api = HfApi()
if hf_api.file_exists(model_id, filename="config.json"):
config_path = hf_hub_download(model_id, filename="config.json")
with open(config_path, "r") as f:
try:
model_config = json.load(f)
except json.JSONDecodeError:
return []
if "id2label" in model_config:
labels = list(model_config["id2label"].values())
labels = [str(label).lower() for label in labels]
return labels
else:
return []
else:
return []
def create_dataset(pipeline_tag: str):
def dataset_gen(model_ids: list[str]):
for model_id in model_ids:
try:
model_card = get_model_card(model_id)
model_labels = get_model_labels(model_id)
except Exception as e:
print(f"Error getting model card or labels for {model_id}: {e}")
continue
if len(model_labels) > 1 and len(model_card) > 0:
yield {
"model_id": model_id,
"model_card": model_card,
"model_labels": model_labels,
}
model_ids = get_model_ids(pipeline_tag)
print(f"Found {len(model_ids)} models")
dataset = Dataset.from_generator(
dataset_gen,
gen_kwargs={"model_ids": model_ids},
features=Features(
{
"model_id": Value("string"),
"model_card": Value("string"),
"model_labels": [Value("string")],
}
),
num_proc=24,
)
return dataset
if __name__ == "__main__":
disable_progress_bars()
dataset = create_dataset("image-classification")
print(dataset)
dataset.push_to_hub("stevenbucaille/image-classification-models-dataset")