File size: 2,027 Bytes
9f99fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

import logging
import telegram
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from telegram.ext import (
    Updater,
    CommandHandler,
    ConversationHandler,
    CallbackContext,
    CallbackQueryHandler,
    MessageHandler,
    filters
)

from telegram import Update

NAME = "ConversationSummary"

DESCRIPTION = """
Useful for summarizing conversation. 
Input: A conversation text
Output: A summazrised version of the conversation
"""

SELECT_COMMAND, GET_TEXT = range(2)


class ConversationSummary():
    """Tool used to summarize text from a conversational text."""

    tokenizer = AutoTokenizer.from_pretrained(
        "mrm8488/flan-t5-small-finetuned-samsum")

    model = AutoModelForSeq2SeqLM.from_pretrained(
        "mrm8488/flan-t5-small-finetuned-samsum")

    async def summarize(self, input: str, words: int):

        logging.info(f"{input} {words}")

        input_ids = self.tokenizer(input, return_tensors="pt").input_ids
        outputs = self.model.generate(input_ids, max_length=words)
        decoded_output = self.tokenizer.decode(
            outputs[0], skip_special_tokens=True)
        return f"{decoded_output}"

    def conversation_summary_handler(self) -> ConversationHandler:
        handler = ConversationHandler(
            entry_points=[CommandHandler(
                "summary", self.conversation_summary)],
            states={
                GET_TEXT: [MessageHandler(filters.TEXT & ~filters.COMMAND, self.process_conversation_summary)],
            },
            fallbacks=[CommandHandler("cancel", self.cancel)],
        )
        return handler

    async def process_conversation_summary(self, update: Update, context: CallbackContext) -> int:
        message = update.message.text
        await self.summarize(message, 100)
        return ConversationHandler.END

    async def conversation_summary(self, update: Update, context: CallbackContext) -> str:
        await update.message.reply_text(f'Please enter your conversations...')
        return GET_TEXT