matthoffner commited on
Commit
3ff3725
1 Parent(s): 64a8115

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +9 -15
demo.py CHANGED
@@ -14,7 +14,7 @@ DEFAULT_INSTRUCTIONS = f"""The following is a conversation between a highly know
14
  """
15
  RETRY_COMMAND = "/retry"
16
  STOP_STR = f"\n{USER_NAME}:"
17
- STOP_SUSPECT_LIST = [":", "\n", "User"]
18
 
19
 
20
  def chat_accordion():
@@ -98,12 +98,12 @@ def chat():
98
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
99
  yield chat_history
100
  return
101
-
102
  if message == RETRY_COMMAND and chat_history:
103
  prev_turn = chat_history.pop(-1)
104
  user_message, _ = prev_turn
105
  message = user_message
106
-
107
  prompt = format_chat_prompt(message, chat_history, instructions)
108
  chat_history = chat_history + [[message, ""]]
109
  stream = llm(
@@ -117,26 +117,20 @@ def chat():
117
  acc_text = ""
118
  for idx, response in enumerate(stream):
119
  text_token = response
120
-
121
- if acc_text.endswith(STOP_STR):
122
- last_turn = list(chat_history.pop(-1))
123
- last_turn[-1] += acc_text[:-len(STOP_STR)]
124
- chat_history = chat_history + [last_turn]
125
- yield chat_history
126
- acc_text = text_token
127
  continue
128
-
129
  if idx == 0 and text_token.startswith(" "):
130
  text_token = text_token[1:]
131
-
132
  acc_text += text_token
133
-
134
- # If there's any remaining text after the loop
135
- if acc_text:
136
  last_turn = list(chat_history.pop(-1))
137
  last_turn[-1] += acc_text
138
  chat_history = chat_history + [last_turn]
139
  yield chat_history
 
140
 
141
 
142
  def delete_last_turn(chat_history):
 
14
  """
15
  RETRY_COMMAND = "/retry"
16
  STOP_STR = f"\n{USER_NAME}:"
17
+ STOP_SUSPECT_LIST = [":", "\n", "User", "\nUser"]
18
 
19
 
20
  def chat_accordion():
 
98
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
99
  yield chat_history
100
  return
101
+
102
  if message == RETRY_COMMAND and chat_history:
103
  prev_turn = chat_history.pop(-1)
104
  user_message, _ = prev_turn
105
  message = user_message
106
+
107
  prompt = format_chat_prompt(message, chat_history, instructions)
108
  chat_history = chat_history + [[message, ""]]
109
  stream = llm(
 
117
  acc_text = ""
118
  for idx, response in enumerate(stream):
119
  text_token = response
120
+
121
+ if text_token in STOP_SUSPECT_LIST:
122
+ acc_text += text_token
 
 
 
 
123
  continue
124
+
125
  if idx == 0 and text_token.startswith(" "):
126
  text_token = text_token[1:]
127
+
128
  acc_text += text_token
 
 
 
129
  last_turn = list(chat_history.pop(-1))
130
  last_turn[-1] += acc_text
131
  chat_history = chat_history + [last_turn]
132
  yield chat_history
133
+ acc_text = ""
134
 
135
 
136
  def delete_last_turn(chat_history):