Spaces:
Running
Running
extend type-assistant-response to predictive-text for type user messages too
Browse files
app.py
CHANGED
@@ -209,35 +209,33 @@ def generate_revisions():
|
|
209 |
st.write(revised_docs[i]['doc_text'])
|
210 |
|
211 |
def type_assistant_response():
|
212 |
-
if 'messages' not in st.session_state:
|
213 |
-
st.session_state['messages'] = []
|
|
|
214 |
messages = st.session_state.messages
|
215 |
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
218 |
with st.chat_message(message["role"]):
|
219 |
st.markdown(message["content"])
|
|
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
# Display assistant response in chat message container
|
233 |
-
with st.chat_message("assistant"):
|
234 |
-
#st.write(messages[-1]['content'])
|
235 |
-
msg_in_progress = st.text_area("Assistant response", value=messages[-1]['content'], placeholder="Clicking the buttons below will update this field. You can also edit it directly; press Ctrl+Enter to apply changes.", height=300)
|
236 |
-
# strip spaces (but not newlines) to avoid a tokenization issue
|
237 |
-
msg_in_progress = msg_in_progress.rstrip(' ')
|
238 |
|
239 |
def append_token(word):
|
240 |
-
messages[-1]['content'] = (
|
241 |
msg_in_progress + word
|
242 |
)
|
243 |
|
@@ -246,9 +244,7 @@ def type_assistant_response():
|
|
246 |
response = requests.post(
|
247 |
f"{API_SERVER}/continue_messages",
|
248 |
json={
|
249 |
-
"messages": messages
|
250 |
-
{"role": "assistant", "content": msg_in_progress},
|
251 |
-
],
|
252 |
"n_branch_tokens": 5,
|
253 |
"n_future_tokens": 2
|
254 |
}
|
@@ -278,6 +274,12 @@ def type_assistant_response():
|
|
278 |
token_display = show_token(token)
|
279 |
st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="📝")
|
283 |
highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="🖍️")
|
|
|
209 |
st.write(revised_docs[i]['doc_text'])
|
210 |
|
211 |
def type_assistant_response():
|
212 |
+
if 'messages' not in st.session_state or st.button("Start a new conversation"):
|
213 |
+
st.session_state['messages'] = [{"role": "user", "content": ""}]
|
214 |
+
st.session_state['msg_in_progress'] = ""
|
215 |
messages = st.session_state.messages
|
216 |
|
217 |
+
def rewind_to(i):
|
218 |
+
st.session_state.messages = st.session_state.messages[:i+1]
|
219 |
+
st.session_state['msg_in_progress'] = st.session_state.messages[-1]['content']
|
220 |
+
|
221 |
+
for i, message in enumerate(st.session_state.messages[:-1]):
|
222 |
with st.chat_message(message["role"]):
|
223 |
st.markdown(message["content"])
|
224 |
+
st.button("Edit", on_click=rewind_to, args=(i,))
|
225 |
|
226 |
+
# Display message-in-progress in chat message container
|
227 |
+
last_role = messages[-1]["role"]
|
228 |
+
with st.chat_message(last_role):
|
229 |
+
label = "Your message" if last_role == "user" else "Assistant response"
|
230 |
+
msg_in_progress = st.text_area(label, placeholder="Clicking the buttons below will update this field. You can also edit it directly; press Ctrl+Enter to apply changes.", height=300, key="msg_in_progress")
|
231 |
+
if msg_in_progress is None:
|
232 |
+
msg_in_progress = ""
|
233 |
|
234 |
+
messages[-1]['content'] = msg_in_progress
|
235 |
+
print(messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
def append_token(word):
|
238 |
+
messages[-1]['content'] = st.session_state['msg_in_progress'] = (
|
239 |
msg_in_progress + word
|
240 |
)
|
241 |
|
|
|
244 |
response = requests.post(
|
245 |
f"{API_SERVER}/continue_messages",
|
246 |
json={
|
247 |
+
"messages": messages,
|
|
|
|
|
248 |
"n_branch_tokens": 5,
|
249 |
"n_future_tokens": 2
|
250 |
}
|
|
|
274 |
token_display = show_token(token)
|
275 |
st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
|
276 |
|
277 |
+
if st.button("Send"):
|
278 |
+
other_role = "assistant" if last_role == "user" else "user"
|
279 |
+
messages.append({"role": other_role, "content": ""})
|
280 |
+
st.session_state['msg_in_progress'] = ""
|
281 |
+
st.session_state.messages = messages
|
282 |
+
|
283 |
|
284 |
rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="📝")
|
285 |
highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="🖍️")
|