|
|
|
import os |
|
import chainlit as cl |
|
from typing import Annotated, List |
|
from dotenv import load_dotenv |
|
from typing_extensions import List, TypedDict |
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.documents import Document |
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_cohere import CohereRerank |
|
from langgraph.graph import START, StateGraph, END |
|
from langchain_core.messages import HumanMessage |
|
from langchain_core.tools import tool |
|
from langchain_community.tools import TavilySearchResults |
|
from langgraph.prebuilt.tool_node import ToolNode |
|
from langgraph.graph.message import add_messages |
|
from langchain_community.vectorstores import FAISS |
|
from vectorstore import VectorStore |
|
|
|
load_dotenv() |
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY |
|
|
|
COHERE_API_KEY = os.getenv("COHERE_API_KEY") |
|
os.environ["COHERE_API_KEY"] = COHERE_API_KEY |
|
|
|
|
|
|
|
embed_model = HuggingFaceEmbeddings( |
|
model_name="Snowflake/snowflake-arctic-embed-l", |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
llm_sml = ChatOpenAI( |
|
model="gpt-4o-mini", |
|
temperature=0, |
|
) |
|
|
|
|
|
rag_prompt = ChatPromptTemplate.from_template("""\ |
|
You are a helpful assistant who answers questions based on provided context. You must only use the provided context. Do NOT use your own knowledge. |
|
if you don't know the answer, say so. |
|
### Question |
|
{question} |
|
### Context |
|
{context} |
|
""") |
|
|
|
|
|
vectorstore = VectorStore( |
|
collection_name="mg_alloy_collection_snowflake", |
|
) |
|
documents = VectorStore.load_chunks_as_documents("data/contextual_chunks") |
|
vectorstore.add_documents(documents) |
|
retriever = vectorstore.as_retriever(k=5) |
|
|
|
|
|
class State(TypedDict): |
|
question: str |
|
context: List[Document] |
|
response: str |
|
|
|
|
|
|
|
|
|
def generate(state): |
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) |
|
messages = rag_prompt.format_messages(question=state["question"], context=docs_content) |
|
response = llm_sml.invoke(messages) |
|
return {"response" : response.content} |
|
|
|
|
|
def retrieve_adjusted(state: State): |
|
compressor = CohereRerank(model="rerank-v3.5") |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, base_retriever=retriever, search_kwargs={"k": 5} |
|
) |
|
retrieved_docs = compression_retriever.invoke(state["question"]) |
|
return {"context" : retrieved_docs} |
|
|
|
|
|
def should_continue(state): |
|
last_message = state["messages"][-1] |
|
|
|
if last_message.tool_calls: |
|
return "action" |
|
|
|
return END |
|
|
|
|
|
|
|
|
|
graph_builder = StateGraph(State).add_sequence([retrieve_adjusted, generate]) |
|
graph_builder.add_edge(START, "retrieve_adjusted") |
|
graph = graph_builder.compile() |
|
|
|
|
|
@tool |
|
def ai_rag_tool(question: str) -> str: |
|
"""Useful for when you need to answer questions about magnesium alloys. Input should be a fully formed question.""" |
|
response = graph.invoke({"question" : question}) |
|
return { |
|
"messages": [HumanMessage(content=response["response"])], |
|
"context": response["context"] |
|
} |
|
|
|
|
|
|
|
tool_belt = [ |
|
ai_rag_tool |
|
] |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list, add_messages] |
|
context: List[Document] |
|
|
|
tool_node = ToolNode(tool_belt) |
|
|
|
uncompiled_graph = StateGraph(AgentState) |
|
|
|
def call_model(state): |
|
messages = state["messages"] |
|
response = llm_sml.invoke(messages) |
|
return { |
|
"messages": [response], |
|
"context": state.get("context", []) |
|
} |
|
|
|
uncompiled_graph.add_node("agent", call_model) |
|
uncompiled_graph.add_node("action", tool_node) |
|
uncompiled_graph.set_entry_point("agent") |
|
|
|
def should_continue(state): |
|
last_message = state["messages"][-1] |
|
|
|
if last_message.tool_calls: |
|
return "action" |
|
|
|
return END |
|
|
|
uncompiled_graph.add_conditional_edges( |
|
"agent", |
|
should_continue |
|
) |
|
|
|
uncompiled_graph.add_edge("action", "agent") |
|
|
|
compiled_graph = uncompiled_graph.compile() |
|
|
|
|
|
|
|
@cl.on_chat_start |
|
async def start(): |
|
cl.user_session.set( |
|
"graph", compiled_graph) |
|
|
|
@cl.on_message |
|
async def handle(message: cl.Message): |
|
graph = cl.user_session.get("graph") |
|
state = {"messages" : [HumanMessage(content=message.content)]} |
|
response = await graph.ainvoke(state) |
|
await cl.Message(content=response["messages"][-1].content).send() |