Spaces:
Runtime error
Runtime error
import logging | |
from telegram import Update | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from telegram.ext import ( | |
CallbackContext, | |
) | |
NAME = "Conversation" | |
DESCRIPTION = """ | |
Useful for building up conversation. | |
Input: A normal chat text | |
Output: A text | |
""" | |
GET_CON = range(1) | |
class Conversation(): | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
async def talk(self, message: str): | |
logging.info(f"{message}") | |
new_user_input_ids = self.tokenizer.encode(input(f'{message}') + self.tokenizer.eos_token, return_tensors='pt') | |
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) | |
chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id) | |
return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)) | |
async def process_conversation(self, update: Update, context: CallbackContext) -> int: | |
message = update.message.text | |
text = await self.talk(message) | |
await update.message.reply_text(f'{text}') | |