kcarnold commited on
Commit
05627f1
Β·
1 Parent(s): bfee54e

Add a quick and dirty "show internals" page

Browse files
Files changed (1) hide show
  1. app.py +103 -1
app.py CHANGED
@@ -296,11 +296,112 @@ def type_assistant_response():
296
  st.session_state['msg_in_progress'] = ""
297
  st.button("Send", on_click=send_message)
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="πŸ“")
301
  highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="πŸ–οΈ")
302
  generate_page = st.Page(generate_revisions, title="Generate revisions", icon="πŸ”„")
303
  type_assistant_response_page = st.Page(type_assistant_response, title="Type Assistant Response", icon="πŸ”€")
 
304
 
305
  # Manually specify the sidebar
306
  page = st.navigation([
@@ -308,6 +409,7 @@ page = st.navigation([
308
  highlight_page,
309
  rewrite_page,
310
  generate_page,
311
- type_assistant_response_page
 
312
  ])
313
  page.run()
 
296
  st.session_state['msg_in_progress'] = ""
297
  st.button("Send", on_click=send_message)
298
 
299
+ def show_internals():
300
+ if 'messages' not in st.session_state or st.button("Start a new conversation"):
301
+ st.session_state['messages'] = [{"role": "user", "content": ""}]
302
+ st.session_state['msg_in_progress'] = ""
303
+ messages = st.session_state.messages
304
+
305
+ def rewind_to(i):
306
+ st.session_state.messages = st.session_state.messages[:i+1]
307
+ st.session_state['msg_in_progress'] = st.session_state.messages[-1]['content']
308
+
309
+ for i, message in enumerate(st.session_state.messages[:-1]):
310
+ with st.chat_message(message["role"]):
311
+ st.markdown(message["content"])
312
+ st.button("Edit", on_click=rewind_to, args=(i,), key=f"rewind_to_{i}")
313
+
314
+ # Display message-in-progress in chat message container
315
+ last_role = messages[-1]["role"]
316
+ with st.chat_message(last_role):
317
+ label = "Your message" if last_role == "user" else "Assistant response"
318
+ msg_in_progress = st.text_area(label, placeholder="Clicking the buttons below will update this field. You can also edit it directly; press Ctrl+Enter to apply changes.", height=300, key="msg_in_progress")
319
+ if msg_in_progress is None:
320
+ msg_in_progress = ""
321
+
322
+ messages[-1]['content'] = msg_in_progress
323
+
324
+ def append_token(word):
325
+ messages[-1]['content'] = st.session_state['msg_in_progress'] = (
326
+ msg_in_progress + word
327
+ )
328
+
329
+ response = requests.post(
330
+ f"{API_SERVER}/logprobs",
331
+ json={
332
+ "messages": messages,
333
+ "n_branch_tokens": 5,
334
+ "n_future_tokens": 2
335
+ }
336
+ )
337
+ if response.status_code != 200:
338
+ st.error("Error fetching response")
339
+ st.write(response.text)
340
+ st.stop()
341
+ response.raise_for_status()
342
+ response = response.json()
343
+
344
+ logprobs = response['logprobs']
345
+ # logprobs is a list of tokens:
346
+ # {
347
+ # "token": "the",
348
+ # "logprobs": [{"the": -0.1, "a": -0.2, ...}]
349
+ # }
350
+ #st.write(logprobs)
351
+ logprobs_component(logprobs)
352
+
353
+ def send_message():
354
+ other_role = "assistant" if last_role == "user" else "user"
355
+ st.session_state['messages'].append({"role": other_role, "content": ""})
356
+ st.session_state['msg_in_progress'] = ""
357
+ st.button("Send", on_click=send_message)
358
+
359
+ def logprobs_component(logprobs):
360
+ import html, json
361
+ html_out = ''
362
+ for i, entry in enumerate(logprobs):
363
+ token = entry['token']
364
+ if token is not None:
365
+ token_to_show = html.escape(show_token(token, escape_markdown=False))
366
+ else:
367
+ token_to_show = html.escape("<empty>")
368
+ html_out += f'<span onclick="showLogprobs({i})" title="Click to show logprobs for this token">{token_to_show}</span>'
369
+ show_logprob_js = '''
370
+ function showLogprobs(i) {
371
+ const logprobs = allLogprobs[i].logprobs;
372
+ const logprobsHtml = Object.entries(logprobs).map(([token, logprob]) => `<li>${token}: ${Math.exp(logprob)}</li>`).join('');
373
+ const container = document.getElementById('logprobs-display');
374
+ container.innerHTML = `<ul>${logprobsHtml}</ul>`;
375
+ }
376
+ '''
377
+ html_out = f"""
378
+ <script>allLogprobs = {json.dumps(logprobs)};
379
+
380
+ {show_logprob_js}</script>
381
+ <style>
382
+ p.logprobs-container {{
383
+ background: white;
384
+ line-height: 2.5;
385
+ color: #2C3E50; /* Dark blue-grey for main text */
386
+ }}
387
+ p.logprobs-container > span {{
388
+ position: relative;
389
+ padding: 2px 1px;
390
+ border-radius: 3px;
391
+ }}
392
+ </style>
393
+ <p class="logprobs-container">{html_out}</p>
394
+ <div id="logprobs-display"></div>
395
+ """
396
+ #return st.html(html_out)
397
+ import streamlit.components.v1 as components
398
+ return components.html(html_out, height=200, scrolling=True)
399
 
400
  rewrite_page = st.Page(rewrite_with_predictions, title="Rewrite with predictions", icon="πŸ“")
401
  highlight_page = st.Page(highlight_edits, title="Highlight locations for possible edits", icon="πŸ–οΈ")
402
  generate_page = st.Page(generate_revisions, title="Generate revisions", icon="πŸ”„")
403
  type_assistant_response_page = st.Page(type_assistant_response, title="Type Assistant Response", icon="πŸ”€")
404
+ show_internals_page = st.Page(show_internals, title="Show Internals", icon="πŸ”§")
405
 
406
  # Manually specify the sidebar
407
  page = st.navigation([
 
409
  highlight_page,
410
  rewrite_page,
411
  generate_page,
412
+ type_assistant_response_page,
413
+ show_internals_page
414
  ])
415
  page.run()