huggingface-datasets-search-v2 / generate_summaries_uv.py
davanstrien's picture
davanstrien HF Staff
fix: Remove custom index URLs to resolve dependency conflicts
d7ac8e4
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets",
# "flashinfer-python",
# "hf_transfer",
# "huggingface-hub[hf_xet]",
# "polars",
# "stamina",
# "transformers",
# "vllm",
# "tqdm",
# "setuptools",
# ]
# ///
import argparse
import logging
import os
import sys
from typing import Optional
# Set environment variables to speed up model loading
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
import polars as pl
from datasets import Dataset, load_dataset
from huggingface_hub import login, dataset_info, snapshot_download
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# Setup logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def format_prompt(content: str, card_type: str, tokenizer) -> str:
"""Format content as a prompt for the model."""
if card_type == "model":
messages = [{"role": "user", "content": f"<MODEL_CARD>{content[:4000]}"}]
else:
messages = [{"role": "user", "content": f"<DATASET_CARD>{content[:4000]}"}]
return tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
def load_and_filter_data(
dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1
) -> pl.DataFrame:
"""Load and filter dataset/model data."""
logger.info(f"Loading data from {dataset_id}")
ds = load_dataset(dataset_id, split="train")
df = ds.to_polars().lazy()
# Extract content after YAML frontmatter
df = df.with_columns(
[
pl.col("card")
.str.replace_all(r"^---\n[\s\S]*?\n---\n", "", literal=False)
.str.strip_chars()
.alias("post_yaml_content")
]
)
# Apply filters
df = df.filter(pl.col("post_yaml_content").str.len_bytes() > 200)
df = df.filter(pl.col("post_yaml_content").str.len_bytes() < 120_000)
if card_type == "model":
df = df.filter(pl.col("likes") >= min_likes)
df = df.filter(pl.col("downloads") >= min_downloads)
df_filtered = df.collect()
logger.info(f"Filtered dataset has {len(df_filtered)} items")
return df_filtered
def generate_summaries(
model_id: str,
input_dataset_id: str,
output_dataset_id: str,
card_type: str = "dataset",
max_tokens: int = 120,
temperature: float = 0.6,
batch_size: int = 1000,
min_likes: int = 1,
min_downloads: int = 1,
hf_token: Optional[str] = None,
):
"""Main function to generate summaries."""
# Login if token provided
HF_TOKEN = hf_token or os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
# Load and filter data
df_filtered = load_and_filter_data(
input_dataset_id, card_type, min_likes, min_downloads
)
# Download model to local directory first
logger.info(f"Downloading model {model_id} to local directory...")
local_model_path = snapshot_download(repo_id=model_id, resume_download=True)
logger.info(f"Model downloaded to: {local_model_path}")
# Initialize model and tokenizer from local path
logger.info(f"Initializing vLLM model from local path: {local_model_path}")
llm = LLM(model=local_model_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
)
# Prepare prompts
logger.info("Preparing prompts")
post_yaml_contents = df_filtered["post_yaml_content"].to_list()
prompts = [
format_prompt(content, card_type, tokenizer)
for content in tqdm(post_yaml_contents, desc="Formatting prompts")
]
# Generate summaries in batches
logger.info(f"Generating summaries for {len(prompts)} items")
all_outputs = []
for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"):
batch_prompts = prompts[i : i + batch_size]
outputs = llm.generate(batch_prompts, sampling_params)
all_outputs.extend(outputs)
# Extract clean results
clean_results = [output.outputs[0].text.strip() for output in all_outputs]
# Create dataset and add summaries
ds = Dataset.from_polars(df_filtered)
ds = ds.add_column("summary", clean_results)
# Push to hub
logger.info(f"Pushing dataset to hub: {output_dataset_id}")
ds.push_to_hub(output_dataset_id, token=HF_TOKEN)
logger.info("Dataset successfully pushed to hub")
def main():
parser = argparse.ArgumentParser(
description="Generate summaries for Hugging Face datasets or models using vLLM"
)
parser.add_argument(
"model_id",
help="Model ID for summary generation (e.g., davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02)",
)
parser.add_argument(
"input_dataset_id",
help="Input dataset ID (e.g., librarian-bots/dataset_cards_with_metadata)",
)
parser.add_argument(
"output_dataset_id", help="Output dataset ID where results will be saved"
)
parser.add_argument(
"--card-type",
choices=["dataset", "model"],
default="dataset",
help="Type of cards to process (default: dataset)",
)
parser.add_argument(
"--max-tokens",
type=int,
default=120,
help="Maximum tokens for summary generation (default: 120)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.6,
help="Temperature for generation (default: 0.6)",
)
parser.add_argument(
"--batch-size",
type=int,
default=1000,
help="Batch size for processing (default: 1000)",
)
parser.add_argument(
"--min-likes",
type=int,
default=1,
help="Minimum likes filter for models (default: 1)",
)
parser.add_argument(
"--min-downloads",
type=int,
default=1,
help="Minimum downloads filter for models (default: 1)",
)
parser.add_argument(
"--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)"
)
args = parser.parse_args()
generate_summaries(
model_id=args.model_id,
input_dataset_id=args.input_dataset_id,
output_dataset_id=args.output_dataset_id,
card_type=args.card_type,
max_tokens=args.max_tokens,
temperature=args.temperature,
batch_size=args.batch_size,
min_likes=args.min_likes,
min_downloads=args.min_downloads,
hf_token=args.hf_token,
)
if __name__ == "__main__":
if len(sys.argv) == 1:
# Show example hfjobs command when run without arguments
print("Example hfjobs command:")
print(
"hfjobs run --flavor l4x1 --secret HF_TOKEN=hf_*** ghcr.io/astral-sh/uv:debian /bin/bash -c '"
)
print("apt-get update && apt-get install -y python3-dev gcc && \\")
print("export HOME=/tmp && \\")
print("export USER=dummy && \\")
print("export TORCHINDUCTOR_CACHE_DIR=/tmp/torch-inductor && \\")
print("uv run generate_summaries_uv.py \\")
print(" davanstrien/Smol-Hub-tldr \\")
print(" librarian-bots/dataset_cards_with_metadata \\")
print(" your-username/datasets_with_summaries \\")
print(" --card-type dataset \\")
print(" --batch-size 2000")
print("' --project summary-generation --name dataset-summaries")
print()
print("For models:")
print("uv run generate_summaries_uv.py \\")
print(" davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02 \\")
print(" librarian-bots/model_cards_with_metadata \\")
print(" your-username/models_with_summaries \\")
print(" --card-type model \\")
print(" --min-likes 5 \\")
print(" --min-downloads 1000")
else:
main()