File size: 1,355 Bytes
7c3e9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import logging
from telegram import Update
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from telegram.ext import (
    Updater,
    CommandHandler,
    ConversationHandler,
    CallbackContext,
    CallbackQueryHandler,
    MessageHandler,
    filters
)

NAME = "Conversation"

DESCRIPTION = """
Useful for building up conversation. 
Input: A normal chat text
Output: A text
"""

GET_CON = range(1)

class Conversation():
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
    model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

    async def talk(self, input: str):
        logging.info(f"{input}")
        new_user_input_ids = self.tokenizer.encode(input(f"{input}") + self.tokenizer.eos_token, return_tensors='pt')
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
        chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
        return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))

    
    async def process_conversation(self, update: Update, context: CallbackContext) -> int:
        message = update.message.text
        text = await self.talk(message)
        await update.message.reply_text(f'{text}')