kcarnold commited on
Commit
fe6cd36
Β·
1 Parent(s): 38826eb

Initial stab at typing the assistant response

Browse files
Files changed (1) hide show
  1. app.py +60 -1
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  import requests
3
 
 
 
4
  def landing():
5
  st.title("Writing Tools Prototypes")
6
  st.markdown("Click one of the links below to see a prototype in action.")
@@ -8,6 +10,7 @@ def landing():
8
  st.page_link(rewrite_page, label="Rewrite with predictions", icon="πŸ“")
9
  st.page_link(highlight_page, label="Highlight locations for possible edits", icon="πŸ–οΈ")
10
  st.page_link(generate_page, label="Generate revisions", icon="πŸ”„")
 
11
 
12
  st.markdown("*Note*: These services send data to a remote server for processing. The server logs requests. Don't use sensitive or identifiable information on this page.")
13
 
@@ -205,16 +208,72 @@ def generate_revisions():
205
  with tab:
206
  st.write(revised_docs[i]['doc_text'])
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="πŸ“")
210
  highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="πŸ–οΈ")
211
  generate_page = st.Page(generate_revisions, title="Generate revisions", icon="πŸ”„")
 
212
 
213
  # Manually specify the sidebar
214
  page = st.navigation([
215
  st.Page(landing, title="Home", icon="🏠"),
216
  highlight_page,
217
  rewrite_page,
218
- generate_page
 
219
  ])
220
  page.run()
 
1
  import streamlit as st
2
  import requests
3
 
4
+ API_SERVER = "https://tools.kenarnold.org/api"
5
+
6
  def landing():
7
  st.title("Writing Tools Prototypes")
8
  st.markdown("Click one of the links below to see a prototype in action.")
 
10
  st.page_link(rewrite_page, label="Rewrite with predictions", icon="πŸ“")
11
  st.page_link(highlight_page, label="Highlight locations for possible edits", icon="πŸ–οΈ")
12
  st.page_link(generate_page, label="Generate revisions", icon="πŸ”„")
13
+ st.page_link(type_assistant_response_page, label="Type Assistant Response", icon="πŸ”€")
14
 
15
  st.markdown("*Note*: These services send data to a remote server for processing. The server logs requests. Don't use sensitive or identifiable information on this page.")
16
 
 
208
  with tab:
209
  st.write(revised_docs[i]['doc_text'])
210
 
211
+ def type_assistant_response():
212
+ if 'messages' not in st.session_state:
213
+ st.session_state['messages'] = []
214
+ messages = st.session_state.messages
215
+
216
+ for message in st.session_state.messages[:-1]:
217
+ with st.chat_message(message["role"]):
218
+ st.markdown(message["content"])
219
+
220
+ if prompt := st.chat_input(""):
221
+ # Display user message in chat message container
222
+ with st.chat_message("user"):
223
+ st.markdown(prompt)
224
+ # Add user message to chat history
225
+ messages.append({"role": "user", "content": prompt})
226
+ messages.append({"role": "assistant", "content": ""})
227
+
228
+ if len(messages) == 0:
229
+ st.stop()
230
+
231
+ response = requests.post(
232
+ f"{API_SERVER}/continue_messages",
233
+ json={
234
+ "messages": messages,
235
+ "n_branch_tokens": 5,
236
+ "n_future_tokens": 2
237
+ }
238
+ )
239
+ if response.status_code != 200:
240
+ st.error("Error fetching response")
241
+ st.write(response.text)
242
+ st.stop()
243
+ response.raise_for_status()
244
+ response = response.json()
245
+
246
+ # Display assistant response in chat message container
247
+ with st.chat_message("assistant"):
248
+ st.write(messages[-1]['content'])
249
+ def append_token(word):
250
+ messages[-1]['content'] = (
251
+ messages[-1]['content'] + word
252
+ )
253
+
254
+ allow_multi_word = st.checkbox("Allow multi-word predictions", value=False)
255
+
256
+ continuations = response['continuations']
257
+ for i, (col, continuation) in enumerate(zip(st.columns(len(continuations)), continuations)):
258
+ token = continuation['doc_text']
259
+ with col:
260
+ if not allow_multi_word and ' ' in token[1:]:
261
+ token = token[0] + token[1:].split(' ', 1)[0]
262
+ token_display = show_token(token)
263
+ st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
264
+
265
 
266
  rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="πŸ“")
267
  highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="πŸ–οΈ")
268
  generate_page = st.Page(generate_revisions, title="Generate revisions", icon="πŸ”„")
269
+ type_assistant_response_page = st.Page(type_assistant_response, title="Type Assistant Response", icon="πŸ”€")
270
 
271
  # Manually specify the sidebar
272
  page = st.navigation([
273
  st.Page(landing, title="Home", icon="🏠"),
274
  highlight_page,
275
  rewrite_page,
276
+ generate_page,
277
+ type_assistant_response_page
278
  ])
279
  page.run()