Spaces:
Runtime error
Runtime error
File size: 6,126 Bytes
77d5469 87e5c9c 77d5469 87e5c9c c006617 87e5c9c 8312087 87e5c9c c006617 87e5c9c c006617 87e5c9c c006617 87e5c9c 5407b63 9ebcabc 5f857f0 9ebcabc 5407b63 048132e 9350787 87e5c9c 5407b63 9350787 2956200 5407b63 87e5c9c 5407b63 2956200 8312087 87e5c9c 7813441 c006617 87e5c9c 9350787 c006617 87e5c9c c006617 87e5c9c 9350787 87e5c9c 77d5469 87e5c9c c006617 87e5c9c c006617 87e5c9c c006617 87e5c9c 55b49e6 87e5c9c 55b49e6 87e5c9c c006617 87e5c9c c006617 87e5c9c 77d5469 c006617 77d5469 c006617 77d5469 87e5c9c 5dfe75e 87e5c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
summarize - a module for summarizing text using a model from the Hugging Face model hub
"""
import logging
import pprint as pp
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import validate_pytorch2
def load_model_and_tokenizer(model_name: str) -> tuple:
"""
load_model_and_tokenizer - load a model and tokenizer from a model name/ID on the hub
:param str model_name: the model name/ID on the hub
:return tuple: a tuple containing the model and tokenizer
"""
MODEL_OPTIONS = {
"Text Summarizer": "pszemraj/long-t5-tglobal-base-16384-book-summary",
"News Article Summarizer Alpha": "pszemraj/long-t5-tglobal-base-sci-simplify",
"News Article Summarizer Beta": "pszemraj/long-t5-tglobal-base-sci-simplify-elife",
"Scientific Document Summarizer Alpha": "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
"Scientific Document Summarizer Beta": "pszemraj/pegasus-x-large-book-summary",
}
selected_model_identifier = MODEL_OPTIONS.get(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(
selected_model_identifier,
).to(device)
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(selected_model_identifier)
logging.info(f"Loaded model {selected_model_identifier} to {device}")
if validate_pytorch2():
try:
logging.info("Compiling model with Torch 2.0")
model = torch.compile(model)
except Exception as e:
logging.warning(f"Could not compile model with Torch 2.0: {e}")
else:
logging.info("Torch 2.0 not detected, skipping compilation")
return model, tokenizer
def summarize_and_score(
ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs
) -> tuple:
"""
summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary
Args:
ids (): the batch of ids
mask (): the attention mask for the batch
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True.
**kwargs: any additional arguments to pass to the model
Returns:
tuple (str, float): the summary, the score for the summary
"""
ids = ids[None, :]
mask = mask[None, :]
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
global_attention_mask = torch.zeros_like(attention_mask)
# put global attention on <s> token
global_attention_mask[:, 0] = 1
if is_general_attention_model:
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
else:
summary_pred_ids = model.generate(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
summary = tokenizer.batch_decode(
summary_pred_ids.sequences,
skip_special_tokens=True,
remove_invalid_values=True,
)
score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4)
return summary, score
def summarize_via_tokenbatches(
input_text: str,
model,
tokenizer,
batch_length=2048,
batch_stride=16,
min_batch_length=512,
**kwargs,
) -> list:
"""
summarize_via_tokenbatches - summarize a long string via batches of tokens
Args:
input_text (str): the text to summarize
model (): the model to use for summarization
tokenizer (): the tokenizer to use for summarization
batch_length (int, optional): the length of each batch. Defaults to 2048.
batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches.
min_batch_length (int, optional): the minimum length of each batch. Defaults to 512.
**kwargs: any additional arguments to pass to the model for inference
Returns:
list: a list of dictionaries containing the input tokens, the summary, and the summary score
"""
logger = logging.getLogger(__name__)
# log all input parameters
if batch_length < min_batch_length:
logger.warning(
f"batch_length must be at least {min_batch_length}. Setting batch_length to {min_batch_length}"
)
batch_length = min_batch_length
logger.info(f"input parameters:\n{pp.pformat(kwargs)}")
logger.info(f"batch_length: {batch_length}, batch_stride: {batch_stride}")
encoded_input = tokenizer(
input_text,
padding="max_length",
truncation=True,
max_length=batch_length,
stride=batch_stride,
return_overflowing_tokens=True,
add_special_tokens=False,
return_tensors="pt",
)
in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask
gen_summaries = []
pbar = tqdm(total=len(in_id_arr))
for _id, _mask in zip(in_id_arr, att_arr):
result, score = summarize_and_score(
ids=_id,
mask=_mask,
model=model,
tokenizer=tokenizer,
**kwargs,
)
score = round(float(score), 4)
_sum = {
"input_tokens": _id,
"summary": result,
"summary_score": score,
}
gen_summaries.append(_sum)
logger.debug(f"Score for batch: {score}. num chars: {len(repr(result))}")
logger.debug(f"Summary:\n\t{result}")
pbar.update()
pbar.close()
return gen_summaries
|