File size: 6,078 Bytes
7e35601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import random
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from util_funcs import getLengthParam, calcAnswerLengthByProbability, cropContext

def chat_function(Message):   # model, tokenizer

    input_user = Message

    history = gr.get_state() or []

    chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)

    # encode the new user input, add parameters and return a tensor in Pytorch
    lengthId = getLengthParam(input_user, tokenizer)
    new_user_input_ids = tokenizer.encode(f"|0|{lengthId}|" \
                                          + input_user + tokenizer.eos_token, return_tensors="pt")
    # append the new user input tokens to the chat history
    chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)

    # Длину ожидаемой фразы мы рассчитаем на основании последнего инпута
    # Например, я не люблю когда на мой длинный ответ отвечают короткой фразой
    # Но пойдем через вероятности: 
    # при длинном инпуте 60% что будет длинный ответ (3), 30% что средний (2), 10% что короткий (1)
    # при среднем инпуте 50% что ответ будет средний (2), и по 25% на оба остальных случая
    # при коротком инпуте 50% что ответ будет короткий (1), 30% что средний (2) и 20% что длинный (3)
    # см. функцию calcAnswerLengthByProbability()

    next_len = calcAnswerLengthByProbability(lengthId)

    # encode the new user input, add parameters and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")

    # append the new user input tokens to the chat history
    chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)

    chat_history_ids = cropContext(chat_history_ids, 10)

    print(tokenizer.decode(chat_history_ids[-1]))# uncomment for debug
    
    # save previous len
    input_len = chat_history_ids.shape[-1]
    # generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate

    temperature = 0.6

    # Обрезаем контекст до нужной длины с конца

    # Создадим копию изначальных данных на случай если придется перегенерировать ответ
    chat_history_ids_initial = chat_history_ids

    while True:
      chat_history_ids = model.generate(
        chat_history_ids,
        num_return_sequences=1,
        min_length = 2,
        max_length=512,
        no_repeat_ngram_size=3,
        do_sample=True,
        top_k=50,
        top_p=0.9,
        temperature = temperature,
        mask_token_id=tokenizer.mask_token_id,
        eos_token_id=tokenizer.eos_token_id,
        unk_token_id=tokenizer.unk_token_id,
        pad_token_id=tokenizer.pad_token_id,
        device='cpu'
      )

      answer = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)

      if (len(answer) > 0 and answer[-1] != ',' and answer[-1] != ':'):
        break
      else:
        if (temperature <= 0.1):
          temperature -= 0.1

        # Случай когда надо перегенерировать ответ наступил, берем изначальный тензор
        chat_history_ids = chat_history_ids_initial

    history.append((input_user, answer, chat_history_ids.tolist()))        
    gr.set_state(history)
    html = "<div class='chatbot'>"
    for user_msg, resp_msg, _ in history:
        if user_msg != '-':
            html += f"<div class='user_msg'>{user_msg}</div>"
        if resp_msg != '-':
            html += f"<div class='resp_msg'>{resp_msg}</div>"
    html += "</div>"
    return html

# Download checkpoint:

checkpoint = "avorozhko/ruDialoGpt3-medium-finetuned-context"
tokenizer =  AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.eval()

# Gradio
title = "Чат-бот для поднятия настроения"
description = """
Данный бот постарается поднять вам настроение, так как он знает 26700 анекдотов.
Но чувство юмора у него весьма специфичное.
Бот не знает матерных слов и откровенных пошлостей, но кто такой Вовочка и Поручик Ржевский знает )
              """
article = "<p style='text-align: center'><a href='https://huggingface.co/avorozhko/ruDialoGpt3-medium-finetuned-context'>Бот на основе дообученной GPT-3</a></p>"

iface = gr.Interface(fn=chat_function,
                     inputs=gr.inputs.Textbox(lines=3, placeholder="Что вы хотите сказать боту..."),
                     outputs="html",
                     title=title, description=description, article=article,
                     theme='dark-grass',
                     css= """
                            .chatbox {display:flex;flex-direction:column}
                            .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
                            .user_msg {background-color:#1e4282;color:white;align-self:start}
                            .resp_msg {background-color:#552a2a;align-self:self-end}
                            .panels.unaligned {flex-direction: column !important;align-items: initial!important;}
                            .panels.unaligned :last-child {order: -1 !important;}
                          """,
                     allow_screenshot=False,
                     allow_flagging=False
                    )

if __name__ == "__main__":
  iface.launch(debug=False)