Andrey Vorozhko commited on
Commit
7e35601
·
1 Parent(s): 5bbc595

First version

Browse files
Files changed (4) hide show
  1. LICENSE +21 -0
  2. app.py +127 -0
  3. requirements.txt +3 -0
  4. util_funcs.py +51 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Kirill Gelvan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from util_funcs import getLengthParam, calcAnswerLengthByProbability, cropContext
6
+
7
+ def chat_function(Message): # model, tokenizer
8
+
9
+ input_user = Message
10
+
11
+ history = gr.get_state() or []
12
+
13
+ chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
14
+
15
+ # encode the new user input, add parameters and return a tensor in Pytorch
16
+ lengthId = getLengthParam(input_user, tokenizer)
17
+ new_user_input_ids = tokenizer.encode(f"|0|{lengthId}|" \
18
+ + input_user + tokenizer.eos_token, return_tensors="pt")
19
+ # append the new user input tokens to the chat history
20
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
21
+
22
+ # Длину ожидаемой фразы мы рассчитаем на основании последнего инпута
23
+ # Например, я не люблю когда на мой длинный ответ отвечают короткой фразой
24
+ # Но пойдем через вероятности:
25
+ # при длинном инпуте 60% что будет длинный ответ (3), 30% что средний (2), 10% что короткий (1)
26
+ # при среднем инпуте 50% что ответ будет средний (2), и по 25% на оба остальных случая
27
+ # при коротком инпуте 50% что ответ будет короткий (1), 30% что средний (2) и 20% что длинный (3)
28
+ # см. функцию calcAnswerLengthByProbability()
29
+
30
+ next_len = calcAnswerLengthByProbability(lengthId)
31
+
32
+ # encode the new user input, add parameters and return a tensor in Pytorch
33
+ new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
34
+
35
+ # append the new user input tokens to the chat history
36
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
37
+
38
+ chat_history_ids = cropContext(chat_history_ids, 10)
39
+
40
+ print(tokenizer.decode(chat_history_ids[-1]))# uncomment for debug
41
+
42
+ # save previous len
43
+ input_len = chat_history_ids.shape[-1]
44
+ # generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
45
+
46
+ temperature = 0.6
47
+
48
+ # Обрезаем контекст до нужной длины с конца
49
+
50
+ # Создадим копию изначальных данных на случай если придется перегенерировать ответ
51
+ chat_history_ids_initial = chat_history_ids
52
+
53
+ while True:
54
+ chat_history_ids = model.generate(
55
+ chat_history_ids,
56
+ num_return_sequences=1,
57
+ min_length = 2,
58
+ max_length=512,
59
+ no_repeat_ngram_size=3,
60
+ do_sample=True,
61
+ top_k=50,
62
+ top_p=0.9,
63
+ temperature = temperature,
64
+ mask_token_id=tokenizer.mask_token_id,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ unk_token_id=tokenizer.unk_token_id,
67
+ pad_token_id=tokenizer.pad_token_id,
68
+ device='cpu'
69
+ )
70
+
71
+ answer = tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)
72
+
73
+ if (len(answer) > 0 and answer[-1] != ',' and answer[-1] != ':'):
74
+ break
75
+ else:
76
+ if (temperature <= 0.1):
77
+ temperature -= 0.1
78
+
79
+ # Случай когда надо перегенерировать ответ наступил, берем изначальный тензор
80
+ chat_history_ids = chat_history_ids_initial
81
+
82
+ history.append((input_user, answer, chat_history_ids.tolist()))
83
+ gr.set_state(history)
84
+ html = "<div class='chatbot'>"
85
+ for user_msg, resp_msg, _ in history:
86
+ if user_msg != '-':
87
+ html += f"<div class='user_msg'>{user_msg}</div>"
88
+ if resp_msg != '-':
89
+ html += f"<div class='resp_msg'>{resp_msg}</div>"
90
+ html += "</div>"
91
+ return html
92
+
93
+ # Download checkpoint:
94
+
95
+ checkpoint = "avorozhko/ruDialoGpt3-medium-finetuned-context"
96
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
97
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
98
+ model = model.eval()
99
+
100
+ # Gradio
101
+ title = "Чат-бот для поднятия настроения"
102
+ description = """
103
+ Данный бот постарается поднять вам настроение, так как он знает 26700 анекдотов.
104
+ Но чувство юмора у него весьма специфичное.
105
+ Бот не знает матерных слов и откровенных пошлостей, но кто такой Вовочка и Поручик Ржевский знает )
106
+ """
107
+ article = "<p style='text-align: center'><a href='https://huggingface.co/avorozhko/ruDialoGpt3-medium-finetuned-context'>Бот на основе дообученной GPT-3</a></p>"
108
+
109
+ iface = gr.Interface(fn=chat_function,
110
+ inputs=gr.inputs.Textbox(lines=3, placeholder="Что вы хотите сказать боту..."),
111
+ outputs="html",
112
+ title=title, description=description, article=article,
113
+ theme='dark-grass',
114
+ css= """
115
+ .chatbox {display:flex;flex-direction:column}
116
+ .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
117
+ .user_msg {background-color:#1e4282;color:white;align-self:start}
118
+ .resp_msg {background-color:#552a2a;align-self:self-end}
119
+ .panels.unaligned {flex-direction: column !important;align-items: initial!important;}
120
+ .panels.unaligned :last-child {order: -1 !important;}
121
+ """,
122
+ allow_screenshot=False,
123
+ allow_flagging=False
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ iface.launch(debug=False)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ random
util_funcs.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def getLengthParam(text: str, tokenizer) -> str:
2
+ tokens_count = len(tokenizer.encode(text))
3
+ if tokens_count <= 15:
4
+ len_param = '1'
5
+ elif tokens_count <= 50:
6
+ len_param = '2'
7
+ elif tokens_count <= 256:
8
+ len_param = '3'
9
+ else:
10
+ len_param = '-'
11
+ return len_param
12
+
13
+ # Эта функция вычисляет длину ожидаемого ответа на основе инпута
14
+ def calcAnswerLengthByProbability(lengthId):
15
+
16
+ # Вспомогательная функция, для работы с вероятностями
17
+ # На вход подаем список веротностей для длинного ответа (3), среднего(2), короткого 1
18
+ def getLenght(probList):
19
+ rndNum = random.randrange(start=0, stop=100, step=1)
20
+ if 0 <= rndNum <= probList[0]:
21
+ return 3
22
+ elif probList[0] < rndNum <= probList[1]:
23
+ return 2
24
+ else:
25
+ return 1
26
+
27
+ return {
28
+ lengthId == '3' or lengthId == '-': getLenght([60, 90]), # до 60 - 3, от 60 до 90 2, остальное - 1
29
+ lengthId == '2': getLenght([25, 75]), # до 25 - 3, от 25 до 75 - 2, остальное - 2
30
+ lengthId == '1': getLenght([20, 50]), # до 20 - 3, от 20 до 50 - 2, остальное - 1
31
+ }[True]
32
+
33
+ # Функция для обрезки контекста
34
+ # tensor - входной тензор
35
+ # size - сколько ПОСЛЕДНИХ ответов нужно оставить
36
+ def cropContext(tensor, size):
37
+ # переводим в размерность, удобную для работы
38
+ tensor = tensor[-1]
39
+ # Список, содержащий начала предложений
40
+ beginList = []
41
+
42
+ for i, item in enumerate(tensor):
43
+ if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96):
44
+ beginList.append(i)
45
+
46
+ if (len(beginList) < size):
47
+ return torch.unsqueeze(tensor, 0)
48
+
49
+ neededIndex = beginList[-size]
50
+ # Возвращаем в нужном нам формате (добавляем одну размерность)
51
+ return torch.unsqueeze(tensor[neededIndex:], 0)