Spaces:
Runtime error
Runtime error
matthoffner
commited on
Commit
•
ab83be1
1
Parent(s):
8a15493
Try to fix repeating on User
Browse files
demo.py
CHANGED
@@ -98,18 +98,18 @@ def chat():
|
|
98 |
if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
|
99 |
yield chat_history
|
100 |
return
|
101 |
-
|
102 |
if message == RETRY_COMMAND and chat_history:
|
103 |
prev_turn = chat_history.pop(-1)
|
104 |
user_message, _ = prev_turn
|
105 |
message = user_message
|
106 |
-
|
107 |
prompt = format_chat_prompt(message, chat_history, instructions)
|
108 |
chat_history = chat_history + [[message, ""]]
|
109 |
stream = llm(
|
110 |
prompt,
|
111 |
max_new_tokens=1024,
|
112 |
-
stop=[STOP_STR, "
|
113 |
temperature=temperature,
|
114 |
top_p=top_p,
|
115 |
stream=True
|
@@ -117,20 +117,27 @@ def chat():
|
|
117 |
acc_text = ""
|
118 |
for idx, response in enumerate(stream):
|
119 |
text_token = response
|
120 |
-
|
121 |
-
if
|
122 |
-
|
|
|
|
|
|
|
|
|
123 |
continue
|
124 |
-
|
125 |
if idx == 0 and text_token.startswith(" "):
|
126 |
text_token = text_token[1:]
|
127 |
-
|
128 |
acc_text += text_token
|
|
|
|
|
|
|
129 |
last_turn = list(chat_history.pop(-1))
|
130 |
last_turn[-1] += acc_text
|
131 |
chat_history = chat_history + [last_turn]
|
132 |
yield chat_history
|
133 |
-
|
134 |
|
135 |
def delete_last_turn(chat_history):
|
136 |
if chat_history:
|
|
|
98 |
if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
|
99 |
yield chat_history
|
100 |
return
|
101 |
+
|
102 |
if message == RETRY_COMMAND and chat_history:
|
103 |
prev_turn = chat_history.pop(-1)
|
104 |
user_message, _ = prev_turn
|
105 |
message = user_message
|
106 |
+
|
107 |
prompt = format_chat_prompt(message, chat_history, instructions)
|
108 |
chat_history = chat_history + [[message, ""]]
|
109 |
stream = llm(
|
110 |
prompt,
|
111 |
max_new_tokens=1024,
|
112 |
+
stop=[STOP_STR, ""],
|
113 |
temperature=temperature,
|
114 |
top_p=top_p,
|
115 |
stream=True
|
|
|
117 |
acc_text = ""
|
118 |
for idx, response in enumerate(stream):
|
119 |
text_token = response
|
120 |
+
|
121 |
+
if acc_text.endswith(STOP_STR):
|
122 |
+
last_turn = list(chat_history.pop(-1))
|
123 |
+
last_turn[-1] += acc_text[:-len(STOP_STR)]
|
124 |
+
chat_history = chat_history + [last_turn]
|
125 |
+
yield chat_history
|
126 |
+
acc_text = text_token
|
127 |
continue
|
128 |
+
|
129 |
if idx == 0 and text_token.startswith(" "):
|
130 |
text_token = text_token[1:]
|
131 |
+
|
132 |
acc_text += text_token
|
133 |
+
|
134 |
+
# If there's any remaining text after the loop
|
135 |
+
if acc_text:
|
136 |
last_turn = list(chat_history.pop(-1))
|
137 |
last_turn[-1] += acc_text
|
138 |
chat_history = chat_history + [last_turn]
|
139 |
yield chat_history
|
140 |
+
|
141 |
|
142 |
def delete_last_turn(chat_history):
|
143 |
if chat_history:
|