baohuynhbk14 commited on
Commit
7d6b74d
·
1 Parent(s): 4170f81

Refactor add_text and get_history functions to improve message handling and formatting

Browse files
Files changed (2) hide show
  1. app.py +17 -3
  2. conversation.py +7 -14
app.py CHANGED
@@ -141,7 +141,6 @@ def clear_history(request: gr.Request):
141
 
142
 
143
  def add_text(state, message, system_prompt, request: gr.Request):
144
- print(f"state: {state}")
145
  if not state:
146
  state = init_state()
147
  images = message.get("files", [])
@@ -162,11 +161,22 @@ def add_text(state, message, system_prompt, request: gr.Request):
162
  return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
163
  images = [Image.open(path).convert("RGB") for path in images]
164
 
 
165
  if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
166
  state = init_state(state)
167
-
 
168
  if len(images) > 0 and len(state.get_images(source=state.USER)) == 0:
169
- text = DEFAULT_IMAGE_TOKEN + "\n" + text
 
 
 
 
 
 
 
 
 
170
 
171
  state.set_system_message(system_prompt)
172
  state.append_message(Conversation.USER, text, images)
@@ -215,6 +225,10 @@ def predict(state,
215
  state.update_message(state.USER, DEFAULT_IMAGE_TOKEN + "\n" + first_user_message, None, index)
216
 
217
 
 
 
 
 
218
  history = state.get_history()
219
  logger.info(f"==== History ====\n{history}")
220
  message = history[-1][0] if len(history) > 0 else ""
 
141
 
142
 
143
  def add_text(state, message, system_prompt, request: gr.Request):
 
144
  if not state:
145
  state = init_state()
146
  images = message.get("files", [])
 
161
  return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5
162
  images = [Image.open(path).convert("RGB") for path in images]
163
 
164
+ # Init again if send the second image
165
  if len(images) > 0 and len(state.get_images(source=state.USER)) > 0:
166
  state = init_state(state)
167
+
168
+ # Upload the first image
169
  if len(images) > 0 and len(state.get_images(source=state.USER)) == 0:
170
+ if len(state.messages) == 0: ## In case the first message is an image
171
+ text = DEFAULT_IMAGE_TOKEN + "\n" + system_prompt + "\n" + text
172
+ else: ## In case the image is uploaded after some text messages
173
+ first_user_message = state.messages[0]['content']
174
+ state.update_message(Conversation.USER, DEFAULT_IMAGE_TOKEN + "\n" + first_user_message, None, 0)
175
+
176
+ # If the first message is text
177
+ if len(images) == 0 and len(state.get_images(source=state.USER)) == 0 and len(state.messages) == 0:
178
+ text = system_prompt + "\n" + text
179
+
180
 
181
  state.set_system_message(system_prompt)
182
  state.append_message(Conversation.USER, text, images)
 
225
  state.update_message(state.USER, DEFAULT_IMAGE_TOKEN + "\n" + first_user_message, None, index)
226
 
227
 
228
+ idx_last_user_mess, message = state.get_user_message(source=state.USER, position='last')
229
+ if idx_last_user_mess == 0 and DEFAULT_IMAGE_TOKEN not in first_user_message:
230
+ message = state.system_message + "\n" + message
231
+
232
  history = state.get_history()
233
  logger.info(f"==== History ====\n{history}")
234
  message = history[-1][0] if len(history) > 0 else ""
conversation.py CHANGED
@@ -137,35 +137,28 @@ class Conversation:
137
  return send_messages
138
 
139
  def get_history(self):
140
- results = []
141
  system_message = self.system_message
142
- messages = self.messages[:-2]
143
-
144
  logger.info(f"=== Raw messages ===\n{messages}")
145
 
 
 
 
146
  for i in range(len(messages)):
147
  if messages[i]['role'] == 'user':
148
  # Create the question by combining system and user messages
149
  user_content = messages[i]['content']
150
 
151
- # Check if it's the first user message and contains <image>
152
- if i == 0 and '<image>' in user_content:
153
- user_content = user_content.replace('<image>', '')
154
- question = f"<image>\n{system_message}\n{user_content}"
155
- elif i == 0 and '<image>' not in user_content:
156
- question = f"{system_message}\n{user_content}"
157
- else:
158
- question = f"{user_content}"
159
-
160
  # Check for the corresponding assistant response
161
  answer = ""
162
  if i + 1 < len(messages) and messages[i + 1]['role'] == 'assistant':
163
  answer = messages[i + 1]['content']
164
 
165
  # Add the question-answer pair to results
166
- results.append((question, answer))
167
 
168
- return results
169
 
170
 
171
  def append_message(
 
137
  return send_messages
138
 
139
  def get_history(self):
140
+ history = []
141
  system_message = self.system_message
142
+ messages = self.messages[:-1]
 
143
  logger.info(f"=== Raw messages ===\n{messages}")
144
 
145
+ if len(messages) < 2:
146
+ return None
147
+
148
  for i in range(len(messages)):
149
  if messages[i]['role'] == 'user':
150
  # Create the question by combining system and user messages
151
  user_content = messages[i]['content']
152
 
 
 
 
 
 
 
 
 
 
153
  # Check for the corresponding assistant response
154
  answer = ""
155
  if i + 1 < len(messages) and messages[i + 1]['role'] == 'assistant':
156
  answer = messages[i + 1]['content']
157
 
158
  # Add the question-answer pair to results
159
+ history.append((user_content, answer))
160
 
161
+ return history
162
 
163
 
164
  def append_message(