import logging import telegram from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from telegram.ext import ( Updater, CommandHandler, ConversationHandler, CallbackContext, CallbackQueryHandler, MessageHandler, filters ) from telegram import Update NAME = "ConversationSummary" DESCRIPTION = """ Useful for summarizing conversation. Input: A conversation text Output: A summazrised version of the conversation """ SELECT_COMMAND, GET_TEXT = range(2) class ConversationSummary(): """Tool used to summarize text from a conversational text.""" tokenizer = AutoTokenizer.from_pretrained( "mrm8488/flan-t5-small-finetuned-samsum") model = AutoModelForSeq2SeqLM.from_pretrained( "mrm8488/flan-t5-small-finetuned-samsum") async def summarize(self, input: str, words: int): logging.info(f"{input} {words}") input_ids = self.tokenizer(input, return_tensors="pt").input_ids outputs = self.model.generate(input_ids, max_length=words) decoded_output = self.tokenizer.decode( outputs[0], skip_special_tokens=True) return f"{decoded_output}" def conversation_summary_handler(self) -> ConversationHandler: handler = ConversationHandler( entry_points=[CommandHandler( "summary", self.conversation_summary)], states={ GET_TEXT: [MessageHandler(filters.TEXT & ~filters.COMMAND, self.process_conversation_summary)], }, fallbacks=[CommandHandler("cancel", self.cancel)], ) return handler async def process_conversation_summary(self, update: Update, context: CallbackContext) -> int: message = update.message.text await self.summarize(message, 100) return ConversationHandler.END async def conversation_summary(self, update: Update, context: CallbackContext) -> str: await update.message.reply_text(f'Please enter your conversations...') return GET_TEXT