import os import platform import re from collections import defaultdict import gradio as gr from cachetools import TTLCache, cached from cytoolz import groupby from huggingface_hub import CollectionItem, get_collection, list_datasets, list_models from tqdm.auto import tqdm from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" is_macos = platform.system() == "Darwin" local = platform.system() == "Darwin" LIMIT = 1000 if is_macos else None # limit for local dev because slooow internet CACHE_TIME = 60 * 15 # 15 minutes @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def get_models(): print("getting models...") return list(tqdm(iter(list_models(full=True, limit=LIMIT)))) @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def get_datasets(): print("getting datasets...") return list(tqdm(iter(list_datasets(full=True, limit=LIMIT)))) get_models() # warm up the cache get_datasets() # warm up the cache def check_for_arxiv_id(model): return [tag for tag in model.tags if "arxiv" in tag] if model.tags else False def extract_arxiv_id(input_string: str) -> str: pattern = re.compile(r"\barxiv:(\d+\.\d+)\b") match = pattern.search(input_string) return match[1] if match else None @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def create_model_to_arxiv_id_dict(): models = get_models() model_to_arxiv_id = {} for model in models: if arxiv_papers := check_for_arxiv_id(model): clean_arxiv_ids = [] for paper in arxiv_papers: if arxiv_id := extract_arxiv_id(paper): clean_arxiv_ids.append(arxiv_id) model_to_arxiv_id[model.modelId] = clean_arxiv_ids return model_to_arxiv_id @cached(cache=TTLCache(maxsize=100, ttl=CACHE_TIME)) def create_dataset_to_arxiv_id_dict(): datasets = get_datasets() dataset_to_arxiv_id = {} for dataset in datasets: if arxiv_papers := check_for_arxiv_id(dataset): clean_arxiv_ids = [] for paper in arxiv_papers: if arxiv_id := extract_arxiv_id(paper): clean_arxiv_ids.append(arxiv_id) dataset_to_arxiv_id[dataset.id] = clean_arxiv_ids return dataset_to_arxiv_id def get_collection_type(collection_item: CollectionItem): try: return f"{collection_item.item_type}s" except AttributeError: return None def group_collection_items(collection_slug: str): collection = get_collection(collection_slug) items = collection.items return groupby(get_collection_type, items) @cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME)) def get_papers_for_collection(collection_slug: str): dataset_to_arxiv_id = create_dataset_to_arxiv_id_dict() models_to_arxiv_id = create_model_to_arxiv_id_dict() collection = group_collection_items(collection_slug) collection_datasets = collection.get("datasets", None) collection_models = collection.get("models", None) papers = collection.get("papers", None) dataset_papers = defaultdict(dict) model_papers = defaultdict(dict) collection_papers = defaultdict(dict) if collection_datasets is not None: for dataset in collection_datasets: if arxiv_ids := dataset_to_arxiv_id.get(dataset.item_id, None): data = { "arxiv_ids": arxiv_ids, "hub_paper_links": [ f"https://huggingface.co/papers/{arxiv_id}" for arxiv_id in arxiv_ids ], } dataset_papers[dataset.item_id] = data if collection_models is not None: for model in collection.get("models", []): if arxiv_ids := models_to_arxiv_id.get(model.item_id, None): data = { "arxiv_ids": arxiv_ids, "hub_paper_links": [ f"https://huggingface.co/papers/{arxiv_id}" for arxiv_id in arxiv_ids ], } model_papers[model.item_id] = data if papers is not None: for paper in papers: data = { "arxiv_ids": [paper.item_id], "hub_paper_links": [f"https://huggingface.co/papers/{paper.item_id}"], } collection_papers[paper.item_id] = data if not dataset_papers: dataset_papers = None if not model_papers: model_papers = None if not collection_papers: collection_papers = None return { "dataset papers": dataset_papers, "model papers": model_papers, "papers": collection_papers, } scheduler = BackgroundScheduler() scheduler.add_job(get_datasets, "interval", minutes=15) scheduler.add_job(get_models, "interval", minutes=15) scheduler.start() placeholder_url = "HF-IA-archiving/models-to-archive-65006a7fdadb8c628f33aac9" slug_input = gr.Textbox( placeholder=placeholder_url, interactive=True, label="Collection slug", max_lines=1 ) description = ( "Enter a Collection slug to get the arXiv IDs and Hugging Face Paper links for" " papers associated with models and datasets in the collection. If the collection" " includes papers the arXiv IDs and Hugging Face Paper links will be returned for" " those papers as well." ) examples = [ placeholder_url, "davanstrien/historic-language-modeling-64f99e243188ade79d7ad74b", ] gr.Interface( get_papers_for_collection, slug_input, "json", title="📄🔗: Extract linked papers from a Hugging Face Collection", description=description, examples=examples, cache_examples=True, ).queue(concurrency_count=4).launch()