|
import gc |
|
import copy |
|
from tenacity import RetryError |
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
|
import torch |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
LogitsProcessorList, |
|
MinNewTokensLengthLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopPLogitsWarper, |
|
) |
|
|
|
def get_output_batch( |
|
model, tokenizer, prompts, generation_config |
|
): |
|
if len(prompts) == 1: |
|
encoding = tokenizer(prompts, return_tensors="pt") |
|
input_ids = encoding["input_ids"].cuda() |
|
generated_id = model.generate( |
|
input_ids=input_ids, |
|
generation_config=generation_config, |
|
max_new_tokens=256 |
|
) |
|
|
|
decoded = tokenizer.batch_decode(generated_id) |
|
del input_ids, generated_id |
|
torch.cuda.empty_cache() |
|
return decoded |
|
else: |
|
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') |
|
generated_ids = model.generate( |
|
**encodings, |
|
generation_config=generation_config, |
|
max_new_tokens=256 |
|
) |
|
|
|
decoded = tokenizer.batch_decode(generated_ids) |
|
del encodings, generated_ids |
|
torch.cuda.empty_cache() |
|
return decoded |
|
|
|
|
|
|
|
|
|
class StreamModel: |
|
"""StreamModel wraps around a language model to provide stream decoding.""" |
|
|
|
def __init__(self, model, tokenizer): |
|
super().__init__() |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def __call__( |
|
self, |
|
prompt, |
|
min_tokens=0, |
|
max_tokens=16, |
|
temperature=1.0, |
|
top_p=1.0, |
|
n=1, |
|
logprobs=0, |
|
): |
|
"""Create a completion stream for the provided prompt.""" |
|
input_ids = self.tokenize(prompt) |
|
logprobs = max(logprobs, 0) |
|
|
|
|
|
chunk_size = 2 |
|
chunk_count = 0 |
|
|
|
|
|
final_tokens = torch.empty(0).to(self.device) |
|
|
|
try: |
|
for tokens in self.generate( |
|
input_ids[None, :].repeat(n, 1), |
|
logprobs=logprobs, |
|
min_new_tokens=min_tokens, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
if chunk_count < chunk_size: |
|
chunk_count = chunk_count + 1 |
|
|
|
final_tokens = torch.cat((final_tokens, tokens)) |
|
|
|
if chunk_count == chunk_size-1: |
|
chunk_count = 0 |
|
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
|
if chunk_count > 0: |
|
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
|
except RetryError as e: |
|
print(e) |
|
del input_ids |
|
gc.collect() |
|
|
|
del final_tokens |
|
if self.device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_fixed(1)) |
|
def _infer(self, model_fn, **kwargs): |
|
"""Call a model function in inference mode with auto retrying.""" |
|
|
|
|
|
with torch.inference_mode(): |
|
return model_fn(**kwargs) |
|
|
|
def _logits_processor(self, config, input_length): |
|
"""Set up logits processor based on the generation config.""" |
|
processor = LogitsProcessorList() |
|
|
|
|
|
if ( |
|
config.min_new_tokens is not None |
|
and config.min_new_tokens > 0 |
|
and config.eos_token_id is not None |
|
): |
|
processor.append( |
|
MinNewTokensLengthLogitsProcessor( |
|
prompt_length_to_skip=input_length, |
|
min_new_tokens=config.min_new_tokens, |
|
eos_token_id=config.eos_token_id, |
|
) |
|
) |
|
|
|
|
|
if ( |
|
config.temperature is not None |
|
and config.temperature > 0 |
|
and config.temperature != 1.0 |
|
): |
|
processor.append(TemperatureLogitsWarper(config.temperature)) |
|
|
|
|
|
if config.top_p is not None and config.top_p > 0 and config.top_p < 1: |
|
processor.append(TopPLogitsWarper(config.top_p)) |
|
|
|
return processor |
|
|
|
def tokenize(self, text): |
|
"""Tokenize a string into a tensor of token IDs.""" |
|
batch = self.tokenizer.encode(text, return_tensors="pt") |
|
return batch[0].to(self.device) |
|
|
|
def generate(self, input_ids, logprobs=0, **kwargs): |
|
"""Generate a stream of predicted tokens using the language model.""" |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
input_length = input_ids.shape[-1] |
|
|
|
|
|
config = self.model.generation_config |
|
config = copy.deepcopy(config) |
|
kwargs = config.update(**kwargs) |
|
kwargs["output_attentions"] = False |
|
kwargs["output_hidden_states"] = False |
|
kwargs["use_cache"] = True |
|
|
|
|
|
pad_token_id = config.pad_token_id |
|
bos_token_id = config.bos_token_id |
|
eos_token_id = config.eos_token_id |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
if pad_token_id is None and eos_token_id is not None: |
|
pad_token_id = eos_token_id[0] |
|
|
|
|
|
if input_length == 0: |
|
input_ids = input_ids.new_ones((batch_size, 1)).long() |
|
if eos_token_id is not None: |
|
input_ids = input_ids * eos_token_id[0] |
|
input_length = 1 |
|
|
|
|
|
if self.model.config.is_encoder_decoder: |
|
|
|
encoder = self.model.get_encoder() |
|
encoder_kwargs = kwargs.copy() |
|
encoder_kwargs.pop("use_cache", None) |
|
encoder_kwargs["input_ids"] = input_ids |
|
encoder_kwargs["return_dict"] = True |
|
encoder_outputs = self._infer(encoder, **encoder_kwargs) |
|
kwargs["encoder_outputs"] = encoder_outputs |
|
|
|
|
|
decoder_start_token_id = config.decoder_start_token_id |
|
if decoder_start_token_id is None: |
|
decoder_start_token_id = bos_token_id |
|
input_ids = input_ids.new_ones((batch_size, 1)) |
|
input_ids = input_ids * decoder_start_token_id |
|
input_length = 1 |
|
|
|
|
|
processor = self._logits_processor(config, input_length) |
|
|
|
|
|
unfinished = input_ids.new_ones(batch_size) |
|
|
|
|
|
while True: |
|
inputs = self.model.prepare_inputs_for_generation( |
|
input_ids, **kwargs |
|
) |
|
outputs = self._infer( |
|
self.model, |
|
**inputs, |
|
return_dict=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
with torch.inference_mode(): |
|
logits = processor(input_ids, logits) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
if (config.top_p is not None and config.top_p <= 0) or ( |
|
config.temperature is not None and config.temperature <= 0 |
|
): |
|
tokens = torch.argmax(probs, dim=-1)[:, None] |
|
else: |
|
tokens = torch.multinomial(probs, num_samples=1) |
|
|
|
tokens = tokens.squeeze(1) |
|
|
|
|
|
if pad_token_id is not None: |
|
tokens = tokens * unfinished + pad_token_id * (1 - unfinished) |
|
|
|
|
|
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) |
|
|
|
|
|
if eos_token_id is not None: |
|
not_eos = sum(tokens != i for i in eos_token_id) |
|
unfinished = unfinished.mul(not_eos.long()) |
|
|
|
|
|
status = unfinished.clone() |
|
if input_ids.shape[-1] - input_length >= config.max_new_tokens: |
|
status = 0 - status |
|
|
|
|
|
yield tokens |
|
|
|
|
|
if status.max() <= 0: |
|
break |
|
|
|
|