Spaces:
Runtime error
Runtime error
File size: 2,027 Bytes
9f99fe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import logging
import telegram
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from telegram.ext import (
Updater,
CommandHandler,
ConversationHandler,
CallbackContext,
CallbackQueryHandler,
MessageHandler,
filters
)
from telegram import Update
NAME = "ConversationSummary"
DESCRIPTION = """
Useful for summarizing conversation.
Input: A conversation text
Output: A summazrised version of the conversation
"""
SELECT_COMMAND, GET_TEXT = range(2)
class ConversationSummary():
"""Tool used to summarize text from a conversational text."""
tokenizer = AutoTokenizer.from_pretrained(
"mrm8488/flan-t5-small-finetuned-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained(
"mrm8488/flan-t5-small-finetuned-samsum")
async def summarize(self, input: str, words: int):
logging.info(f"{input} {words}")
input_ids = self.tokenizer(input, return_tensors="pt").input_ids
outputs = self.model.generate(input_ids, max_length=words)
decoded_output = self.tokenizer.decode(
outputs[0], skip_special_tokens=True)
return f"{decoded_output}"
def conversation_summary_handler(self) -> ConversationHandler:
handler = ConversationHandler(
entry_points=[CommandHandler(
"summary", self.conversation_summary)],
states={
GET_TEXT: [MessageHandler(filters.TEXT & ~filters.COMMAND, self.process_conversation_summary)],
},
fallbacks=[CommandHandler("cancel", self.cancel)],
)
return handler
async def process_conversation_summary(self, update: Update, context: CallbackContext) -> int:
message = update.message.text
await self.summarize(message, 100)
return ConversationHandler.END
async def conversation_summary(self, update: Update, context: CallbackContext) -> str:
await update.message.reply_text(f'Please enter your conversations...')
return GET_TEXT
|