File size: 1,216 Bytes
06126dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pydantic import BaseModel
from transformers import (PreTrainedTokenizerFast, StoppingCriteria)


def fallback(value, fallback_value):
    if value is None:
        return fallback_value
    return value


class Body(BaseModel):
    prompt: str
    posts_count: int
    max_length: int | None = None
    temperature: float | None = None
    top_p: float | None = None
    top_k: float | None = None
    repetition_penalty: float | None = None
    no_repeat_ngram_size: float | None = None
    do_sample: bool | None = None


class MaxPostsStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer: PreTrainedTokenizerFast, posts_count: int):
        self.end_of_post_token_id = tokenizer.encode("<|end_of_post|>", add_special_tokens=False)
        self.posts_count = posts_count
        self.counter = 0

    def __call__(self, input_ids, scores, **kwargs):
        # Check if the last token matches the <|end_of_post|> token ID
        for sequence in input_ids:
            if sequence[-len(self.end_of_post_token_id):].tolist() == self.end_of_post_token_id:
                self.counter += 1
                if self.counter >= self.posts_count:
                    return True
        return False