Spaces:
Runtime error
Runtime error
import logging | |
import telegram | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from telegram.ext import ( | |
Updater, | |
CommandHandler, | |
ConversationHandler, | |
CallbackContext, | |
CallbackQueryHandler, | |
MessageHandler, | |
filters | |
) | |
from telegram import Update | |
NAME = "ConversationSummary" | |
DESCRIPTION = """ | |
Useful for summarizing conversation. | |
Input: A conversation text | |
Output: A summazrised version of the conversation | |
""" | |
SELECT_COMMAND, GET_TEXT = range(2) | |
class ConversationSummary(): | |
"""Tool used to summarize text from a conversational text.""" | |
tokenizer = AutoTokenizer.from_pretrained( | |
"mrm8488/flan-t5-small-finetuned-samsum") | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"mrm8488/flan-t5-small-finetuned-samsum") | |
async def summarize(self, input: str, words: int): | |
logging.info(f"{input} {words}") | |
input_ids = self.tokenizer(input, return_tensors="pt").input_ids | |
outputs = self.model.generate(input_ids, max_length=words) | |
decoded_output = self.tokenizer.decode( | |
outputs[0], skip_special_tokens=True) | |
return f"{decoded_output}" | |
def conversation_summary_handler(self) -> ConversationHandler: | |
handler = ConversationHandler( | |
entry_points=[CommandHandler( | |
"summary", self.conversation_summary)], | |
states={ | |
GET_TEXT: [MessageHandler(filters.TEXT & ~filters.COMMAND, self.process_conversation_summary)], | |
}, | |
fallbacks=[CommandHandler("cancel", self.cancel)], | |
) | |
return handler | |
async def process_conversation_summary(self, update: Update, context: CallbackContext) -> int: | |
message = update.message.text | |
await self.summarize(message, 100) | |
return ConversationHandler.END | |
async def conversation_summary(self, update: Update, context: CallbackContext) -> str: | |
await update.message.reply_text(f'Please enter your conversations...') | |
return GET_TEXT | |