# imports import logging import time import torch from transformers import GenerationConfig, pipeline # Setting up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class BatchAggregator: def __init__( self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs ): self.logger = logging.getLogger(__name__) self.model_name = model_name self.logger.info(f"Initializing aggregator with model {model_name}") self.aggregator = pipeline( "text2text-generation", model_name, device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.float32, ) try: self.aggregator.model = torch.compile(self.aggregator.model) except Exception as e: self.logger.warning(f"Could not compile model with Torch 2.0: {e}") try: self.aggregator.model.generation_config = GenerationConfig.from_pretrained( self.model_name ) except Exception as e: self.logger.warning( f"Could not load generation config, using defaults: {e}" ) self.aggregator.model.generation_config = GenerationConfig( num_beams=4, early_stopping=True, do_sample=False, min_new_tokens=32, max_new_tokens=192, repetition_penalty=1.1, length_penalty=1.5, no_repeat_ngram_size=4, encoder_no_repeat_ngram_size=5, decoder_start_token_id=0, eos_token_id=1, pad_token_id=0, ) if "bart" in model_name.lower(): self.logger.info("Using BART model, updating generation config") upd = { "num_beams": 8, "repetition_penalty": 1.3, "length_penalty": 1.0, "_from_model_config": False, "max_new_tokens": 256, "min_new_tokens": 32, "no_repeat_ngram_size": 3, "encoder_no_repeat_ngram_size": 6, } self.aggregator.model.generation_config.update(**upd) if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1": self.logger.info("Updating generation config with defaults") self.update_generation_config() self.logger.info(self.aggregator.model.generation_config.to_json_string()) def update_generation_config(self, **kwargs): self.logger.info(f"Updating generation config with {kwargs}") default = GenerationConfig( num_beams=4, early_stopping=True, do_sample=False, min_new_tokens=32, max_new_tokens=192, repetition_penalty=1.1, length_penalty=1.5, no_repeat_ngram_size=4, encoder_no_repeat_ngram_size=5, decoder_start_token_id=0, eos_token_id=1, pad_token_id=0, ).to_dict() self.aggregator.model.generation_config.update(**default) def _replace_pipeline(model_name) def infer_aggregate( self, text_list: list, instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:", **kwargs, ): joined_text = "\n".join(text_list) prompt = f"{instruction}\n\n{joined_text}\n" if kwargs: self.update_generation_config(**kwargs) st = time.perf_counter() self.logger.info(f"Running inference on {len(text_list)} texts") result = self.aggregator( prompt, generation_config=self.aggregator.model.generation_config, )[0]["generated_text"] self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") self.logger.info( f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" ) return result def count_tokens(self, text: str): return ( len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) if text else 0 )