|
--- |
|
language: ru |
|
tags: |
|
- conversational |
|
--- |
|
### Description |
|
|
|
|
|
### Inference |
|
|
|
```python |
|
|
|
def get_length_param(text: str, tokenizer) -> str: |
|
tokens_count = len(tokenizer.encode(text)) |
|
if tokens_count <= 15: |
|
len_param = '1' |
|
elif tokens_count <= 50: |
|
len_param = '2' |
|
elif tokens_count <= 256: |
|
len_param = '3' |
|
else: |
|
len_param = '-' |
|
return len_param |
|
|
|
|
|
def get_user_param(text: dict, machine_name_in_chat: str) -> str: |
|
if text['from'] == machine_name_in_chat: |
|
return '1' # machine |
|
else: |
|
return '0' # human |
|
|
|
|
|
chat_history_ids = torch.zeros((1, 0), dtype=torch.int) |
|
|
|
while True: |
|
|
|
next_who = input("Who's phrase?\t") #input("H / G?") # Human or GPT |
|
|
|
# In case Human |
|
if next_who == "H" or next_who == "Human": |
|
input_user = input("===> Human: ") |
|
|
|
# encode the new user input, add parameters and return a tensor in Pytorch |
|
new_user_input_ids = tokenizer.encode(f"|0|{get_length_param(input_user, tokenizer)}|" \ |
|
+ input_user + tokenizer.eos_token, return_tensors="pt") |
|
# append the new user input tokens to the chat history |
|
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
|
|
if next_who == "G" or next_who == "GPT": |
|
|
|
next_len = input("Phrase len? 1/2/3/-\t") #input("Exp. len?(-/1/2/3): ") |
|
# encode the new user input, add parameters and return a tensor in Pytorch |
|
new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt") |
|
# append the new user input tokens to the chat history |
|
chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) |
|
|
|
# print(tokenizer.decode(chat_history_ids[-1])) # uncomment to see full gpt input |
|
|
|
# save previous len |
|
input_len = chat_history_ids.shape[-1] |
|
# generated a response; PS you can read about the parameters at hf.co/blog/how-to-generate |
|
chat_history_ids = model.generate( |
|
chat_history_ids, |
|
num_return_sequences=1, # use for more variants, but have to print [i] |
|
max_length=512, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.9, |
|
temperature = 0.6, # 0 for greedy |
|
mask_token_id=tokenizer.mask_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
unk_token_id=tokenizer.unk_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
device='cpu' |
|
) |
|
|
|
|
|
# pretty print last ouput tokens from bot |
|
print(f"===> GPT-3: {tokenizer.decode(chat_history_ids[:, input_len:][0], skip_special_tokens=True)}") |
|
``` |