File size: 5,722 Bytes
569f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import gc
import copy
import time

import torch
import warnings
import transformers

import numpy as np

from typing import Dict, Optional, Sequence
from omnilmm import conversation as conversation_lib

IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"


def _tokenize_fn(strings: Sequence[str],
                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [
        tokenized.input_ids[0] for tokenized in tokenized_list
    ]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )



def omni_preprocess(sources,
                      tokenizer: transformers.PreTrainedTokenizer,
                      generation=False):
    system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
    ignore_index = -100

    response_template = '\n<|assistant|>\n'
    instruction_template = '\n<|user|>\n'
    response_token_ids = tokenizer.encode(
        response_template, add_special_tokens=False)
    instruction_token_ids = tokenizer.encode(
        instruction_template, add_special_tokens=False)

    batch_input_ids = []
    batch_labels = []
    for i in range(len(sources)):
        new_source = []
        prev_role = 'unexpect'
        for conv_turn in sources[i]:
            role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
            content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']

            role = 'user' if role == 'human' else role
            role = 'assistant' if role == 'gpt' else role

            assert role in ['user', 'assistant']
            assert role != prev_role, f'role={role}, prev_role={prev_role}'
            prev_role = role

            new_turn = {
                'role': role,
                'content': content
            }
            new_source.append(new_turn)
        if new_source[0]['role'] != 'system':
            new_source.insert(0, {'role': 'system', 'content': system_content})

        # TODO: this automatically add '\n' to the end
        res_text = tokenizer.apply_chat_template(
            new_source, tokenize=False, add_generation_prompt=generation)
        if not generation:
            res_text = res_text.strip()

        conversations_tokenized = _tokenize_fn([res_text], tokenizer)
        res_input_ids = conversations_tokenized["input_ids"][0]

        # since labels and input_ids are reference towards the same object
        res_labels = copy.deepcopy(conversations_tokenized["labels"][0])

        response_token_ids_idxs = []
        human_token_ids_idxs = []

        for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
            # find the indexes of the start of a response.
            if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
                        response_token_ids)].tolist()
                    ):
                response_token_ids_idxs.append(
                    assistant_idx + len(response_token_ids))

        if len(response_token_ids_idxs) == 0:
            warnings.warn(
                f"Could not find response key `{response_template}` in the "
                f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
                f'Raw text is @===>{res_text}<===@'
                f'Raw source is @===>{new_source}<===@'
                f"This instance will be ignored in loss calculation. "
                f"Note, if this happens often, consider increasing the `max_seq_length`."
            )
            res_labels[:] = ignore_index

        human_token_ids = instruction_token_ids
        for human_idx in np.where(res_labels == human_token_ids[0])[0]:
            # find the indexes of the start of a human answer.
            if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
                human_token_ids_idxs.append(human_idx)

        if len(human_token_ids_idxs) == 0:
            warnings.warn(
                f"Could not find instruction key `{instruction_template}` in the "
                f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
                f'Raw text is @===>{res_text}<===@'
                f'Raw source is @===>{new_source}<===@'
                f"This instance will be ignored in loss calculation. "
                f"Note, if this happens often, consider increasing the `max_seq_length`."
            )
            res_labels[:] = ignore_index

        for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
            # Make pytorch loss function ignore all non response tokens
            if idx != 0:
                res_labels[start:end] = ignore_index
            else:
                res_labels[:end] = ignore_index

        if len(response_token_ids_idxs) < len(human_token_ids_idxs):
            res_labels[human_token_ids_idxs[-1]:] = ignore_index

        batch_input_ids.append(res_input_ids)
        batch_labels.append(res_labels)

    return dict(input_ids=batch_input_ids, labels=batch_labels)