Niki Zhang commited on
Commit
711583c
·
verified ·
1 Parent(s): e16096a

Update chatbox.py

Browse files
Files changed (1) hide show
  1. chatbox.py +10 -4
chatbox.py CHANGED
@@ -12,7 +12,7 @@ import inspect
12
 
13
  from langchain.agents.initialize import initialize_agent
14
  from langchain.agents.tools import Tool
15
- from langchain.memory import ConversationBufferMemory
16
  from langchain_community.chat_models import ChatOpenAI
17
  import torch
18
  from PIL import Image, ImageDraw, ImageOps
@@ -141,7 +141,7 @@ class ConversationBot:
141
  def __init__(self, tools, api_key=""):
142
  # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
143
  print("chatbot api",api_key)
144
- llm = ChatOpenAI(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key, model_kwargs={"api_version": "2020-11-07"})
145
  self.llm = llm
146
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
147
  self.tools = tools
@@ -172,11 +172,17 @@ class ConversationBot:
172
  return ans
173
 
174
  def run_text(self, text, state, aux_state):
175
- self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
 
 
 
 
 
176
  if self.point_prompt != "":
177
  Human_prompt = f'\nHuman: {self.point_prompt}\n'
178
  AI_prompt = 'Ok'
179
- self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
 
180
  self.point_prompt = ""
181
  res = self.agent({"input": text})
182
  res['output'] = res['output'].replace("\\", "/")
 
12
 
13
  from langchain.agents.initialize import initialize_agent
14
  from langchain.agents.tools import Tool
15
+ from langchain.chains.conversation.memory import ConversationBufferMemory
16
  from langchain_community.chat_models import ChatOpenAI
17
  import torch
18
  from PIL import Image, ImageDraw, ImageOps
 
141
  def __init__(self, tools, api_key=""):
142
  # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
143
  print("chatbot api",api_key)
144
+ llm = ChatOpenAI(model_name="gpt-4o", temperature=0.7, openai_api_key=api_key)
145
  self.llm = llm
146
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
147
  self.tools = tools
 
172
  return ans
173
 
174
  def run_text(self, text, state, aux_state):
175
+ memory_str = self.agent.memory.buffer_as_str
176
+ trimmed_memory_str = cut_dialogue_history(memory_str, keep_last_n_words=500)
177
+ trimmed_messages = self.memory.buffer_as_messages[:len(trimmed_memory_str.split())]
178
+ # self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
179
+ self.memory.chat_memory.messages = trimmed_messages
180
+ print("done")
181
  if self.point_prompt != "":
182
  Human_prompt = f'\nHuman: {self.point_prompt}\n'
183
  AI_prompt = 'Ok'
184
+ # self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
185
+ self.agent.memory.save_context({'input': Human_prompt}, {'output': AI_prompt})
186
  self.point_prompt = ""
187
  res = self.agent({"input": text})
188
  res['output'] = res['output'].replace("\\", "/")