davanstrien HF Staff commited on
Commit
0c77e14
·
1 Parent(s): 58fabfc

summaries script

Browse files
Files changed (1) hide show
  1. generate_summaries_uv.py +241 -0
generate_summaries_uv.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "datasets",
5
+ # "flashinfer-python>=0.2.3",
6
+ # "huggingface-hub[hf_xet]",
7
+ # "polars",
8
+ # "stamina",
9
+ # "transformers",
10
+ # "vllm",
11
+ # "tqdm",
12
+ # ]
13
+ # ///
14
+
15
+ import argparse
16
+ import logging
17
+ import os
18
+ import sys
19
+ from typing import Optional
20
+
21
+ import polars as pl
22
+ from datasets import Dataset, load_dataset
23
+ from huggingface_hub import login, dataset_info
24
+ from tqdm.auto import tqdm
25
+ from transformers import AutoTokenizer
26
+ from vllm import LLM, SamplingParams
27
+
28
+ # Setup logging
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format="%(asctime)s - %(levelname)s - %(message)s",
32
+ datefmt="%Y-%m-%d %H:%M:%S",
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def format_prompt(content: str, card_type: str, tokenizer) -> str:
38
+ """Format content as a prompt for the model."""
39
+ if card_type == "model":
40
+ messages = [{"role": "user", "content": f"<MODEL_CARD>{content[:4000]}"}]
41
+ else:
42
+ messages = [{"role": "user", "content": f"<DATASET_CARD>{content[:4000]}"}]
43
+
44
+ return tokenizer.apply_chat_template(
45
+ messages, add_generation_prompt=True, tokenize=False
46
+ )
47
+
48
+
49
+ def load_and_filter_data(
50
+ dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1
51
+ ) -> pl.DataFrame:
52
+ """Load and filter dataset/model data."""
53
+ logger.info(f"Loading data from {dataset_id}")
54
+ ds = load_dataset(dataset_id, split="train")
55
+ df = ds.to_polars().lazy()
56
+
57
+ # Extract content after YAML frontmatter
58
+ df = df.with_columns(
59
+ [
60
+ pl.col("card")
61
+ .str.replace_all(r"^---\n[\s\S]*?\n---\n", "", literal=False)
62
+ .str.strip_chars()
63
+ .alias("post_yaml_content")
64
+ ]
65
+ )
66
+
67
+ # Apply filters
68
+ df = df.filter(pl.col("post_yaml_content").str.len_bytes() > 200)
69
+ df = df.filter(pl.col("post_yaml_content").str.len_bytes() < 120_000)
70
+
71
+ if card_type == "model":
72
+ df = df.filter(pl.col("likes") >= min_likes)
73
+ df = df.filter(pl.col("downloads") >= min_downloads)
74
+
75
+ df_filtered = df.collect()
76
+ logger.info(f"Filtered dataset has {len(df_filtered)} items")
77
+ return df_filtered
78
+
79
+
80
+ def generate_summaries(
81
+ model_id: str,
82
+ input_dataset_id: str,
83
+ output_dataset_id: str,
84
+ card_type: str = "dataset",
85
+ max_tokens: int = 120,
86
+ temperature: float = 0.6,
87
+ batch_size: int = 1000,
88
+ min_likes: int = 1,
89
+ min_downloads: int = 1,
90
+ hf_token: Optional[str] = None,
91
+ ):
92
+ """Main function to generate summaries."""
93
+
94
+ # Login if token provided
95
+ HF_TOKEN = hf_token or os.environ.get("HF_TOKEN")
96
+ if HF_TOKEN:
97
+ login(token=HF_TOKEN)
98
+
99
+ # Load and filter data
100
+ df_filtered = load_and_filter_data(
101
+ input_dataset_id, card_type, min_likes, min_downloads
102
+ )
103
+
104
+ # Initialize model and tokenizer
105
+ logger.info(f"Initializing vLLM model: {model_id}")
106
+ llm = LLM(model=model_id)
107
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
108
+ sampling_params = SamplingParams(
109
+ temperature=temperature,
110
+ max_tokens=max_tokens,
111
+ )
112
+
113
+ # Prepare prompts
114
+ logger.info("Preparing prompts")
115
+ post_yaml_contents = df_filtered["post_yaml_content"].to_list()
116
+ prompts = [
117
+ format_prompt(content, card_type, tokenizer)
118
+ for content in tqdm(post_yaml_contents, desc="Formatting prompts")
119
+ ]
120
+
121
+ # Generate summaries in batches
122
+ logger.info(f"Generating summaries for {len(prompts)} items")
123
+ all_outputs = []
124
+
125
+ for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"):
126
+ batch_prompts = prompts[i : i + batch_size]
127
+ outputs = llm.generate(batch_prompts, sampling_params)
128
+ all_outputs.extend(outputs)
129
+
130
+ # Extract clean results
131
+ clean_results = [output.outputs[0].text.strip() for output in all_outputs]
132
+
133
+ # Create dataset and add summaries
134
+ ds = Dataset.from_polars(df_filtered)
135
+ ds = ds.add_column("summary", clean_results)
136
+
137
+ # Push to hub
138
+ logger.info(f"Pushing dataset to hub: {output_dataset_id}")
139
+ ds.push_to_hub(output_dataset_id, token=HF_TOKEN)
140
+ logger.info("Dataset successfully pushed to hub")
141
+
142
+
143
+ def main():
144
+ parser = argparse.ArgumentParser(
145
+ description="Generate summaries for Hugging Face datasets or models using vLLM"
146
+ )
147
+ parser.add_argument(
148
+ "model_id",
149
+ help="Model ID for summary generation (e.g., davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02)",
150
+ )
151
+ parser.add_argument(
152
+ "input_dataset_id",
153
+ help="Input dataset ID (e.g., librarian-bots/dataset_cards_with_metadata)",
154
+ )
155
+ parser.add_argument(
156
+ "output_dataset_id", help="Output dataset ID where results will be saved"
157
+ )
158
+ parser.add_argument(
159
+ "--card-type",
160
+ choices=["dataset", "model"],
161
+ default="dataset",
162
+ help="Type of cards to process (default: dataset)",
163
+ )
164
+ parser.add_argument(
165
+ "--max-tokens",
166
+ type=int,
167
+ default=120,
168
+ help="Maximum tokens for summary generation (default: 120)",
169
+ )
170
+ parser.add_argument(
171
+ "--temperature",
172
+ type=float,
173
+ default=0.6,
174
+ help="Temperature for generation (default: 0.6)",
175
+ )
176
+ parser.add_argument(
177
+ "--batch-size",
178
+ type=int,
179
+ default=1000,
180
+ help="Batch size for processing (default: 1000)",
181
+ )
182
+ parser.add_argument(
183
+ "--min-likes",
184
+ type=int,
185
+ default=1,
186
+ help="Minimum likes filter for models (default: 1)",
187
+ )
188
+ parser.add_argument(
189
+ "--min-downloads",
190
+ type=int,
191
+ default=1,
192
+ help="Minimum downloads filter for models (default: 1)",
193
+ )
194
+ parser.add_argument(
195
+ "--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)"
196
+ )
197
+
198
+ args = parser.parse_args()
199
+
200
+ generate_summaries(
201
+ model_id=args.model_id,
202
+ input_dataset_id=args.input_dataset_id,
203
+ output_dataset_id=args.output_dataset_id,
204
+ card_type=args.card_type,
205
+ max_tokens=args.max_tokens,
206
+ temperature=args.temperature,
207
+ batch_size=args.batch_size,
208
+ min_likes=args.min_likes,
209
+ min_downloads=args.min_downloads,
210
+ hf_token=args.hf_token,
211
+ )
212
+
213
+
214
+ if __name__ == "__main__":
215
+ if len(sys.argv) == 1:
216
+ # Show example hfjobs command when run without arguments
217
+ print("Example hfjobs command:")
218
+ print(
219
+ "hfjobs run --flavor l4x1 --secret HF_TOKEN=hf_*** ghcr.io/astral-sh/uv:debian /bin/bash -c '"
220
+ )
221
+ print("export HOME=/tmp && \\")
222
+ print("export USER=dummy && \\")
223
+ print("export TORCHINDUCTOR_CACHE_DIR=/tmp/torch-inductor && \\")
224
+ print("uv run generate_summaries_uv.py \\")
225
+ print(" davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02 \\")
226
+ print(" librarian-bots/dataset_cards_with_metadata \\")
227
+ print(" your-username/datasets_with_summaries \\")
228
+ print(" --card-type dataset \\")
229
+ print(" --batch-size 2000")
230
+ print("' --project summary-generation --name dataset-summaries")
231
+ print()
232
+ print("For models:")
233
+ print("uv run generate_summaries_uv.py \\")
234
+ print(" davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02 \\")
235
+ print(" librarian-bots/model_cards_with_metadata \\")
236
+ print(" your-username/models_with_summaries \\")
237
+ print(" --card-type model \\")
238
+ print(" --min-likes 5 \\")
239
+ print(" --min-downloads 1000")
240
+ else:
241
+ main()