Spaces:
Runtime error
Runtime error
import re | |
import torch | |
from transformers import ( | |
LogitsProcessorList, | |
TemperatureLogitsWarper, | |
TopKLogitsWarper, | |
TopPLogitsWarper, | |
TypicalLogitsWarper, | |
RepetitionPenaltyLogitsProcessor, | |
PreTrainedTokenizerBase, | |
) | |
from typing import List, Tuple, Optional | |
from text_generation_server.pb import generate_pb2 | |
from text_generation_server.pb.generate_pb2 import FinishReason | |
from text_generation_server.utils.watermark import WatermarkLogitsProcessor | |
class Sampling: | |
def __init__(self, seed: int, device: str = "cpu"): | |
self.generator = torch.Generator(device) | |
self.generator.manual_seed(seed) | |
self.seed = seed | |
def __call__(self, logits): | |
probs = torch.nn.functional.softmax(logits, -1) | |
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) | |
return next_tokens | |
class Greedy: | |
def __call__(self, logits): | |
return logits.argmax() | |
class NextTokenChooser: | |
def __init__( | |
self, | |
watermark=False, | |
temperature=1.0, | |
repetition_penalty=1.0, | |
top_k=None, | |
top_p=None, | |
typical_p=None, | |
do_sample=False, | |
seed=0, | |
device="cpu", | |
): | |
warpers = LogitsProcessorList() | |
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files | |
# all samplers can be found in `generation_utils_samplers.py` | |
sampling = do_sample | |
if watermark: | |
warpers.append(WatermarkLogitsProcessor(device=device)) | |
if repetition_penalty is not None and repetition_penalty != 1.0: | |
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) | |
if temperature is not None and temperature != 1.0: | |
temperature = float(temperature) | |
warpers.append(TemperatureLogitsWarper(temperature)) | |
sampling = True | |
if top_k is not None and top_k != 0: | |
warpers.append(TopKLogitsWarper(top_k=top_k)) | |
sampling = True | |
if top_p is not None and top_p < 1.0: | |
warpers.append(TopPLogitsWarper(top_p=top_p)) | |
sampling = True | |
if typical_p is not None and typical_p < 1.0: | |
warpers.append(TypicalLogitsWarper(mass=typical_p)) | |
sampling = True | |
self.warpers = warpers | |
self.choice = Sampling(seed, device) if sampling else Greedy() | |
def __call__(self, input_ids, scores): | |
# Warp logits | |
if scores.shape[0] > 1: | |
# only warp the last token logits | |
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :]) | |
else: | |
scores = self.warpers(input_ids, scores) | |
# Compute logprobs | |
logprobs = torch.log_softmax(scores, -1) | |
# Choose tokens | |
next_id = self.choice(scores[-1]) | |
return next_id.view(1, 1), logprobs | |
def from_pb( | |
cls, | |
pb: generate_pb2.NextTokenChooserParameters, | |
device: torch.device, | |
) -> "NextTokenChooser": | |
return NextTokenChooser( | |
watermark=pb.watermark, | |
temperature=pb.temperature, | |
repetition_penalty=pb.repetition_penalty, | |
top_k=pb.top_k, | |
top_p=pb.top_p, | |
typical_p=pb.typical_p, | |
do_sample=pb.do_sample, | |
seed=pb.seed, | |
device=device, | |
) | |
class StopSequenceCriteria: | |
def __init__(self, stop_sequence: str): | |
stop_sequence = re.escape(stop_sequence) | |
self.regex = re.compile(f".*{stop_sequence}$") | |
def __call__(self, output: str) -> bool: | |
if self.regex.findall(output): | |
return True | |
return False | |
class StoppingCriteria: | |
def __init__( | |
self, | |
eos_token_id: int, | |
stop_sequence_criterias: List[StopSequenceCriteria], | |
max_new_tokens: int = 20, | |
ignore_eos_token: bool = False, | |
): | |
self.eos_token_id = eos_token_id | |
self.stop_sequence_criterias = stop_sequence_criterias | |
self.max_new_tokens = max_new_tokens | |
self.current_tokens = 0 | |
self.current_output = "" | |
self.ignore_eos_token = ignore_eos_token | |
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: | |
self.current_tokens += 1 | |
if self.current_tokens >= self.max_new_tokens: | |
return True, FinishReason.FINISH_REASON_LENGTH | |
if not self.ignore_eos_token and last_token == self.eos_token_id: | |
return True, FinishReason.FINISH_REASON_EOS_TOKEN | |
self.current_output += last_output | |
for stop_sequence_criteria in self.stop_sequence_criterias: | |
if stop_sequence_criteria(self.current_output): | |
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE | |
return False, None | |
def from_pb( | |
cls, | |
pb: generate_pb2.StoppingCriteriaParameters, | |
tokenizer: PreTrainedTokenizerBase, | |
) -> "StoppingCriteria": | |
stop_sequence_criterias = [ | |
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences | |
] | |
return StoppingCriteria( | |
tokenizer.eos_token_id, | |
stop_sequence_criterias, | |
pb.max_new_tokens, | |
pb.ignore_eos_token, | |
) | |