File size: 1,002 Bytes
74b17e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
import copy
from .formatter import EmptyFormatter, StringFormatter
from .base import Template
from .formatter import Formatter
from ...utils.constants import *
from . import register_template
from transformers import PreTrainedTokenizer
import torch
@register_template('pretrain')
@dataclass
class PretrainTemplate(Template):
format_image_token: "Formatter" = EmptyFormatter(slot="")
format_user: "Formatter" = EmptyFormatter(slot="<image>")
format_assistant: "Formatter" = StringFormatter(slot="{{content}}\n")
system: "Formatter" = EmptyFormatter(slot="")
separator: "Formatter" = EmptyFormatter(slot=['', ''])
def make_labels(self, input_ids, prompt, tokenizer):
labels = copy.deepcopy(input_ids)
mask_len = len(self.tokenizer_image_token("<image>", tokenizer))
labels[:mask_len] = IGNORE_INDEX
return labels
|