Kirill Gelvan
commited on
Commit
•
26ff0b0
1
Parent(s):
67f50e8
add inference code to readme
Browse files
README.md
CHANGED
@@ -3,5 +3,78 @@ language: ru
|
|
3 |
tags:
|
4 |
- conversational
|
5 |
---
|
|
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
tags:
|
4 |
- conversational
|
5 |
---
|
6 |
+
### Description
|
7 |
|
8 |
|
9 |
+
### Inference
|
10 |
+
|
11 |
+
```python
|
12 |
+
|
13 |
+
def get_length_param(text: str, tokenizer) -> str:
|
14 |
+
tokens_count = len(tokenizer.encode(text))
|
15 |
+
if tokens_count <= 15:
|
16 |
+
len_param = '1'
|
17 |
+
elif tokens_count <= 50:
|
18 |
+
len_param = '2'
|
19 |
+
elif tokens_count <= 256:
|
20 |
+
len_param = '3'
|
21 |
+
else:
|
22 |
+
len_param = '-'
|
23 |
+
return len_param
|
24 |
+
|
25 |
+
|
26 |
+
def get_user_param(text: dict, machine_name_in_chat: str) -> str:
|
27 |
+
if text['from'] == machine_name_in_chat:
|
28 |
+
return '1' # machine
|
29 |
+
else:
|
30 |
+
return '0' # human
|
31 |
+
|
32 |
+
|
33 |
+
chat_history_ids = torch.zeros((1, 0), dtype=torch.int)
|
34 |
+
|
35 |
+
while True:
|
36 |
+
|
37 |
+
next_who = input("Who's phrase?\t") #input("H / G?") # Human or GPT
|
38 |
+
|
39 |
+
# In case Human
|
40 |
+
if next_who == "H" or next_who == "Human":
|
41 |
+
input_user = input("===> Human: ")
|
42 |
+
|
43 |
+
# encode the new user input, add parameters and return a tensor in Pytorch
|
44 |
+
new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" + input_user + tokenizer.eos_token, return_tensors="pt")
|
45 |
+
# append the new user input tokens to the chat history
|
46 |
+
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
47 |
+
|
48 |
+
if next_who == "G" or next_who == "GPT":
|
49 |
+
|
50 |
+
next_len = input("Phrase len? 1/2/3/-\t") #input("Exp. len?(-/1/2/3): ")
|
51 |
+
# encode the new user input, add parameters and return a tensor in Pytorch
|
52 |
+
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
|
53 |
+
# append the new user input tokens to the chat history
|
54 |
+
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
55 |
+
|
56 |
+
# print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input
|
57 |
+
|
58 |
+
# save previous len
|
59 |
+
input_len = chat_history_ids.shape[-1]
|
60 |
+
# generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate
|
61 |
+
chat_history_ids = model.generate(
|
62 |
+
chat_history_ids,
|
63 |
+
num_return_sequences=1, # use for more variants, but have to print [i]
|
64 |
+
max_length=512,
|
65 |
+
no_repeat_ngram_size=3,
|
66 |
+
do_sample=True,
|
67 |
+
top_k=50,
|
68 |
+
top_p=0.9,
|
69 |
+
temperature = 0.6, # 0 for greedy
|
70 |
+
mask_token_id=tokenizer.mask_token_id,
|
71 |
+
eos_token_id=tokenizer.eos_token_id,
|
72 |
+
unk_token_id=tokenizer.unk_token_id,
|
73 |
+
pad_token_id=tokenizer.pad_token_id,
|
74 |
+
device='cpu'
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
# pretty print last ouput tokens from bot
|
79 |
+
print(f"===> GPT-3: {tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)}")
|
80 |
+
```
|