davanstrien HF Staff commited on
Commit
aaa2fc9
·
1 Parent(s): 58454eb

refactor: remove FLASHINFER environment variable and update LLM initialization for batch processing

Browse files
Files changed (1) hide show
  1. generate_summaries_uv.py +20 -6
generate_summaries_uv.py CHANGED
@@ -20,7 +20,6 @@ from typing import Optional
20
 
21
  # Set environment variables to speed up model loading
22
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
23
- # os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
24
 
25
  import polars as pl
26
  from datasets import Dataset, load_dataset
@@ -112,7 +111,10 @@ def generate_summaries(
112
 
113
  # Initialize model and tokenizer from local path
114
  logger.info(f"Initializing vLLM model from local path: {local_model_path}")
115
- llm = LLM(model=local_model_path)
 
 
 
116
  tokenizer = AutoTokenizer.from_pretrained(local_model_path)
117
  sampling_params = SamplingParams(
118
  temperature=temperature,
@@ -131,10 +133,22 @@ def generate_summaries(
131
  logger.info(f"Generating summaries for {len(prompts)} items")
132
  all_outputs = []
133
 
134
- for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"):
135
- batch_prompts = prompts[i : i + batch_size]
136
- outputs = llm.generate(batch_prompts, sampling_params)
137
- all_outputs.extend(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Extract clean results
140
  clean_results = [output.outputs[0].text.strip() for output in all_outputs]
 
20
 
21
  # Set environment variables to speed up model loading
22
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
23
 
24
  import polars as pl
25
  from datasets import Dataset, load_dataset
 
111
 
112
  # Initialize model and tokenizer from local path
113
  logger.info(f"Initializing vLLM model from local path: {local_model_path}")
114
+ llm = LLM(
115
+ model=local_model_path,
116
+ max_model_len=4096, # Adjust based on model capabilities
117
+ )
118
  tokenizer = AutoTokenizer.from_pretrained(local_model_path)
119
  sampling_params = SamplingParams(
120
  temperature=temperature,
 
133
  logger.info(f"Generating summaries for {len(prompts)} items")
134
  all_outputs = []
135
 
136
+ # for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"):
137
+ # batch_prompts = prompts[i : i + batch_size]
138
+ # outputs = llm.generate(batch_prompts, sampling_params)
139
+ # all_outputs.extend(outputs)
140
+ # try directly doing whole dataset
141
+ all_outputs = llm.generate(
142
+ prompts,
143
+ sampling_params,
144
+ batch_size=batch_size,
145
+ max_batch_size=batch_size,
146
+ )
147
+ logger.info(f"Generated {len(all_outputs)} summaries")
148
+ if len(all_outputs) != len(prompts):
149
+ logger.warning(
150
+ f"Generated {len(all_outputs)} summaries, but expected {len(prompts)}. Some prompts may have failed."
151
+ )
152
 
153
  # Extract clean results
154
  clean_results = [output.outputs[0].text.strip() for output in all_outputs]