|
""" |
|
aggregate.py - module for aggregating text from multiple sources/multiple parts of a single source. |
|
Primary usage is through the BatchAggregator class. |
|
|
|
How it works: |
|
1. We tell the language model to do it. |
|
2. The language model does it. |
|
3. Yaay! |
|
""" |
|
import logging |
|
import pprint as pp |
|
import time |
|
|
|
import torch |
|
from transformers import GenerationConfig, pipeline |
|
|
|
from utils import compare_model_size |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
class BatchAggregator: |
|
""" |
|
BatchAggregator is a class for aggregating text from multiple sources. |
|
|
|
Usage: |
|
>>> from aggregate import BatchAggregator |
|
>>> aggregator = BatchAggregator() |
|
>>> agg = aggregator.infer_aggregate(["This is a test", "This is another test"]) |
|
>>> print(agg) |
|
""" |
|
|
|
GENERIC_CONFIG = GenerationConfig( |
|
num_beams=8, |
|
early_stopping=True, |
|
do_sample=False, |
|
min_new_tokens=32, |
|
max_new_tokens=256, |
|
repetition_penalty=1.1, |
|
length_penalty=1.4, |
|
no_repeat_ngram_size=4, |
|
encoder_no_repeat_ngram_size=5, |
|
) |
|
CONFIGURED_MODELS = [ |
|
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1", |
|
"pszemraj/bart-base-instruct-dolly_hhrlhf", |
|
"pszemraj/flan-t5-large-instruct-dolly_hhrlhf", |
|
"pszemraj/flan-t5-base-instruct-dolly_hhrlhf", |
|
] |
|
|
|
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", |
|
force_cpu: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
__init__ initializes the BatchAggregator class. |
|
|
|
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" |
|
:param bool force_cpu: force the model to run on CPU, default: False |
|
""" |
|
self.device = None |
|
self.is_compiled = False |
|
self.model_name = None |
|
self.aggregator = None |
|
self.force_cpu = force_cpu |
|
self.logger = logging.getLogger(__name__) |
|
self.init_model(model_name) |
|
|
|
def init_model(self, model_name: str) -> None: |
|
""" |
|
Initialize the model. |
|
|
|
:param model_name: The name of the model to use. |
|
""" |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
self.logger.info(f"Setting model to {model_name}") |
|
self.model_name = model_name |
|
self.aggregator = self._create_pipeline(model_name) |
|
self._configure_model() |
|
|
|
tokenizer_params = { |
|
"decoder_start_token_id": 0 |
|
if "t5" in model_name.lower() |
|
else self.aggregator.tokenizer.eos_token_id, |
|
"eos_token_id": 1 |
|
if "t5" in model_name.lower() |
|
else self.aggregator.tokenizer.eos_token_id, |
|
"pad_token_id": 0 |
|
if "t5" in model_name.lower() |
|
else self.aggregator.tokenizer.pad_token_id, |
|
} |
|
self.update_generation_config(**tokenizer_params) |
|
|
|
def _create_pipeline( |
|
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" |
|
) -> pipeline: |
|
""" |
|
_create_pipeline creates a pipeline for the model. |
|
|
|
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" |
|
:return pipeline: the pipeline for the model |
|
|
|
:raises Exception: if the pipeline cannot be created |
|
""" |
|
self.device = 0 if torch.cuda.is_available() and not self.force_cpu else -1 |
|
try: |
|
self.logger.info( |
|
f"Creating pipeline with model {model_name} on device {self.device}" |
|
) |
|
return pipeline( |
|
"text2text-generation", |
|
model_name, |
|
device=self.device, |
|
torch_dtype=torch.float32, |
|
) |
|
except Exception as e: |
|
self.logger.error(f"Failed to create pipeline: {e}") |
|
raise |
|
|
|
def _configure_model(self): |
|
""" |
|
Configure the model for generation. |
|
""" |
|
try: |
|
self.aggregator.model = torch.compile(self.aggregator.model) |
|
self.is_compiled = True |
|
except Exception as e: |
|
self.logger.warning(f"Could not compile model with Torch 2.0: {e}") |
|
|
|
if self.model_name not in self.CONFIGURED_MODELS: |
|
self.logger.info("Setting generation config to general defaults") |
|
self._set_default_generation_config() |
|
else: |
|
try: |
|
self.logger.info("Loading generation config from hub") |
|
self.aggregator.model.generation_config = ( |
|
GenerationConfig.from_pretrained(self.model_name) |
|
) |
|
except Exception as e: |
|
self.logger.warning( |
|
f"Could not load generation config, using defaults: {e}" |
|
) |
|
self._set_default_generation_config() |
|
|
|
self.logger.info(self.aggregator.model.generation_config.to_json_string()) |
|
|
|
def _set_default_generation_config(self): |
|
""" |
|
Set the default generation configuration for the model. |
|
""" |
|
self.aggregator.model.generation_config = self.GENERIC_CONFIG |
|
|
|
if ( |
|
"large" |
|
or "xl" in self.model_name.lower() |
|
or compare_model_size(self.model_name, 500) |
|
): |
|
upd = {"num_beams": 4} |
|
self.update_generation_config(**upd) |
|
|
|
def update_generation_config(self, **kwargs): |
|
""" |
|
Update the generation configuration with the specified parameters. |
|
|
|
Args: |
|
**kwargs: The parameters to update in the generation configuration. |
|
""" |
|
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") |
|
|
|
self.aggregator.model.generation_config.update(**kwargs) |
|
|
|
def get_generation_config(self) -> dict: |
|
""" |
|
Get the current generation configuration. |
|
|
|
Returns: |
|
dict: The current generation configuration. |
|
""" |
|
return self.aggregator.model.generation_config.to_dict() |
|
|
|
def update_loglevel(self, level: str = "INFO"): |
|
""" |
|
Update the log level. |
|
|
|
Args: |
|
level (str): The log level to set. Defaults to "INFO". |
|
""" |
|
self.logger.setLevel(level) |
|
|
|
def infer_aggregate( |
|
self, |
|
text_list: list, |
|
instruction: str = DEFAULT_INSTRUCTION, |
|
**kwargs, |
|
) -> str: |
|
f""" |
|
infer_aggregate - infers a consolidated summary from a list of texts. |
|
|
|
Args: |
|
text_list (list): The texts to summarize. |
|
instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}. |
|
**kwargs: Additional parameters to update in the generation configuration. |
|
|
|
Returns: |
|
The generated summary. |
|
""" |
|
joined_text = "\n".join(text_list) |
|
prompt = f"{instruction}\n\n{joined_text}\n" |
|
if kwargs: |
|
self.update_generation_config(**kwargs) |
|
st = time.perf_counter() |
|
self.logger.info(f"inference on {len(text_list)} texts ...") |
|
result = self.aggregator( |
|
prompt, |
|
generation_config=self.aggregator.model.generation_config, |
|
)[0]["generated_text"] |
|
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") |
|
self.logger.info( |
|
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" |
|
) |
|
self.logger.debug(f"Generated text:\n{result}") |
|
|
|
return result |
|
|
|
def count_tokens(self, text: str) -> int: |
|
"""count the number of tokens in a text""" |
|
return ( |
|
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) |
|
if text |
|
else 0 |
|
) |
|
|