Spaces:
Running
Running
File size: 18,164 Bytes
f85c983 7a6a62b f85c983 7a6a62b f85c983 5fe3eac f85c983 7a6a62b f85c983 7a6a62b f85c983 7a6a62b f85c983 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 |
import json
import operator
import uuid
from operator import itemgetter
from typing import Annotated, Sequence, TypedDict
import chainlit as cl
from dotenv import load_dotenv
from langchain.retrievers import ParentDocumentRetriever
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
from langchain.storage import InMemoryStore
# from langchain_core.output_parsers import StrOutputParser
from langchain.tools import tool
from langchain_community.document_loaders import ArxivLoader
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
from langchain_community.tools.pubmed.tool import PubmedQueryRun
# from langgraph.graph.message import add_messages
from langchain_core.messages import (
BaseMessage,
FunctionMessage,
SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant import Qdrant
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver
# from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain_core.messages import HumanMessage
# GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
# ---- ENV VARIABLES ---- #
"""
This function will load our environment file (.env) if it is present.
NOTE: Make sure that .env is in your .gitignore file - it is by default, but please ensure it remains there.
"""
load_dotenv()
"""
We will load our environment variables here.
"""
# ---- GLOBAL DECLARATIONS ---- #
# -- RETRIEVAL -- #
"""
1. Load Documents from Text File
2. Split Documents into Chunks
3. Load HuggingFace Embeddings (remember to use the URL we set above)
4. Index Files if they do not exist, otherwise load the vectorstore
"""
### 1. CREATE TEXT LOADER AND LOAD DOCUMENTS
### NOTE: PAY ATTENTION TO THE PATH THEY ARE IN.
docs = ArxivLoader(
query='"mental health counseling" AND (data OR analytics OR "machine learning")',
load_max_docs=10,
sort_by="submittedDate",
sort_order="descending",
).load()
### 2. CREATE QDRANT CLIENT VECTORE STORE
client = QdrantClient(":memory:")
client.create_collection(
collection_name="split_parents",
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
)
vectorstore = Qdrant(
client,
collection_name="split_parents",
embeddings=OpenAIEmbeddings(model="text-embedding-3-small"),
)
store = InMemoryStore()
### 3. CREATE PARENT DOCUMENT TEXT SPLITTER AND RETRIEVER INITIATED
parent_document_retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=RecursiveCharacterTextSplitter(chunk_size=400),
parent_splitter=RecursiveCharacterTextSplitter(chunk_size=2000),
)
parent_document_retriever.add_documents(docs)
### 4. CREATE PROMPT OBJECT
RAG_PROMPT = """\
Your are a professional mental helth advisor. Use the following context to answer the user's query. If you cannot answer the question, please respond with 'I don't know'.
Question:
{question}
Context:
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
### 5. CREATE CHAIN PIPLINE RETRIVER
openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo", streaming=True)
def create_qa_chain(retriever):
mentahealth_qa_llm = openai_chat_model
created_qa_chain = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {
"response": rag_prompt | mentahealth_qa_llm | StrOutputParser(),
"context": itemgetter("context"),
}
)
return created_qa_chain
### 6. DEFINE LIST OF TOOLS AVAILABLE FOR AND TOOL EXECUTOR WRAPPED AROUND THEM
@tool
async def rag_tool(question: str) -> str:
"""Only use this tool to retrieve research relevant information from the knowledge base."""
# advanced_rag_prompt=ChatPromptTemplate.from_template(INSTRUCTION_PROMPT_TEMPLATE.format(user_query=question))
parent_document_retriever_qa_chain = create_qa_chain(parent_document_retriever)
response = await parent_document_retriever_qa_chain.ainvoke({"question": question})
return response["response"]
tool_belt = [
rag_tool,
PubmedQueryRun(),
ArxivQueryRun(),
DuckDuckGoSearchRun(),
]
tool_executor = ToolExecutor(tool_belt)
### 7. CONVERT TOOLS INTO THE FORMAT COMAPTIBLE WITH OPENAI'S FUNCTION CALLING API THEN BINDING THEM TO MODEL TO BE USED WHEN GENERATION
model = ChatOpenAI(temperature=0, streaming=True)
functions = [convert_to_openai_function(t) for t in tool_belt]
model = model.bind_functions(functions)
model = model.with_config(tags=["final_node"])
### 8. USING the TypedDict FROM THE typing module AND THE langchain_core.messages module, A CUSTOM TYPE NAMED AgentState CREATED.
# THE AgentState type HAS A FIELD NAMED <messages> THAT IS OF TYPE Annotated[Sequence[BaseMessage], operator.add].
# Sequence[BaseMessage]: INDICATES THAT MESSAGES ARE A SEQUENCE OF BaseMessage OBJECTS.
# Annotated: USED TO ATTACH MEATADATA TO THE TYPE, THEN THE MESSAGE FIELD TREATED AS CONCATENABLE SEQUENCE OF BASEMASSAGES TO OPERATOR.ADD FUNCTION.
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
### 9. TWO FUNCTIONS DEFINED: 1. call_model AND 2. call_tool FUNCTIONS
# 1. INVOKES THE MODEL BY THE MESSAGES EXTRACTED FROM THE STATE RETURNING A DICT CONTAINING THE RESPONSE MESSAGE,
# 2.1 ToolInvocation OBJECT CREATED USING THE NAME AND ARGUMENTS EXTRACTED FROM THE LAST MASSAGE EXTRACTED FROM THE STATE,
# 2.2. tool_executor IS INVOKED BY THE CREATED toolInvocation OBJECT
# 2.3 FunctionMessage OBJECT IS CREATED WITH THE tool_executor RESPONSE AND THE NAME OF THAT TOOL
# 2.4 RETURN IS A DICT CONTAINING FunctionMessage OBJECT.
async def call_model(state):
messages = state["messages"]
response = await model.ainvoke(messages)
return {"messages": [response]}
async def call_tool(state):
last_message = state["messages"][-1]
action = ToolInvocation(
tool=last_message.additional_kwargs["function_call"]["name"],
tool_input=json.loads(
last_message.additional_kwargs["function_call"]["arguments"]
),
)
print()
print(last_message.additional_kwargs["function_call"]["name"])
print()
response = await tool_executor.ainvoke(action)
function_message = FunctionMessage(content=str(response), name=action.tool)
return {"messages": [function_message]}
###10. GRAPG CREATION WITH HELPFULNESS EVALUATION
# should_continue CHECKS IF THE LAST MASSAGE IN THE STATE IS TO CONTINUE (additional_kwargs EXISTS) OR END.
# THE add_conditional_edges() method IS ORIGINATED FROM THIS REPONSE, EITHER TRANSITION TO ACTION NODE OR END.
def should_continue(state):
last_message = state["messages"][-1]
if "function_call" not in last_message.additional_kwargs:
return "end"
return "continue"
async def check_helpfulness(state):
initial_query = state["messages"][0]
final_response = state["messages"][-1]
# adding artificial_loop
if len(state["messages"]) > 20:
return "end"
prompt_template = """\
Given an initial query and a final response, determine if the final response is extremely helpful or not. Please indicate helpfulness with a 'Y'\
and unhelpfulness as an 'N'.
Initial Query:
{initial_query}
Final Response:
{final_response}"""
prompt_template = PromptTemplate.from_template(prompt_template)
helpfulness_check_model = ChatOpenAI(model="gpt-4")
helpfulness_check_chain = (
prompt_template | helpfulness_check_model | StrOutputParser()
)
helpfulness_response = await helpfulness_check_chain.ainvoke(
{"initial_query": initial_query, "final_response": final_response}
)
if "Y" in helpfulness_response:
print("helpful!")
return "end"
else:
print(" Not helpful!!")
return "continue"
def dummy_node(state):
return
### 11. SETTING THE GRAPH WORKFLOW:
# 1. AN INSTANCE OF THE STATEGRAPH CREATED OF THE TYPE AgentState. THREE NODES ADDED TO THE GRAPH USING add_node() method:
# 1.1 THE "agent" NODE IS ASSOCIATED WITH THE call_model FUNCTION.
# 1.2 THE "action" NODE IS ASSOCIATED WITH THE call_tool FUNCTION.
# 1.3 THE "passthrough" NODE IS A CUSTOM NODE THAT IS ASSOCIATED WITH CHECKING HELPFULNESS.
# 1.5 THE CONDITIONAL EDGES
# 1.5.1 BETWEEN agent NODE AND THE OTHER TWO NODES TO EITHER action NODE OR passthrough NODE
# 1.5.2 BETWEEN passthrough NODE AND agen NODE OR END NODE.
# 1.5.3 BETWEEN agent AND action NODES AS MODEL HAS ACCESS TO TOOLS FOR RESPONSE GENERATION.
def get_state_update_bot():
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model) # agent node has access to llm
workflow.add_node("action", call_tool) # action node has access to tools
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{
"continue": "action", # tools
"end": END,
},
)
workflow.add_edge("action", "agent") # tools
state_update_bot = workflow.compile()
return state_update_bot
# --------------------------------------------------
from langgraph.checkpoint.memory import MemorySaver
def get_state_update_bot_with_helpfullness_node():
# memory = MemorySaver()
graph_with_helpfulness_check = StateGraph(AgentState)
graph_with_helpfulness_check.add_node("agent", call_model)
graph_with_helpfulness_check.add_node("action", call_tool)
graph_with_helpfulness_check.add_node("passthrough", dummy_node)
graph_with_helpfulness_check.set_entry_point("agent")
graph_with_helpfulness_check.add_conditional_edges(
"agent", should_continue, {"continue": "action", "end": "passthrough"}
)
graph_with_helpfulness_check.add_conditional_edges(
"passthrough", check_helpfulness, {"continue": "agent", "end": END}
)
graph_with_helpfulness_check.add_edge("action", "agent")
memory=AsyncSqliteSaver.from_conn_string(":memory:")
return graph_with_helpfulness_check.compile(checkpointer=memory)
### 12.
# def convert_inputs(input_object):
# system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses.
# Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution.
# First, look up the RAG (retrieval-augmented generation) and then arxiv research or use InternetSearch:
# You will operate in a loop of Thought, Action, PAUSE, and Observation. At the end of the loop, you will provide an Answer.
# Instructions:
# Thought: Describe your thoughts about the user's question.
# Action: Choose one of the available actions to gather information or provide insights.
# PAUSE: Pause to allow the action to complete.
# Observation: Review the results of the action.
# Available Actions:
# Use the tools at your disposal to look up information or resolve the consultancy. You are allowed to make multiple calls (either together or in sequence).:
# 1. rag_tool: RAG (Retrieval-Augmented Generation) to access relevant mental health information.
# 2. DuckDuckGoSearchRun: Perform an online search: InternetSearch to find up-to-date resources and recommendations.
# 3. ArxivQueryRun: Find relevant research or content.
# 3. PubMedQuerRun: Find a specific coping strategies or management techniques by doing research paper
# You may make multiple calls to these tools as needed to provide comprehensive advice.
# Present your final response in a clear, structured format, including a chart of recommended actions if appropriate.
# User's question: {input_object["messages"]}
# Response: Your task is When responding to users' personal issues or concerns:
# 1. With a brief empathetic acknowledgment of the user's situation, continue
# 2. Provide practical, actionable advice that often includes
# 3. Suggesting professional help (e.g., therapists, counselors) when appropriate
# 4. Encouraging open communication and dialogue with involved parties and
# 5. Recommending self-reflection or exploration of emotions and values and
# 6. Offering specific coping strategies or management techniques
# """
# return {"messages": [SystemMessage(content=system_prompt)]}
def convert_inputs(input_object):
system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses.
Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution.
You must Use the tools at your dsiposal.
You must consult pubmed, then ragtool, then duckduckgo_results_json.
You must make multiple calls to these tools as needed to provide comprehensive advice.
User's question: {input_object["messages"]}
"""
return {"messages": [SystemMessage(content=system_prompt)]}
# Define the function to parse the output
def parse_output(input_state):
return input_state
# bot_with_helpfulness_check=get_state_update_bot_with_helpfullness_node() # type:
# bot=get_state_update_bot()
# Create the agent chain
# agent_chain = convert_inputs | bot_with_helpfulness_check# | StrOutputParser()#| parse_output
# Run the agent chain with the input
# messages=agent_chain.invoke({"question": mental_health_counseling_data['test'][14]['Context']})
# ---------------------------------------------------------------------------------------------------------
# DEPLOYMENT
# ---------------------------------------------------------------------------------------------------------
@cl.author_rename
def rename(original_author: str):
"""
This function can be used to rename the 'author' of a message.
In this case, we're overriding the 'Assistant' author to be 'Paul Graham Essay Bot'.
"""
rename_dict = {"Assistant": "Mental Health Advisor Bot"}
return rename_dict.get(original_author, original_author)
@cl.on_chat_start
async def start_chat():
"""
This function will be called at the start of every user session.
We will build our LCEL RAG chain here, and store it in the user session.
The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
"""
### BUILD LCEL RAG CHAIN THAT ONLY RETURNS TEXT
# lcel_rag_chain = ( {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}
# | rag_prompt | hf_llm
# )
memory=MemorySaver
bot_with_helpfulness_check = get_state_update_bot_with_helpfullness_node()#(checkpointer=memory)
# type: ignore
lcel_agent_langgraph_chain = (
convert_inputs | bot_with_helpfulness_check) #| StrOutputParser())
# bot=get_state_update_bot()
# lcel_agent_chain = convert_inputs | bot| parse_output# StrOutputParser()
cl.user_session.set("langgraph_agent_chain", lcel_agent_langgraph_chain)
# Create a thread id and pass it as configuration
# to be able to use Langgraph's MemorySaver
conversation_id = str(uuid.uuid4())
config = {"configurable": {"thread_id": conversation_id}}
cl.user_session.set("config", config)
@cl.on_message
async def main(message: cl.Message):
"""
This function will be called every time a message is recieved from a session.
"""
# msg is the human message, could be mixed with system message.
# agent_message is the agent's response.
graph = cl.user_session.get("langgraph_agent_chain")
config = cl.user_session.get("config")
final_output=""
# inputs = {"messages": [("user", message.content)]}
inputs={"messages": [HumanMessage(message.content)]}
agent_message = cl.Message(content="")
await agent_message.send()
# final_output=""
async for event in graph.astream_events(
inputs,
config=config,#=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
version="v2",
):
kind = event["event"]
tags = event.get("tags", [])
name=event.get("name", "")
print()
print(f"Received event: {event}") # Debugging statement
print()
if kind == "on_chain_start":
if (
event["name"] == "Agent"
): # Was assigned when creating the agent with `.with_config({"run_name": "Agent"})`
print(
f"Starting agent: {event['name']} with input: {event['data'].get('input')}"
)
# await agent_message.send()
elif kind == "on_chain_end" and name=="RunnableSequence":#"tool_end" in tags:
if 'output' in event['data'] and "agent" in event["data"]['output']:
agent_output=event["data"]["output"]["agent"]
if "messages" in agent_output and agent_output["messages"]:
final_output=agent_output["messages"][0].content
await agent_message.stream_token(final_output)
# elif kind=="on_chain_stream":
# data=event['data']
# if data["chunk"].content:
# print(f"Streaming content: {data['chunk'].content}")
# await agent_message.stream_token(data["chunk"].content)
await agent_message.send()
#docker build -t llm-app-langgraph-react-chainlit-mentalmindbt .
#docker run -it -p 7860:7860 llm-app-langgraph-react-chainlit-mentalmindbt:latest |