BuddyPanda / src /agent /tools /text_summary.py
rexthecoder's picture
chore: update
9f99fe2
raw
history blame
2.03 kB
import logging
import telegram
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from telegram.ext import (
Updater,
CommandHandler,
ConversationHandler,
CallbackContext,
CallbackQueryHandler,
MessageHandler,
filters
)
from telegram import Update
NAME = "ConversationSummary"
DESCRIPTION = """
Useful for summarizing conversation.
Input: A conversation text
Output: A summazrised version of the conversation
"""
SELECT_COMMAND, GET_TEXT = range(2)
class ConversationSummary():
"""Tool used to summarize text from a conversational text."""
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/flan-t5-small-finetuned-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained(
"mrm8488/flan-t5-small-finetuned-samsum")
async def summarize(self, input: str, words: int):
logging.info(f"{input} {words}")
input_ids = self.tokenizer(input, return_tensors="pt").input_ids
outputs = self.model.generate(input_ids, max_length=words)
decoded_output = self.tokenizer.decode(
outputs[0], skip_special_tokens=True)
return f"{decoded_output}"
def conversation_summary_handler(self) -> ConversationHandler:
handler = ConversationHandler(
entry_points=[CommandHandler(
"summary", self.conversation_summary)],
states={
GET_TEXT: [MessageHandler(filters.TEXT & ~filters.COMMAND, self.process_conversation_summary)],
},
fallbacks=[CommandHandler("cancel", self.cancel)],
)
return handler
async def process_conversation_summary(self, update: Update, context: CallbackContext) -> int:
message = update.message.text
await self.summarize(message, 100)
return ConversationHandler.END
async def conversation_summary(self, update: Update, context: CallbackContext) -> str:
await update.message.reply_text(f'Please enter your conversations...')
return GET_TEXT