Spaces:
Runtime error
Runtime error
File size: 1,307 Bytes
7c3e9f4 e43ab69 9d7b3a0 50715ca e43ab69 7c3e9f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import logging
from telegram import Update
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from telegram.ext import (
CallbackContext,
)
NAME = "Conversation"
DESCRIPTION = """
Useful for building up conversation.
Input: A normal chat text
Output: A text
"""
GET_CON = range(1)
class Conversation():
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
chat_history_ids = torch.tensor([], dtype=torch.long)
async def talk(self, message: str):
logging.info(f"{message}")
new_user_input_ids = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1)
self.chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
return "{}".format(self.tokenizer.decode(self.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
async def process_conversation(self, update: Update, context: CallbackContext) -> int:
message = update.message.text
text = await self.talk(message)
await update.message.reply_text(f'{text}')
|