Joshua Sundance Bailey commited on
Commit
f8e9121
·
1 Parent(s): 923e6fa
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +42 -42
langchain-streamlit-demo/app.py CHANGED
@@ -211,14 +211,14 @@ with sidebar:
211
  )
212
 
213
  chunk_size = st.slider(
214
- label="chunk_size",
215
  help="Size of each chunk of text",
216
  min_value=MIN_CHUNK_SIZE,
217
  max_value=MAX_CHUNK_SIZE,
218
  value=DEFAULT_CHUNK_SIZE,
219
  )
220
  chunk_overlap = st.slider(
221
- label="chunk_overlap",
222
  help="Number of characters to overlap between chunks",
223
  min_value=MIN_CHUNK_OVERLAP,
224
  max_value=MAX_CHUNK_OVERLAP,
@@ -399,56 +399,56 @@ if st.session_state.llm:
399
  ],
400
  )
401
 
402
- full_response = None
403
-
404
- try:
405
- if not use_document_chat:
406
- message_placeholder = st.empty()
407
- stream_handler = StreamHandler(message_placeholder)
408
- callbacks.append(stream_handler)
409
- st.session_state.chain = LLMChain(
410
- prompt=chat_prompt,
 
 
 
 
 
 
 
 
 
 
 
411
  llm=st.session_state.llm,
 
 
412
  memory=MEMORY,
413
- ) | (lambda output: output["text"])
414
- config = {"callbacks": callbacks, "tags": ["Streamlit Chat"]}
415
- full_response = st.session_state.chain.invoke(prompt, config)
416
- message_placeholder.markdown(full_response)
417
-
418
- else:
 
 
 
 
 
 
 
419
 
420
- def get_rag_runnable():
421
- if document_chat_chain_type == "Q&A Generation":
422
- return get_rag_qa_gen_chain(
423
- st.session_state.retriever,
424
- st.session_state.llm,
425
- )
426
- elif document_chat_chain_type == "Summarization":
427
- return get_rag_summarization_chain(
428
- prompt,
429
- st.session_state.retriever,
430
- st.session_state.llm,
431
- )
432
- else:
433
- return RetrievalQA.from_chain_type(
434
- llm=st.session_state.llm,
435
- chain_type=document_chat_chain_type,
436
- retriever=st.session_state.retriever,
437
- memory=MEMORY,
438
- output_key="output_text",
439
- ) | (lambda output: output["output_text"])
440
-
441
- st.session_state.doc_chain = get_rag_runnable()
442
-
443
- full_response = st.session_state.doc_chain.invoke(prompt, config)
444
- st.markdown(full_response)
445
 
446
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
447
  st.error(
448
  f"Please enter a valid {st.session_state.provider} API key.",
449
  icon="❌",
450
  )
 
451
  if full_response is not None:
 
 
452
  # --- Tracing ---
453
  if st.session_state.client:
454
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]
 
211
  )
212
 
213
  chunk_size = st.slider(
214
+ label="Number of Tokens per Chunk",
215
  help="Size of each chunk of text",
216
  min_value=MIN_CHUNK_SIZE,
217
  max_value=MAX_CHUNK_SIZE,
218
  value=DEFAULT_CHUNK_SIZE,
219
  )
220
  chunk_overlap = st.slider(
221
+ label="Chunk Overlap",
222
  help="Number of characters to overlap between chunks",
223
  min_value=MIN_CHUNK_OVERLAP,
224
  max_value=MAX_CHUNK_OVERLAP,
 
399
  ],
400
  )
401
 
402
+ full_response: Union[str, None] = None
403
+
404
+ message_placeholder = st.empty()
405
+ stream_handler = StreamHandler(message_placeholder)
406
+ callbacks.append(stream_handler)
407
+
408
+ def get_rag_runnable():
409
+ if document_chat_chain_type == "Q&A Generation":
410
+ return get_rag_qa_gen_chain(
411
+ st.session_state.retriever,
412
+ st.session_state.llm,
413
+ )
414
+ elif document_chat_chain_type == "Summarization":
415
+ return get_rag_summarization_chain(
416
+ prompt,
417
+ st.session_state.retriever,
418
+ st.session_state.llm,
419
+ )
420
+ else:
421
+ return RetrievalQA.from_chain_type(
422
  llm=st.session_state.llm,
423
+ chain_type=document_chat_chain_type,
424
+ retriever=st.session_state.retriever,
425
  memory=MEMORY,
426
+ output_key="output_text",
427
+ ) | (lambda output: output["output_text"])
428
+
429
+ st.session_state.chain = (
430
+ get_rag_runnable()
431
+ if use_document_chat
432
+ else LLMChain(
433
+ prompt=chat_prompt,
434
+ llm=st.session_state.llm,
435
+ memory=MEMORY,
436
+ )
437
+ | (lambda output: output["text"])
438
+ )
439
 
440
+ try:
441
+ full_response = st.session_state.chain.invoke(prompt, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
444
  st.error(
445
  f"Please enter a valid {st.session_state.provider} API key.",
446
  icon="❌",
447
  )
448
+
449
  if full_response is not None:
450
+ message_placeholder.markdown(full_response)
451
+
452
  # --- Tracing ---
453
  if st.session_state.client:
454
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]