File size: 4,870 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import codecs
import json
import logging
import os

from pandas import DataFrame
from transformers import AutoTokenizer

from llm_studio.python_configs.base import DefaultConfigProblemBase

logger = logging.getLogger(__name__)


def get_texts(df: DataFrame, cfg: DefaultConfigProblemBase):
    if isinstance(cfg.dataset.prompt_column, str):
        # single column dataset
        texts = df[cfg.dataset.prompt_column].astype(str)
        texts = texts.values
    else:
        # multi-column dataset - prepend (if necessary) and join
        columns = list(cfg.dataset.prompt_column)

        for column in columns:
            df[column] = df[column].astype(str)

        join_str = codecs.decode(cfg.dataset.prompt_column_separator, "unicode_escape")

        texts = df[columns].astype(str)
        texts = texts.apply(lambda x: join_str.join(x), axis=1).values

    return texts


def get_tokenizer(cfg: DefaultConfigProblemBase):

    kwargs = dict(
        revision=cfg.environment.huggingface_branch,
        trust_remote_code=cfg.environment.trust_remote_code,
        token=os.getenv("HF_TOKEN"),
    )

    # We will be able to remove this after
    # https://github.com/huggingface/transformers/pull/30964
    tokenizer_class = AutoTokenizer.from_pretrained(
        cfg.llm_backbone, **kwargs
    ).__class__
    if tokenizer_class.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
        kwargs["from_slow"] = True

    kwargs.update(json.loads(cfg.tokenizer.tokenizer_kwargs.strip()))

    try:
        tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
    except TypeError as e:
        error_message = str(e)
        if "token" in error_message:
            # TypeError: RWForCausalLM.__init__() got
            # an unexpected keyword argument 'token'
            kwargs.pop("token")
            tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
        elif "not a string" in error_message:
            # https://github.com/h2oai/h2o-llmstudio/issues/623
            kwargs.pop("add_prefix_space", None)
            kwargs.pop("from_slow", None)
            tokenizer = AutoTokenizer.from_pretrained(cfg.llm_backbone, **kwargs)
        else:
            raise e

    tokenizer.padding_side = getattr(
        cfg.tokenizer, "_padding_side", tokenizer.padding_side
    )

    tokenizer.add_bos_token = False
    tokenizer.add_eos_token = False

    # if the eos token is an empty string, we assign it to a token
    if tokenizer.eos_token == "":
        tokenizer.add_special_tokens({"eos_token": "</s>"})
        tokenizer.eos_token = "</s>"

    if tokenizer.pad_token is None:
        if tokenizer.unk_token is not None:
            tokenizer.pad_token = tokenizer.unk_token
        else:
            tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.bos_token is None:
        tokenizer.bos_token = tokenizer.eos_token
    if tokenizer.cls_token is None:
        tokenizer.cls_token = tokenizer.eos_token

    if tokenizer.unk_token_id is not None:
        cfg.tokenizer._tokenizer_mask_token_id = tokenizer.unk_token_id
    elif tokenizer.mask_token_id is not None:
        cfg.tokenizer._tokenizer_mask_token_id = tokenizer.mask_token_id
    elif tokenizer.pad_token_id is not None:
        cfg.tokenizer._tokenizer_mask_token_id = tokenizer.pad_token_id
    else:
        # setting the mask token id to the last token in the vocabulary
        # this usually is a safe choice and mostly refers to eos token
        cfg.tokenizer._tokenizer_mask_token_id = len(tokenizer) - 1

    cfg.tokenizer._tokenizer_eos_token = tokenizer.eos_token

    if hasattr(cfg.prediction, "stop_tokens"):
        set_stop_token_ids(cfg, tokenizer)
    cfg.tokenizer._vocab_length = len(tokenizer)

    return tokenizer


def set_stop_token_ids(cfg, tokenizer):
    cfg.tokenizer._stop_words = list(
        filter(None, cfg.prediction.stop_tokens.split(","))
    )
    for stop_word in [
        cfg.dataset.text_system_start,
        cfg.dataset.text_prompt_start,
        cfg.dataset.text_answer_separator,
    ]:
        stop_word = codecs.decode(stop_word, "unicode_escape").strip()
        if (
            stop_word != ""
            and cfg.tokenizer.add_prompt_answer_tokens
            and (stop_word not in tokenizer.get_vocab())
        ):
            tokenizer.add_tokens([stop_word])
        cfg.tokenizer._stop_words.append(stop_word)
    cfg.tokenizer._stop_words = [
        stop_word for stop_word in cfg.tokenizer._stop_words if stop_word != ""
    ]
    cfg.tokenizer._stop_words_ids = []
    for stop_word in set(cfg.tokenizer._stop_words):
        cfg.tokenizer._stop_words_ids.append(
            tokenizer(stop_word, return_tensors="pt", add_special_tokens=False)[
                "input_ids"
            ][0]
        )
    logger.info(f"Stop token ids: {cfg.tokenizer._stop_words_ids}")