File size: 1,002 Bytes
74b17e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
import copy

from .formatter import EmptyFormatter, StringFormatter
from .base import Template
from .formatter import Formatter
from ...utils.constants import *
from . import register_template

from transformers import PreTrainedTokenizer
import torch
    


@register_template('pretrain')
@dataclass
class PretrainTemplate(Template):
    format_image_token: "Formatter" = EmptyFormatter(slot="")
    format_user: "Formatter" = EmptyFormatter(slot="<image>")
    format_assistant: "Formatter" = StringFormatter(slot="{{content}}\n")
    system: "Formatter" = EmptyFormatter(slot="")
    separator: "Formatter" = EmptyFormatter(slot=['', ''])
    
    def make_labels(self, input_ids, prompt, tokenizer):
        labels = copy.deepcopy(input_ids)
        mask_len = len(self.tokenizer_image_token("<image>", tokenizer))
        labels[:mask_len] = IGNORE_INDEX
        return labels