File size: 3,961 Bytes
5ca4e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from llava.conversation import conv_llava_llama_2
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token


def prepare_text_prompt(user_prompt):

    qs = DEFAULT_IMAGE_TOKEN + '\n'+ user_prompt

    conv = conv_llava_llama_2.copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    return prompt

# support batch implementation
class Prompt:
    # tokenization
    # turn to embeddings

    # padding? wait until targets have been appended
    # prepare labels? need to wait for targets

    def __init__(self, model, tokenizer, text_prompts=None, device='cuda:0',max_new_tokens=300, max_length=2000):

        self.model = model
        self.tokenizer = tokenizer
        self.device = device

        self.text_prompts = text_prompts
        self.img_prompts = [[]]

        self.context_length = []
        self.input_ids = []
        self.do_tokenization(self.text_prompts)

        self.max_new_tokens = max_new_tokens
        self.max_length = max_length

        self.text_embs = self.generate_text_embedding(self.text_prompts)
        self.img_embs = [[]]
        self.update_context_embs()


    def do_tokenization(self, text_prompts):

        if text_prompts is None:
            self.input_ids = []
            self.context_length = []
            return
        if type(text_prompts) is list:
            text_prompts = text_prompts[0]

        input_ids = tokenizer_image_token(text_prompts, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        self.input_ids = [input_ids]
        self.context_length = [input_ids.shape[1]]

    def update_context_embs(self):

        if len(self.text_embs) == len(self.img_embs):
            self.context_embs = self.generate_context_embedding(
                    self.text_embs, self.img_embs
                )
        else:
            self.context_embs = []

    def update_text_prompt(self, text_prompts):
        self.text_prompts = text_prompts
        self.text_embs = self.generate_text_embedding(self.text_prompts)
        self.update_context_embs()

    def generate_text_embedding(self, text_prompts):

        if text_prompts is None:
            return []

        text_embs = []
        for item in text_prompts: # for each prompt within a batch
            prompt_segs = item.split('<image>')  # each <ImageHere> corresponds to one image
            seg_tokens = [
                self.tokenizer(
                    seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
                # only add bos to the first seg
                for i, seg in enumerate(prompt_segs)
            ]
            embs = [self.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings
            text_embs.append(embs)

        return text_embs

    def generate_context_embedding(self, batch_text_embs, batch_img_embs):
        #assert len(text_embs) == len(img_embs) + 1, "Unmatched numbers of image placeholders and images."

        assert len(batch_text_embs) == len(batch_img_embs), "Unmathced batch size of text and image prompts"

        batch_size = len(batch_text_embs)
        batch_context_embs = []

        for i in range(batch_size):

            mixed_embs = torch.cat(batch_text_embs[i], dim=1)
            current_max_len = mixed_embs.shape[1] + self.max_new_tokens
            if current_max_len - self.max_length > 0:
                print('Warning: The number of tokens in current conversation exceeds the max length. '
                      'The model will not see the contexts outside the range.')
            begin_idx = max(0, current_max_len - self.max_length)
            mixed_embs = mixed_embs[:, begin_idx:]
            
            batch_context_embs.append(mixed_embs)

        return batch_context_embs