import logging from telegram import Update from transformers import AutoModelForCausalLM, AutoTokenizer import torch from telegram.ext import ( Updater, CommandHandler, ConversationHandler, CallbackContext, CallbackQueryHandler, MessageHandler, filters ) NAME = "Conversation" DESCRIPTION = """ Useful for building up conversation. Input: A normal chat text Output: A text """ GET_CON = range(1) class Conversation(): tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") async def talk(self, input: str): logging.info(f"{input}") new_user_input_ids = self.tokenizer.encode(input(f"{input}") + self.tokenizer.eos_token, return_tensors='pt') bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id) return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)) async def process_conversation(self, update: Update, context: CallbackContext) -> int: message = update.message.text text = await self.talk(message) await update.message.reply_text(f'{text}')