File size: 1,722 Bytes
703e263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from ..models.model_manager import ModelManager
import torch



def tokenize_long_prompt(tokenizer, prompt, max_length=None):
    # Get model_max_length from self.tokenizer
    length = tokenizer.model_max_length if max_length is None else max_length

    # To avoid the warning. set self.tokenizer.model_max_length to +oo.
    tokenizer.model_max_length = 99999999

    # Tokenize it!
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    # Determine the real length.
    max_length = (input_ids.shape[1] + length - 1) // length * length

    # Restore tokenizer.model_max_length
    tokenizer.model_max_length = length
    
    # Tokenize it again with fixed length.
    input_ids = tokenizer(
        prompt,
        return_tensors="pt",
        padding="max_length",
        max_length=max_length,
        truncation=True
    ).input_ids

    # Reshape input_ids to fit the text encoder.
    num_sentence = input_ids.shape[1] // length
    input_ids = input_ids.reshape((num_sentence, length))
    
    return input_ids



class BasePrompter:
    def __init__(self, refiners=[]):
        self.refiners = refiners


    def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
        for refiner_class in refiner_classes:
            refiner = refiner_class.from_model_manager(model_nameger)
            self.refiners.append(refiner)


    @torch.no_grad()
    def process_prompt(self, prompt, positive=True):
        if isinstance(prompt, list):
            prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
        else:
            for refiner in self.refiners:
                prompt = refiner(prompt, positive=positive)
        return prompt