Joshua Sundance Bailey
commited on
Commit
·
5825ff9
1
Parent(s):
547d578
callbacks (still not working 100%
Browse files- .pre-commit-config.yaml +8 -12
- langchain-streamlit-demo/app.py +65 -34
- langchain-streamlit-demo/llm_resources.py +12 -13
.pre-commit-config.yaml
CHANGED
@@ -40,24 +40,20 @@ repos:
|
|
40 |
- id: trailing-whitespace
|
41 |
- id: mixed-line-ending
|
42 |
- id: requirements-txt-fixer
|
43 |
-
- repo: https://github.com/
|
44 |
-
rev:
|
45 |
hooks:
|
46 |
-
- id:
|
47 |
-
additional_dependencies:
|
48 |
-
- types-requests
|
49 |
- repo: https://github.com/asottile/add-trailing-comma
|
50 |
rev: v3.1.0
|
51 |
hooks:
|
52 |
- id: add-trailing-comma
|
53 |
-
|
54 |
-
|
55 |
-
# hooks:
|
56 |
-
# - id: rm-unneeded-f-str
|
57 |
-
- repo: https://github.com/psf/black
|
58 |
-
rev: 23.9.1
|
59 |
hooks:
|
60 |
-
- id:
|
|
|
|
|
61 |
- repo: https://github.com/PyCQA/bandit
|
62 |
rev: 1.7.5
|
63 |
hooks:
|
|
|
40 |
- id: trailing-whitespace
|
41 |
- id: mixed-line-ending
|
42 |
- id: requirements-txt-fixer
|
43 |
+
- repo: https://github.com/psf/black
|
44 |
+
rev: 23.9.1
|
45 |
hooks:
|
46 |
+
- id: black
|
|
|
|
|
47 |
- repo: https://github.com/asottile/add-trailing-comma
|
48 |
rev: v3.1.0
|
49 |
hooks:
|
50 |
- id: add-trailing-comma
|
51 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
52 |
+
rev: v1.5.1
|
|
|
|
|
|
|
|
|
53 |
hooks:
|
54 |
+
- id: mypy
|
55 |
+
additional_dependencies:
|
56 |
+
- types-requests
|
57 |
- repo: https://github.com/PyCQA/bandit
|
58 |
rev: 1.7.5
|
59 |
hooks:
|
langchain-streamlit-demo/app.py
CHANGED
@@ -6,6 +6,7 @@ import langsmith.utils
|
|
6 |
import openai
|
7 |
import streamlit as st
|
8 |
from langchain.callbacks import StreamlitCallbackHandler
|
|
|
9 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
10 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
11 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
@@ -20,6 +21,7 @@ from defaults import default_values
|
|
20 |
from llm_resources import (
|
21 |
get_agent,
|
22 |
get_llm,
|
|
|
23 |
get_texts_and_multiretriever,
|
24 |
)
|
25 |
from research_assistant.chain import chain as research_assistant_chain
|
@@ -379,15 +381,6 @@ st.session_state.llm = get_llm(
|
|
379 |
},
|
380 |
)
|
381 |
|
382 |
-
research_assistant_tool = Tool.from_function(
|
383 |
-
func=lambda s: research_assistant_chain.invoke({"question": s}),
|
384 |
-
name="web-research-assistant",
|
385 |
-
description="this assistant returns a report based on web research",
|
386 |
-
)
|
387 |
-
|
388 |
-
TOOLS = [research_assistant_tool]
|
389 |
-
st.session_state.agent = get_agent(TOOLS, STMEMORY, st.session_state.llm)
|
390 |
-
|
391 |
# --- Chat History ---
|
392 |
for msg in STMEMORY.messages:
|
393 |
st.chat_message(
|
@@ -424,12 +417,16 @@ if st.session_state.llm:
|
|
424 |
if st.session_state.ls_tracer:
|
425 |
callbacks.append(st.session_state.ls_tracer)
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
|
|
433 |
|
434 |
use_document_chat = all(
|
435 |
[
|
@@ -439,32 +436,66 @@ if st.session_state.llm:
|
|
439 |
)
|
440 |
|
441 |
full_response: Union[str, None] = None
|
442 |
-
|
443 |
# stream_handler = StreamHandler(message_placeholder)
|
444 |
# callbacks.append(stream_handler)
|
|
|
445 |
|
446 |
-
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
|
465 |
# --- LLM call ---
|
466 |
try:
|
467 |
-
full_response = st.session_state.
|
|
|
|
|
|
|
468 |
|
469 |
except (openai.AuthenticationError, anthropic.AuthenticationError):
|
470 |
st.error(
|
|
|
6 |
import openai
|
7 |
import streamlit as st
|
8 |
from langchain.callbacks import StreamlitCallbackHandler
|
9 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
10 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
11 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
12 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
|
|
21 |
from llm_resources import (
|
22 |
get_agent,
|
23 |
get_llm,
|
24 |
+
get_runnable,
|
25 |
get_texts_and_multiretriever,
|
26 |
)
|
27 |
from research_assistant.chain import chain as research_assistant_chain
|
|
|
381 |
},
|
382 |
)
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
# --- Chat History ---
|
385 |
for msg in STMEMORY.messages:
|
386 |
st.chat_message(
|
|
|
417 |
if st.session_state.ls_tracer:
|
418 |
callbacks.append(st.session_state.ls_tracer)
|
419 |
|
420 |
+
def get_config(callbacks: list[BaseCallbackHandler]) -> dict[str, Any]:
|
421 |
+
config: Dict[str, Any] = dict(
|
422 |
+
callbacks=callbacks,
|
423 |
+
tags=["Streamlit Chat"],
|
424 |
+
verbose=True,
|
425 |
+
return_intermediate_steps=True,
|
426 |
+
)
|
427 |
+
if st.session_state.provider == "Anthropic":
|
428 |
+
config["max_concurrency"] = 5
|
429 |
+
return config
|
430 |
|
431 |
use_document_chat = all(
|
432 |
[
|
|
|
436 |
)
|
437 |
|
438 |
full_response: Union[str, None] = None
|
|
|
439 |
# stream_handler = StreamHandler(message_placeholder)
|
440 |
# callbacks.append(stream_handler)
|
441 |
+
message_placeholder = st.empty()
|
442 |
|
443 |
+
if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
|
444 |
+
st_callback = StreamlitCallbackHandler(st.container())
|
445 |
+
callbacks.append(st_callback)
|
446 |
+
research_assistant_tool = Tool.from_function(
|
447 |
+
func=lambda s: research_assistant_chain.invoke(
|
448 |
+
{"question": s},
|
449 |
+
config=get_config(callbacks),
|
450 |
+
),
|
451 |
+
name="web-research-assistant",
|
452 |
+
description="this assistant returns a report based on web research",
|
453 |
+
)
|
454 |
|
455 |
+
TOOLS = [research_assistant_tool]
|
456 |
+
if use_document_chat:
|
457 |
+
st.session_state.doc_chain = get_runnable(
|
458 |
+
use_document_chat,
|
459 |
+
document_chat_chain_type,
|
460 |
+
st.session_state.llm,
|
461 |
+
st.session_state.retriever,
|
462 |
+
MEMORY,
|
463 |
+
chat_prompt,
|
464 |
+
prompt,
|
465 |
+
)
|
466 |
+
doc_chain_tool = Tool.from_function(
|
467 |
+
func=lambda s: st.session_state.doc_chain.invoke(
|
468 |
+
s,
|
469 |
+
config=get_config(callbacks),
|
470 |
+
),
|
471 |
+
name="user-document-chat",
|
472 |
+
description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
|
473 |
+
)
|
474 |
+
TOOLS = [doc_chain_tool, research_assistant_tool]
|
475 |
+
|
476 |
+
st.session_state.chain = get_agent(
|
477 |
+
TOOLS,
|
478 |
+
STMEMORY,
|
479 |
+
st.session_state.llm,
|
480 |
+
callbacks,
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
st.session_state.chain = get_runnable(
|
484 |
+
use_document_chat,
|
485 |
+
document_chat_chain_type,
|
486 |
+
st.session_state.llm,
|
487 |
+
st.session_state.retriever,
|
488 |
+
MEMORY,
|
489 |
+
chat_prompt,
|
490 |
+
prompt,
|
491 |
+
)
|
492 |
|
493 |
# --- LLM call ---
|
494 |
try:
|
495 |
+
full_response = st.session_state.chain.invoke(
|
496 |
+
prompt,
|
497 |
+
config=get_config(callbacks),
|
498 |
+
)
|
499 |
|
500 |
except (openai.AuthenticationError, anthropic.AuthenticationError):
|
501 |
st.error(
|
langchain-streamlit-demo/llm_resources.py
CHANGED
@@ -3,13 +3,13 @@ from tempfile import NamedTemporaryFile
|
|
3 |
from typing import Tuple, List, Optional, Dict
|
4 |
|
5 |
from langchain.agents import AgentExecutor
|
6 |
-
from langchain.agents.agent_toolkits import create_retriever_tool
|
7 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
8 |
AgentTokenBufferMemory,
|
9 |
)
|
10 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
11 |
from langchain.callbacks.base import BaseCallbackHandler
|
12 |
from langchain.chains import LLMChain
|
|
|
13 |
from langchain.chat_models import (
|
14 |
AzureChatOpenAI,
|
15 |
ChatOpenAI,
|
@@ -18,29 +18,30 @@ from langchain.chat_models import (
|
|
18 |
)
|
19 |
from langchain.document_loaders import PyPDFLoader
|
20 |
from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
|
|
21 |
from langchain.prompts import MessagesPlaceholder
|
22 |
from langchain.retrievers import EnsembleRetriever
|
23 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
24 |
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
25 |
from langchain.schema import Document, BaseRetriever
|
|
|
26 |
from langchain.schema.runnable import RunnablePassthrough
|
27 |
from langchain.storage import InMemoryStore
|
28 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
29 |
from langchain.vectorstores import FAISS
|
30 |
from langchain_core.messages import SystemMessage
|
31 |
|
32 |
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
|
33 |
from qagen import get_rag_qa_gen_chain
|
34 |
from summarize import get_rag_summarization_chain
|
35 |
-
from langchain.tools.base import BaseTool
|
36 |
-
from langchain.schema.chat_history import BaseChatMessageHistory
|
37 |
-
from langchain.llms.base import BaseLLM
|
38 |
|
39 |
|
40 |
def get_agent(
|
41 |
tools: list[BaseTool],
|
42 |
chat_history: BaseChatMessageHistory,
|
43 |
llm: BaseLLM,
|
|
|
44 |
):
|
45 |
memory_key = "agent_history"
|
46 |
system_message = SystemMessage(
|
@@ -68,6 +69,7 @@ def get_agent(
|
|
68 |
memory=agent_memory,
|
69 |
verbose=True,
|
70 |
return_intermediate_steps=True,
|
|
|
71 |
)
|
72 |
return (
|
73 |
{"input": RunnablePassthrough()}
|
@@ -84,7 +86,6 @@ def get_runnable(
|
|
84 |
memory,
|
85 |
chat_prompt,
|
86 |
summarization_prompt,
|
87 |
-
chat_history,
|
88 |
):
|
89 |
if not use_document_chat:
|
90 |
return LLMChain(
|
@@ -105,14 +106,12 @@ def get_runnable(
|
|
105 |
llm,
|
106 |
)
|
107 |
else:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
return get_agent(tools, chat_history, llm)
|
116 |
|
117 |
|
118 |
def get_llm(
|
|
|
3 |
from typing import Tuple, List, Optional, Dict
|
4 |
|
5 |
from langchain.agents import AgentExecutor
|
|
|
6 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
7 |
AgentTokenBufferMemory,
|
8 |
)
|
9 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
10 |
from langchain.callbacks.base import BaseCallbackHandler
|
11 |
from langchain.chains import LLMChain
|
12 |
+
from langchain.chains import RetrievalQA
|
13 |
from langchain.chat_models import (
|
14 |
AzureChatOpenAI,
|
15 |
ChatOpenAI,
|
|
|
18 |
)
|
19 |
from langchain.document_loaders import PyPDFLoader
|
20 |
from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
21 |
+
from langchain.llms.base import BaseLLM
|
22 |
from langchain.prompts import MessagesPlaceholder
|
23 |
from langchain.retrievers import EnsembleRetriever
|
24 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
25 |
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
26 |
from langchain.schema import Document, BaseRetriever
|
27 |
+
from langchain.schema.chat_history import BaseChatMessageHistory
|
28 |
from langchain.schema.runnable import RunnablePassthrough
|
29 |
from langchain.storage import InMemoryStore
|
30 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
31 |
+
from langchain.tools.base import BaseTool
|
32 |
from langchain.vectorstores import FAISS
|
33 |
from langchain_core.messages import SystemMessage
|
34 |
|
35 |
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
|
36 |
from qagen import get_rag_qa_gen_chain
|
37 |
from summarize import get_rag_summarization_chain
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
def get_agent(
|
41 |
tools: list[BaseTool],
|
42 |
chat_history: BaseChatMessageHistory,
|
43 |
llm: BaseLLM,
|
44 |
+
callbacks,
|
45 |
):
|
46 |
memory_key = "agent_history"
|
47 |
system_message = SystemMessage(
|
|
|
69 |
memory=agent_memory,
|
70 |
verbose=True,
|
71 |
return_intermediate_steps=True,
|
72 |
+
callbacks=callbacks,
|
73 |
)
|
74 |
return (
|
75 |
{"input": RunnablePassthrough()}
|
|
|
86 |
memory,
|
87 |
chat_prompt,
|
88 |
summarization_prompt,
|
|
|
89 |
):
|
90 |
if not use_document_chat:
|
91 |
return LLMChain(
|
|
|
106 |
llm,
|
107 |
)
|
108 |
else:
|
109 |
+
return RetrievalQA.from_chain_type(
|
110 |
+
llm=llm,
|
111 |
+
chain_type=document_chat_chain_type,
|
112 |
+
retriever=retriever,
|
113 |
+
output_key="output_text",
|
114 |
+
) | (lambda output: output["output_text"])
|
|
|
|
|
115 |
|
116 |
|
117 |
def get_llm(
|