beam-app / utils.py
Greums's picture
first app version
06126dc
from pydantic import BaseModel
from transformers import (PreTrainedTokenizerFast, StoppingCriteria)
def fallback(value, fallback_value):
if value is None:
return fallback_value
return value
class Body(BaseModel):
prompt: str
posts_count: int
max_length: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: float | None = None
repetition_penalty: float | None = None
no_repeat_ngram_size: float | None = None
do_sample: bool | None = None
class MaxPostsStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer: PreTrainedTokenizerFast, posts_count: int):
self.end_of_post_token_id = tokenizer.encode("<|end_of_post|>", add_special_tokens=False)
self.posts_count = posts_count
self.counter = 0
def __call__(self, input_ids, scores, **kwargs):
# Check if the last token matches the <|end_of_post|> token ID
for sequence in input_ids:
if sequence[-len(self.end_of_post_token_id):].tolist() == self.end_of_post_token_id:
self.counter += 1
if self.counter >= self.posts_count:
return True
return False