# You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python) import os import chainlit as cl # importing chainlit for our app from typing import Annotated, List from dotenv import load_dotenv from typing_extensions import List, TypedDict from langchain_huggingface import HuggingFaceEmbeddings from langchain.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_core.documents import Document from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_cohere import CohereRerank from langgraph.graph import START, StateGraph, END from langchain_core.messages import HumanMessage from langchain_core.tools import tool from langchain_community.tools import TavilySearchResults from langgraph.prebuilt.tool_node import ToolNode from langgraph.graph.message import add_messages from langchain_community.vectorstores import FAISS from vectorstore import VectorStore load_dotenv() # Using OpenAI API for embeddings/llms """ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY COHERE_API_KEY = os.getenv("COHERE_API_KEY") os.environ["COHERE_API_KEY"] = COHERE_API_KEY """ # ------- Models/Tools ------- # embed_model = HuggingFaceEmbeddings( model_name="Snowflake/snowflake-arctic-embed-l", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True} ) llm_sml = ChatOpenAI( model="gpt-4o-mini", temperature=0, ) # ------- Prompts ------- # rag_prompt = ChatPromptTemplate.from_template("""\ You are a helpful assistant who answers questions based on provided context. You must only use the provided context. Do NOT use your own knowledge. if you don't know the answer, say so. ### Question {question} ### Context {context} """) # load documents and create vector store vectorstore = VectorStore( collection_name="mg_alloy_collection_snowflake", ) documents = VectorStore.load_chunks_as_documents("data/contextual_chunks") vectorstore.add_documents(documents) retriever = vectorstore.as_retriever(k=5) # ------- Pydantic Models ------- # class State(TypedDict): question: str context: List[Document] response: str # ------- Functions ------- # def generate(state): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = rag_prompt.format_messages(question=state["question"], context=docs_content) response = llm_sml.invoke(messages) return {"response" : response.content} def retrieve_adjusted(state: State): compressor = CohereRerank(model="rerank-v3.5") compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever, search_kwargs={"k": 5} ) retrieved_docs = compression_retriever.invoke(state["question"]) return {"context" : retrieved_docs} def should_continue(state): last_message = state["messages"][-1] if last_message.tool_calls: return "action" return END # ------- Runnables ------- # # retrieve graph graph_builder = StateGraph(State) graph_builder.add_node("retrieve", retrieve_adjusted) graph_builder.add_node("generate", generate) graph_builder.add_edge(START, "retrieve") graph_builder.add_edge("retrieve", "generate") graph_builder.add_edge("generate", END) graph = graph_builder.compile() @tool def ai_rag_tool(question: str) -> str: """Useful for when you need to answer questions about magnesium alloys. Input should be a fully formed question.""" response = graph.invoke({"question" : question}) return { "messages": [HumanMessage(content=response["response"])], "context": response["context"] } # ------------------------------------------------ # tool_belt = [ ai_rag_tool ] class AgentState(TypedDict): messages: Annotated[list, add_messages] context: List[Document] tool_node = ToolNode(tool_belt) uncompiled_graph = StateGraph(AgentState) def call_model(state): messages = state["messages"] response = llm_sml.invoke(messages) return { "messages": [response], "context": state.get("context", []) } uncompiled_graph.add_node("agent", call_model) uncompiled_graph.add_node("action", tool_node) uncompiled_graph.set_entry_point("agent") def should_continue(state): last_message = state["messages"][-1] if last_message.tool_calls: return "action" return END uncompiled_graph.add_conditional_edges( "agent", should_continue ) uncompiled_graph.add_edge("action", "agent") compiled_graph = uncompiled_graph.compile() # ------- Chainlit ------- # @cl.on_chat_start async def start(): cl.user_session.set( "graph", compiled_graph) @cl.on_message async def handle(message: cl.Message): graph = cl.user_session.get("graph") state = {"messages" : [HumanMessage(content=message.content)]} response = await graph.ainvoke(state) await cl.Message(content=response["messages"][-1].content).send()