|
from pydantic import BaseModel |
|
from transformers import (PreTrainedTokenizerFast, StoppingCriteria) |
|
|
|
|
|
def fallback(value, fallback_value): |
|
if value is None: |
|
return fallback_value |
|
return value |
|
|
|
|
|
class Body(BaseModel): |
|
prompt: str |
|
posts_count: int |
|
max_length: int | None = None |
|
temperature: float | None = None |
|
top_p: float | None = None |
|
top_k: float | None = None |
|
repetition_penalty: float | None = None |
|
no_repeat_ngram_size: float | None = None |
|
do_sample: bool | None = None |
|
|
|
|
|
class MaxPostsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, tokenizer: PreTrainedTokenizerFast, posts_count: int): |
|
self.end_of_post_token_id = tokenizer.encode("<|end_of_post|>", add_special_tokens=False) |
|
self.posts_count = posts_count |
|
self.counter = 0 |
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
|
for sequence in input_ids: |
|
if sequence[-len(self.end_of_post_token_id):].tolist() == self.end_of_post_token_id: |
|
self.counter += 1 |
|
if self.counter >= self.posts_count: |
|
return True |
|
return False |
|
|