rexthecoder commited on
Commit
4f624b9
·
1 Parent(s): af7a462

chore: fix data

Browse files
Files changed (2) hide show
  1. main.py +1 -1
  2. src/agent/tools/conversation.py +17 -16
main.py CHANGED
@@ -43,7 +43,7 @@ class LoggingDisabled:
43
 
44
  def main():
45
  app = Application.builder().token(
46
- '5998527257:AAGtWduI4IPlHu2bb2WQC1TNEJ6XutaZTko',).build()
47
 
48
  run_agent(
49
  agent=GirlfriendGPT(
 
43
 
44
  def main():
45
  app = Application.builder().token(
46
+ '5998527257:AAH9cWNMsakaRJNSDW2OucR_Qb1J2noL0Ak',).build()
47
 
48
  run_agent(
49
  agent=GirlfriendGPT(
src/agent/tools/conversation.py CHANGED
@@ -1,7 +1,8 @@
1
  import logging
2
  from telegram import Update
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
 
 
5
  from telegram.ext import (
6
  CallbackContext,
7
  )
@@ -18,10 +19,8 @@ GET_CON = range(1)
18
 
19
 
20
  class Conversation():
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- "microsoft/GODEL-v1_1-large-seq2seq")
23
- model = AutoModelForSeq2SeqLM.from_pretrained(
24
- "microsoft/GODEL-v1_1-large-seq2seq")
25
 
26
  # async def talk(self, message: str):
27
  # logging.info(f"{message}")
@@ -31,18 +30,20 @@ class Conversation():
31
  # chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
32
  # return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
33
  def predict(self, input, history=[]):
34
- instruction = "Instruction: given a dialog context and related knowledge, you need to answer the question based on the knowledge."
35
- knowledge = '[KNOWLEDGE] ' + ''
36
- s = list(sum(history, ()))
37
- s.append(input)
38
- dialog = ' EOS '.join(s)
39
- query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
40
- input_ids = self.tokenizer.encode(f"{query}", return_tensors='pt')
41
- print(input, s)
42
- output = self.model.generate(input_ids, max_length=128, min_length=8, top_p=0.9, do_sample=True).tolist()
43
- response = self.tokenizer.decode(output[0], skip_special_tokens=True)
44
- history.append((input, response))
 
45
  return response
 
46
  # def generate(self, instruction, knowledge, dialog):
47
  # if knowledge != '':
48
  # knowledge = '[KNOWLEDGE] ' + knowledge
 
1
  import logging
2
  from telegram import Update
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BlenderbotForConditionalGeneration, BlenderbotForCausalLM, BlenderbotTokenizer
5
+
6
  from telegram.ext import (
7
  CallbackContext,
8
  )
 
19
 
20
 
21
  class Conversation():
22
+ tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
23
+ model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill",add_cross_attention=False)
 
 
24
 
25
  # async def talk(self, message: str):
26
  # logging.info(f"{message}")
 
30
  # chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
31
  # return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
32
  def predict(self, input, history=[]):
33
+ # tokenize the new input sentence
34
+ new_user_input_ids = self.tokenizer.encode(input + self.tokenizer.eos_token, return_tensors='pt')
35
+
36
+ # append the new user input tokens to the chat history
37
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
38
+
39
+ # generate a response
40
+ history = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id).tolist()
41
+
42
+ # convert the tokens to text, and then split the responses into the right format
43
+ response = self.tokenizer.decode(history[0]).replace("<s>","").split("</s>")
44
+ response = [(response[i], response[i+1]) for i in range(0, len(response), 2)] # convert to tuples of list
45
  return response
46
+
47
  # def generate(self, instruction, knowledge, dialog):
48
  # if knowledge != '':
49
  # knowledge = '[KNOWLEDGE] ' + knowledge