BuddyPanda / src /agent /tools /conversation.py
rexthecoder's picture
chore: fix game
5dfb7cd
raw
history blame
3.03 kB
import logging
from telegram import Update
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BlenderbotForConditionalGeneration, BlenderbotForCausalLM, BlenderbotTokenizer
from telegram.ext import (
CallbackContext,
)
NAME = "Conversation"
DESCRIPTION = """
Useful for building up conversation.
Input: A normal chat text
Output: A text
"""
GET_CON = range(1)
class Conversation():
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill",add_cross_attention=False)
# async def talk(self, message: str):
# logging.info(f"{message}")
# chat_history_ids = torch.tensor([], dtype=torch.long)
# new_user_input_ids = self.tokenizer.encode(message + self.tokenizer.eos_token, return_tensors='pt')
# bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
# chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
# return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
def predict(self, input, history=[]):
# tokenize the new input sentence
new_user_input_ids = self.tokenizer.encode(input + self.tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = self.tokenizer.decode(history[0]).replace("<s>","").split("</s>")
#response = [(response[i], response[i+1]) for i in range(0, len(response), 2)] # convert to tuples of list
return f'{response}'
# def generate(self, instruction, knowledge, dialog):
# if knowledge != '':
# knowledge = '[KNOWLEDGE] ' + knowledge
# dialog = ' EOS '.join(dialog)
# query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
# input_ids = self.tokenizer(f"{query}", return_tensors="pt").input_ids
# outputs = self.model.generate(
# input_ids, max_length=128,
# min_length=8,
# top_p=0.9,
# do_sample=True,
# )
# output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# return output
async def process_conversation(self, update: Update, context: CallbackContext) -> int:
message = update.message.text
# instruction = f'Instruction: given a dialog context, you need to response empathically.'
# knowledge = ''
# dialog = []
# dialog .append(message)
text = self.predict(message)
await update.message.reply_text(f'{text}')