File size: 3,030 Bytes
7c3e9f4
 
 
4f624b9
 
7c3e9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
e43ab69
642e116
4f624b9
 
642e116
 
 
 
 
 
 
 
d86afc2
4f624b9
 
 
 
 
 
 
 
 
 
 
5dfb7cd
 
4f624b9
d86afc2
 
 
 
 
 
 
 
 
 
 
 
 
 
7c3e9f4
 
 
d86afc2
 
 
 
 
7c3e9f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import logging
from telegram import Update
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BlenderbotForConditionalGeneration, BlenderbotForCausalLM, BlenderbotTokenizer

from telegram.ext import (
    CallbackContext,
)

NAME = "Conversation"

DESCRIPTION = """
Useful for building up conversation. 
Input: A normal chat text
Output: A text
"""

GET_CON = range(1)


class Conversation():
    tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
    model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill",add_cross_attention=False)

    # async def talk(self, message: str):
    #     logging.info(f"{message}")
    #     chat_history_ids = torch.tensor([], dtype=torch.long)
    #     new_user_input_ids = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt')
    #     bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
    #     chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
    #     return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
    def predict(self, input, history=[]):
       # tokenize the new input sentence
       new_user_input_ids = self.tokenizer.encode(input + self.tokenizer.eos_token, return_tensors='pt')

       # append the new user input tokens to the chat history
       bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)

       # generate a response 
       history = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into the right format
       response = self.tokenizer.decode(history[0]).replace("<s>","").split("</s>")
       #response = [(response[i], response[i+1]) for i in range(0, len(response), 2)]  # convert to tuples of list
       return f'{response}'

    # def generate(self, instruction, knowledge, dialog):
    #     if knowledge != '':
    #         knowledge = '[KNOWLEDGE] ' + knowledge
    #     dialog = ' EOS '.join(dialog)
    #     query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
    #     input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids
    #     outputs = self.model.generate(
    #         input_ids, max_length=128,
    #         min_length=8,
    #         top_p=0.9,
    #         do_sample=True,
    #     )
    #     output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    #     return output

    async def process_conversation(self, update: Update, context: CallbackContext) -> int:
        message = update.message.text
        # instruction = f'Instruction: given a dialog context, you need to response empathically.'
        # knowledge = ''
        # dialog = []
        # dialog .append(message)
        text = self.predict(message)
        await update.message.reply_text(f'{text}')