|
|
|
import logging |
|
import time |
|
|
|
import torch |
|
from transformers import GenerationConfig, pipeline |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
class BatchAggregator: |
|
def __init__( |
|
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs |
|
): |
|
self.logger = logging.getLogger(__name__) |
|
self.model_name = model_name |
|
self.logger.info(f"Initializing aggregator with model {model_name}") |
|
self.aggregator = pipeline( |
|
"text2text-generation", |
|
model_name, |
|
device=0 if torch.cuda.is_available() else -1, |
|
torch_dtype=torch.float32, |
|
) |
|
|
|
try: |
|
self.aggregator.model = torch.compile(self.aggregator.model) |
|
except Exception as e: |
|
self.logger.warning(f"Could not compile model with Torch 2.0: {e}") |
|
try: |
|
self.aggregator.model.generation_config = GenerationConfig.from_pretrained( |
|
self.model_name |
|
) |
|
except Exception as e: |
|
self.logger.warning( |
|
f"Could not load generation config, using defaults: {e}" |
|
) |
|
self.aggregator.model.generation_config = GenerationConfig( |
|
num_beams=4, |
|
early_stopping=True, |
|
do_sample=False, |
|
min_new_tokens=32, |
|
max_new_tokens=192, |
|
repetition_penalty=1.1, |
|
length_penalty=1.5, |
|
no_repeat_ngram_size=4, |
|
encoder_no_repeat_ngram_size=5, |
|
decoder_start_token_id=0, |
|
eos_token_id=1, |
|
pad_token_id=0, |
|
) |
|
|
|
if "bart" in model_name.lower(): |
|
self.logger.info("Using BART model, updating generation config") |
|
upd = { |
|
"num_beams": 8, |
|
"repetition_penalty": 1.3, |
|
"length_penalty": 1.0, |
|
"_from_model_config": False, |
|
"max_new_tokens": 256, |
|
"min_new_tokens": 32, |
|
"no_repeat_ngram_size": 3, |
|
"encoder_no_repeat_ngram_size": 6, |
|
} |
|
self.aggregator.model.generation_config.update(**upd) |
|
if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1": |
|
self.logger.info("Updating generation config with defaults") |
|
self.update_generation_config() |
|
self.logger.info(self.aggregator.model.generation_config.to_json_string()) |
|
|
|
def update_generation_config(self, **kwargs): |
|
self.logger.info(f"Updating generation config with {kwargs}") |
|
default = GenerationConfig( |
|
num_beams=4, |
|
early_stopping=True, |
|
do_sample=False, |
|
min_new_tokens=32, |
|
max_new_tokens=192, |
|
repetition_penalty=1.1, |
|
length_penalty=1.5, |
|
no_repeat_ngram_size=4, |
|
encoder_no_repeat_ngram_size=5, |
|
decoder_start_token_id=0, |
|
eos_token_id=1, |
|
pad_token_id=0, |
|
).to_dict() |
|
self.aggregator.model.generation_config.update(**default) |
|
def _replace_pipeline(model_name) |
|
def infer_aggregate( |
|
self, |
|
text_list: list, |
|
instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:", |
|
**kwargs, |
|
): |
|
joined_text = "\n".join(text_list) |
|
prompt = f"{instruction}\n\n{joined_text}\n" |
|
if kwargs: |
|
self.update_generation_config(**kwargs) |
|
st = time.perf_counter() |
|
self.logger.info(f"Running inference on {len(text_list)} texts") |
|
result = self.aggregator( |
|
prompt, |
|
generation_config=self.aggregator.model.generation_config, |
|
)[0]["generated_text"] |
|
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") |
|
self.logger.info( |
|
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" |
|
) |
|
return result |
|
|
|
def count_tokens(self, text: str): |
|
return ( |
|
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) |
|
if text |
|
else 0 |
|
) |
|
|