Joshua Sundance Bailey
commited on
Commit
·
f8e9121
1
Parent(s):
923e6fa
refactor
Browse files- langchain-streamlit-demo/app.py +42 -42
langchain-streamlit-demo/app.py
CHANGED
@@ -211,14 +211,14 @@ with sidebar:
|
|
211 |
)
|
212 |
|
213 |
chunk_size = st.slider(
|
214 |
-
label="
|
215 |
help="Size of each chunk of text",
|
216 |
min_value=MIN_CHUNK_SIZE,
|
217 |
max_value=MAX_CHUNK_SIZE,
|
218 |
value=DEFAULT_CHUNK_SIZE,
|
219 |
)
|
220 |
chunk_overlap = st.slider(
|
221 |
-
label="
|
222 |
help="Number of characters to overlap between chunks",
|
223 |
min_value=MIN_CHUNK_OVERLAP,
|
224 |
max_value=MAX_CHUNK_OVERLAP,
|
@@ -399,56 +399,56 @@ if st.session_state.llm:
|
|
399 |
],
|
400 |
)
|
401 |
|
402 |
-
full_response = None
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
llm=st.session_state.llm,
|
|
|
|
|
412 |
memory=MEMORY,
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
420 |
-
|
421 |
-
|
422 |
-
return get_rag_qa_gen_chain(
|
423 |
-
st.session_state.retriever,
|
424 |
-
st.session_state.llm,
|
425 |
-
)
|
426 |
-
elif document_chat_chain_type == "Summarization":
|
427 |
-
return get_rag_summarization_chain(
|
428 |
-
prompt,
|
429 |
-
st.session_state.retriever,
|
430 |
-
st.session_state.llm,
|
431 |
-
)
|
432 |
-
else:
|
433 |
-
return RetrievalQA.from_chain_type(
|
434 |
-
llm=st.session_state.llm,
|
435 |
-
chain_type=document_chat_chain_type,
|
436 |
-
retriever=st.session_state.retriever,
|
437 |
-
memory=MEMORY,
|
438 |
-
output_key="output_text",
|
439 |
-
) | (lambda output: output["output_text"])
|
440 |
-
|
441 |
-
st.session_state.doc_chain = get_rag_runnable()
|
442 |
-
|
443 |
-
full_response = st.session_state.doc_chain.invoke(prompt, config)
|
444 |
-
st.markdown(full_response)
|
445 |
|
446 |
except (openai.error.AuthenticationError, anthropic.AuthenticationError):
|
447 |
st.error(
|
448 |
f"Please enter a valid {st.session_state.provider} API key.",
|
449 |
icon="❌",
|
450 |
)
|
|
|
451 |
if full_response is not None:
|
|
|
|
|
452 |
# --- Tracing ---
|
453 |
if st.session_state.client:
|
454 |
st.session_state.run = RUN_COLLECTOR.traced_runs[0]
|
|
|
211 |
)
|
212 |
|
213 |
chunk_size = st.slider(
|
214 |
+
label="Number of Tokens per Chunk",
|
215 |
help="Size of each chunk of text",
|
216 |
min_value=MIN_CHUNK_SIZE,
|
217 |
max_value=MAX_CHUNK_SIZE,
|
218 |
value=DEFAULT_CHUNK_SIZE,
|
219 |
)
|
220 |
chunk_overlap = st.slider(
|
221 |
+
label="Chunk Overlap",
|
222 |
help="Number of characters to overlap between chunks",
|
223 |
min_value=MIN_CHUNK_OVERLAP,
|
224 |
max_value=MAX_CHUNK_OVERLAP,
|
|
|
399 |
],
|
400 |
)
|
401 |
|
402 |
+
full_response: Union[str, None] = None
|
403 |
+
|
404 |
+
message_placeholder = st.empty()
|
405 |
+
stream_handler = StreamHandler(message_placeholder)
|
406 |
+
callbacks.append(stream_handler)
|
407 |
+
|
408 |
+
def get_rag_runnable():
|
409 |
+
if document_chat_chain_type == "Q&A Generation":
|
410 |
+
return get_rag_qa_gen_chain(
|
411 |
+
st.session_state.retriever,
|
412 |
+
st.session_state.llm,
|
413 |
+
)
|
414 |
+
elif document_chat_chain_type == "Summarization":
|
415 |
+
return get_rag_summarization_chain(
|
416 |
+
prompt,
|
417 |
+
st.session_state.retriever,
|
418 |
+
st.session_state.llm,
|
419 |
+
)
|
420 |
+
else:
|
421 |
+
return RetrievalQA.from_chain_type(
|
422 |
llm=st.session_state.llm,
|
423 |
+
chain_type=document_chat_chain_type,
|
424 |
+
retriever=st.session_state.retriever,
|
425 |
memory=MEMORY,
|
426 |
+
output_key="output_text",
|
427 |
+
) | (lambda output: output["output_text"])
|
428 |
+
|
429 |
+
st.session_state.chain = (
|
430 |
+
get_rag_runnable()
|
431 |
+
if use_document_chat
|
432 |
+
else LLMChain(
|
433 |
+
prompt=chat_prompt,
|
434 |
+
llm=st.session_state.llm,
|
435 |
+
memory=MEMORY,
|
436 |
+
)
|
437 |
+
| (lambda output: output["text"])
|
438 |
+
)
|
439 |
|
440 |
+
try:
|
441 |
+
full_response = st.session_state.chain.invoke(prompt, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
|
443 |
except (openai.error.AuthenticationError, anthropic.AuthenticationError):
|
444 |
st.error(
|
445 |
f"Please enter a valid {st.session_state.provider} API key.",
|
446 |
icon="❌",
|
447 |
)
|
448 |
+
|
449 |
if full_response is not None:
|
450 |
+
message_placeholder.markdown(full_response)
|
451 |
+
|
452 |
# --- Tracing ---
|
453 |
if st.session_state.client:
|
454 |
st.session_state.run = RUN_COLLECTOR.traced_runs[0]
|