File size: 4,318 Bytes
0523803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def init(cfg):
    chat_template = cfg['chat_template']
    model = cfg['model']
    gr = cfg['gr']
    lock = cfg['session_lock']

    # ========== 流式输出函数 ==========
    def btn_com(_n_keep, _n_discard,
                _temperature, _repeat_penalty, _frequency_penalty,
                _presence_penalty, _repeat_last_n, _top_k,
                _top_p, _min_p, _typical_p,
                _tfs_z, _mirostat_mode, _mirostat_eta,
                _mirostat_tau, _role, _max_tokens):
        # ========== 初始化输出模版 ==========
        t_bot = chat_template(_role)
        completion_tokens = []  # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
        history = ''
        # ========== 流式输出 ==========
        for token in model.generate_t(
                tokens=t_bot,
                n_keep=_n_keep,
                n_discard=_n_discard,
                im_start=chat_template.im_start_token,
                top_k=_top_k,
                top_p=_top_p,
                min_p=_min_p,
                typical_p=_typical_p,
                temp=_temperature,
                repeat_penalty=_repeat_penalty,
                repeat_last_n=_repeat_last_n,
                frequency_penalty=_frequency_penalty,
                presence_penalty=_presence_penalty,
                tfs_z=_tfs_z,
                mirostat_mode=_mirostat_mode,
                mirostat_tau=_mirostat_tau,
                mirostat_eta=_mirostat_eta,
        ):
            if token in chat_template.eos or token == chat_template.nlnl:
                t_bot.extend(completion_tokens)
                print('token in eos', token)
                break
            completion_tokens.append(token)
            all_text = model.str_detokenize(completion_tokens)
            if not all_text:
                continue
            t_bot.extend(completion_tokens)
            history += all_text
            yield history
            if token in chat_template.onenl:
                # ========== 移除末尾的换行符 ==========
                if t_bot[-2] in chat_template.onenl:
                    model.venv_pop_token()
                    break
                if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
                    model.venv_pop_token()
                    break
            if history[-2:] == '\n\n':  # 各种 'x\n\n' 的token,比如'。\n\n'
                print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
                      repr(model.str_detokenize(t_bot[-1:])))
                break
            if len(t_bot) > _max_tokens:
                break
            completion_tokens = []
        # ========== 查看末尾的换行符 ==========
        print('history', repr(history))
        # ========== 给 kv_cache 加上输出结束符 ==========
        model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
        t_bot.extend(chat_template.im_end_nl)

    cfg['btn_com'] = btn_com

    def btn_start_or_finish(finish):
        tmp = gr.update(interactive=finish)

        def _inner():
            with lock:
                if cfg['session_active'] != finish:
                    raise RuntimeError
                cfg['session_active'] = not cfg['session_active']
            return tmp, tmp, tmp

        return _inner

    btn_start_or_finish_outputs = [cfg['btn_submit'], cfg['btn_vo'], cfg['btn_suggest']]

    cfg['btn_start'] = {
        'fn': btn_start_or_finish(False),
        'outputs': btn_start_or_finish_outputs
    }

    cfg['btn_finish'] = {
        'fn': btn_start_or_finish(True),
        'outputs': btn_start_or_finish_outputs
    }

    cfg['setting'] = [cfg[x] for x in ('setting_n_keep', 'setting_n_discard',
                                       'setting_temperature', 'setting_repeat_penalty', 'setting_frequency_penalty',
                                       'setting_presence_penalty', 'setting_repeat_last_n', 'setting_top_k',
                                       'setting_top_p', 'setting_min_p', 'setting_typical_p',
                                       'setting_tfs_z', 'setting_mirostat_mode', 'setting_mirostat_eta',
                                       'setting_mirostat_tau', 'role_usr', 'role_char',
                                       'rag', 'setting_max_tokens')]