rexthecoder commited on
Commit
9d7b3a0
·
1 Parent(s): 7c3e9f4
Files changed (1) hide show
  1. src/agent/tools/conversation.py +3 -9
src/agent/tools/conversation.py CHANGED
@@ -3,13 +3,7 @@ from telegram import Update
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from telegram.ext import (
6
- Updater,
7
- CommandHandler,
8
- ConversationHandler,
9
  CallbackContext,
10
- CallbackQueryHandler,
11
- MessageHandler,
12
- filters
13
  )
14
 
15
  NAME = "Conversation"
@@ -26,9 +20,9 @@ class Conversation():
26
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
27
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
28
 
29
- async def talk(self, input: str):
30
- logging.info(f"{input}")
31
- new_user_input_ids = self.tokenizer.encode(input(f"{input}") + self.tokenizer.eos_token, return_tensors='pt')
32
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
33
  chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
34
  return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  from telegram.ext import (
 
 
 
6
  CallbackContext,
 
 
 
7
  )
8
 
9
  NAME = "Conversation"
 
20
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
21
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
22
 
23
+ async def talk(self, message: str):
24
+ logging.info(f"{message}")
25
+ new_user_input_ids = self.tokenizer.encode(input(f'{message}') + self.tokenizer.eos_token, return_tensors='pt')
26
  bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
27
  chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
28
  return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))