Joshua Sundance Bailey commited on
Commit
5825ff9
·
1 Parent(s): 547d578

callbacks (still not working 100%

Browse files
.pre-commit-config.yaml CHANGED
@@ -40,24 +40,20 @@ repos:
40
  - id: trailing-whitespace
41
  - id: mixed-line-ending
42
  - id: requirements-txt-fixer
43
- - repo: https://github.com/pre-commit/mirrors-mypy
44
- rev: v1.5.1
45
  hooks:
46
- - id: mypy
47
- additional_dependencies:
48
- - types-requests
49
  - repo: https://github.com/asottile/add-trailing-comma
50
  rev: v3.1.0
51
  hooks:
52
  - id: add-trailing-comma
53
- #- repo: https://github.com/dannysepler/rm_unneeded_f_str
54
- # rev: v0.2.0
55
- # hooks:
56
- # - id: rm-unneeded-f-str
57
- - repo: https://github.com/psf/black
58
- rev: 23.9.1
59
  hooks:
60
- - id: black
 
 
61
  - repo: https://github.com/PyCQA/bandit
62
  rev: 1.7.5
63
  hooks:
 
40
  - id: trailing-whitespace
41
  - id: mixed-line-ending
42
  - id: requirements-txt-fixer
43
+ - repo: https://github.com/psf/black
44
+ rev: 23.9.1
45
  hooks:
46
+ - id: black
 
 
47
  - repo: https://github.com/asottile/add-trailing-comma
48
  rev: v3.1.0
49
  hooks:
50
  - id: add-trailing-comma
51
+ - repo: https://github.com/pre-commit/mirrors-mypy
52
+ rev: v1.5.1
 
 
 
 
53
  hooks:
54
+ - id: mypy
55
+ additional_dependencies:
56
+ - types-requests
57
  - repo: https://github.com/PyCQA/bandit
58
  rev: 1.7.5
59
  hooks:
langchain-streamlit-demo/app.py CHANGED
@@ -6,6 +6,7 @@ import langsmith.utils
6
  import openai
7
  import streamlit as st
8
  from langchain.callbacks import StreamlitCallbackHandler
 
9
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
10
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
11
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
@@ -20,6 +21,7 @@ from defaults import default_values
20
  from llm_resources import (
21
  get_agent,
22
  get_llm,
 
23
  get_texts_and_multiretriever,
24
  )
25
  from research_assistant.chain import chain as research_assistant_chain
@@ -379,15 +381,6 @@ st.session_state.llm = get_llm(
379
  },
380
  )
381
 
382
- research_assistant_tool = Tool.from_function(
383
- func=lambda s: research_assistant_chain.invoke({"question": s}),
384
- name="web-research-assistant",
385
- description="this assistant returns a report based on web research",
386
- )
387
-
388
- TOOLS = [research_assistant_tool]
389
- st.session_state.agent = get_agent(TOOLS, STMEMORY, st.session_state.llm)
390
-
391
  # --- Chat History ---
392
  for msg in STMEMORY.messages:
393
  st.chat_message(
@@ -424,12 +417,16 @@ if st.session_state.llm:
424
  if st.session_state.ls_tracer:
425
  callbacks.append(st.session_state.ls_tracer)
426
 
427
- config: Dict[str, Any] = dict(
428
- callbacks=callbacks,
429
- tags=["Streamlit Chat"],
430
- )
431
- if st.session_state.provider == "Anthropic":
432
- config["max_concurrency"] = 5
 
 
 
 
433
 
434
  use_document_chat = all(
435
  [
@@ -439,32 +436,66 @@ if st.session_state.llm:
439
  )
440
 
441
  full_response: Union[str, None] = None
442
-
443
  # stream_handler = StreamHandler(message_placeholder)
444
  # callbacks.append(stream_handler)
 
445
 
446
- st_callback = StreamlitCallbackHandler(st.container())
447
- callbacks.append(st_callback)
 
 
 
 
 
 
 
 
 
448
 
449
- message_placeholder = st.empty()
450
- # TODO use agent if openai or azure openai
451
- # otherwise use runnable
452
- # for agent + runnable, add to tools
453
-
454
- # st.session_state.chain = get_runnable(
455
- # use_document_chat,
456
- # document_chat_chain_type,
457
- # st.session_state.llm,
458
- # st.session_state.retriever,
459
- # MEMORY,
460
- # chat_prompt,
461
- # prompt,
462
- # STMEMORY,
463
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  # --- LLM call ---
466
  try:
467
- full_response = st.session_state.agent.invoke(prompt, config)
 
 
 
468
 
469
  except (openai.AuthenticationError, anthropic.AuthenticationError):
470
  st.error(
 
6
  import openai
7
  import streamlit as st
8
  from langchain.callbacks import StreamlitCallbackHandler
9
+ from langchain.callbacks.base import BaseCallbackHandler
10
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
 
21
  from llm_resources import (
22
  get_agent,
23
  get_llm,
24
+ get_runnable,
25
  get_texts_and_multiretriever,
26
  )
27
  from research_assistant.chain import chain as research_assistant_chain
 
381
  },
382
  )
383
 
 
 
 
 
 
 
 
 
 
384
  # --- Chat History ---
385
  for msg in STMEMORY.messages:
386
  st.chat_message(
 
417
  if st.session_state.ls_tracer:
418
  callbacks.append(st.session_state.ls_tracer)
419
 
420
+ def get_config(callbacks: list[BaseCallbackHandler]) -> dict[str, Any]:
421
+ config: Dict[str, Any] = dict(
422
+ callbacks=callbacks,
423
+ tags=["Streamlit Chat"],
424
+ verbose=True,
425
+ return_intermediate_steps=True,
426
+ )
427
+ if st.session_state.provider == "Anthropic":
428
+ config["max_concurrency"] = 5
429
+ return config
430
 
431
  use_document_chat = all(
432
  [
 
436
  )
437
 
438
  full_response: Union[str, None] = None
 
439
  # stream_handler = StreamHandler(message_placeholder)
440
  # callbacks.append(stream_handler)
441
+ message_placeholder = st.empty()
442
 
443
+ if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
444
+ st_callback = StreamlitCallbackHandler(st.container())
445
+ callbacks.append(st_callback)
446
+ research_assistant_tool = Tool.from_function(
447
+ func=lambda s: research_assistant_chain.invoke(
448
+ {"question": s},
449
+ config=get_config(callbacks),
450
+ ),
451
+ name="web-research-assistant",
452
+ description="this assistant returns a report based on web research",
453
+ )
454
 
455
+ TOOLS = [research_assistant_tool]
456
+ if use_document_chat:
457
+ st.session_state.doc_chain = get_runnable(
458
+ use_document_chat,
459
+ document_chat_chain_type,
460
+ st.session_state.llm,
461
+ st.session_state.retriever,
462
+ MEMORY,
463
+ chat_prompt,
464
+ prompt,
465
+ )
466
+ doc_chain_tool = Tool.from_function(
467
+ func=lambda s: st.session_state.doc_chain.invoke(
468
+ s,
469
+ config=get_config(callbacks),
470
+ ),
471
+ name="user-document-chat",
472
+ description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
473
+ )
474
+ TOOLS = [doc_chain_tool, research_assistant_tool]
475
+
476
+ st.session_state.chain = get_agent(
477
+ TOOLS,
478
+ STMEMORY,
479
+ st.session_state.llm,
480
+ callbacks,
481
+ )
482
+ else:
483
+ st.session_state.chain = get_runnable(
484
+ use_document_chat,
485
+ document_chat_chain_type,
486
+ st.session_state.llm,
487
+ st.session_state.retriever,
488
+ MEMORY,
489
+ chat_prompt,
490
+ prompt,
491
+ )
492
 
493
  # --- LLM call ---
494
  try:
495
+ full_response = st.session_state.chain.invoke(
496
+ prompt,
497
+ config=get_config(callbacks),
498
+ )
499
 
500
  except (openai.AuthenticationError, anthropic.AuthenticationError):
501
  st.error(
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -3,13 +3,13 @@ from tempfile import NamedTemporaryFile
3
  from typing import Tuple, List, Optional, Dict
4
 
5
  from langchain.agents import AgentExecutor
6
- from langchain.agents.agent_toolkits import create_retriever_tool
7
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
8
  AgentTokenBufferMemory,
9
  )
10
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
11
  from langchain.callbacks.base import BaseCallbackHandler
12
  from langchain.chains import LLMChain
 
13
  from langchain.chat_models import (
14
  AzureChatOpenAI,
15
  ChatOpenAI,
@@ -18,29 +18,30 @@ from langchain.chat_models import (
18
  )
19
  from langchain.document_loaders import PyPDFLoader
20
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
 
21
  from langchain.prompts import MessagesPlaceholder
22
  from langchain.retrievers import EnsembleRetriever
23
  from langchain.retrievers.multi_query import MultiQueryRetriever
24
  from langchain.retrievers.multi_vector import MultiVectorRetriever
25
  from langchain.schema import Document, BaseRetriever
 
26
  from langchain.schema.runnable import RunnablePassthrough
27
  from langchain.storage import InMemoryStore
28
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
29
  from langchain.vectorstores import FAISS
30
  from langchain_core.messages import SystemMessage
31
 
32
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
33
  from qagen import get_rag_qa_gen_chain
34
  from summarize import get_rag_summarization_chain
35
- from langchain.tools.base import BaseTool
36
- from langchain.schema.chat_history import BaseChatMessageHistory
37
- from langchain.llms.base import BaseLLM
38
 
39
 
40
  def get_agent(
41
  tools: list[BaseTool],
42
  chat_history: BaseChatMessageHistory,
43
  llm: BaseLLM,
 
44
  ):
45
  memory_key = "agent_history"
46
  system_message = SystemMessage(
@@ -68,6 +69,7 @@ def get_agent(
68
  memory=agent_memory,
69
  verbose=True,
70
  return_intermediate_steps=True,
 
71
  )
72
  return (
73
  {"input": RunnablePassthrough()}
@@ -84,7 +86,6 @@ def get_runnable(
84
  memory,
85
  chat_prompt,
86
  summarization_prompt,
87
- chat_history,
88
  ):
89
  if not use_document_chat:
90
  return LLMChain(
@@ -105,14 +106,12 @@ def get_runnable(
105
  llm,
106
  )
107
  else:
108
- tool = create_retriever_tool(
109
- retriever,
110
- "search_user_document",
111
- "Retrieves custom context provided by the user for this conversation. Use this if you cannot answer immediately and confidently.",
112
- )
113
- tools = [tool]
114
-
115
- return get_agent(tools, chat_history, llm)
116
 
117
 
118
  def get_llm(
 
3
  from typing import Tuple, List, Optional, Dict
4
 
5
  from langchain.agents import AgentExecutor
 
6
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
7
  AgentTokenBufferMemory,
8
  )
9
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
10
  from langchain.callbacks.base import BaseCallbackHandler
11
  from langchain.chains import LLMChain
12
+ from langchain.chains import RetrievalQA
13
  from langchain.chat_models import (
14
  AzureChatOpenAI,
15
  ChatOpenAI,
 
18
  )
19
  from langchain.document_loaders import PyPDFLoader
20
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
21
+ from langchain.llms.base import BaseLLM
22
  from langchain.prompts import MessagesPlaceholder
23
  from langchain.retrievers import EnsembleRetriever
24
  from langchain.retrievers.multi_query import MultiQueryRetriever
25
  from langchain.retrievers.multi_vector import MultiVectorRetriever
26
  from langchain.schema import Document, BaseRetriever
27
+ from langchain.schema.chat_history import BaseChatMessageHistory
28
  from langchain.schema.runnable import RunnablePassthrough
29
  from langchain.storage import InMemoryStore
30
  from langchain.text_splitter import RecursiveCharacterTextSplitter
31
+ from langchain.tools.base import BaseTool
32
  from langchain.vectorstores import FAISS
33
  from langchain_core.messages import SystemMessage
34
 
35
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
36
  from qagen import get_rag_qa_gen_chain
37
  from summarize import get_rag_summarization_chain
 
 
 
38
 
39
 
40
  def get_agent(
41
  tools: list[BaseTool],
42
  chat_history: BaseChatMessageHistory,
43
  llm: BaseLLM,
44
+ callbacks,
45
  ):
46
  memory_key = "agent_history"
47
  system_message = SystemMessage(
 
69
  memory=agent_memory,
70
  verbose=True,
71
  return_intermediate_steps=True,
72
+ callbacks=callbacks,
73
  )
74
  return (
75
  {"input": RunnablePassthrough()}
 
86
  memory,
87
  chat_prompt,
88
  summarization_prompt,
 
89
  ):
90
  if not use_document_chat:
91
  return LLMChain(
 
106
  llm,
107
  )
108
  else:
109
+ return RetrievalQA.from_chain_type(
110
+ llm=llm,
111
+ chain_type=document_chat_chain_type,
112
+ retriever=retriever,
113
+ output_key="output_text",
114
+ ) | (lambda output: output["output_text"])
 
 
115
 
116
 
117
  def get_llm(