ai-virtual-assistant / rag /runnable.py
yrobel-lima's picture
Upload 4 files
e921012 verified
raw
history blame
4.01 kB
import os
import random
from datetime import datetime
from operator import itemgetter
from typing import Sequence
import langsmith
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.document_transformers import LongContextReorder
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_openai import ChatOpenAI
from zoneinfo import ZoneInfo
from rag.retrievers import RetrieversConfig
from .prompt_template import generate_prompt_template
# Helpers
def get_datetime() -> str:
"""Get the current date and time."""
return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
def reorder_documents(docs: list[Document]) -> Sequence[Document]:
"""Reorder documents to mitigate performance degradation with long contexts."""
return LongContextReorder().transform_documents(docs)
def randomize_documents(documents: list[Document]) -> list[Document]:
"""Randomize documents to vary model recommendations."""
random.shuffle(documents)
return documents
class DocumentFormatter:
def __init__(self, prefix: str):
self.prefix = prefix
def __call__(self, docs: list[Document]) -> str:
"""Format the Documents to markdown.
Args:
docs (list[Documents]): List of Langchain documents
Returns:
docs (str):
"""
return "\n---\n".join(
[
f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
for i, d in enumerate(docs)
]
)
def create_langsmith_client():
"""Create a Langsmith client."""
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
if not langsmith_api_key:
raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
return langsmith.Client()
# Set up Runnable and Memory
def get_runnable(
model: str = "gpt-4o-mini", temperature: float = 0.1
) -> tuple[Runnable, ConversationBufferWindowMemory]:
"""Set up runnable and chat memory
Args:
model_name (str, optional): LLM model. Defaults to "gpt-4o".
temperature (float, optional): Model temperature. Defaults to 0.1.
Returns:
Runnable, Memory: Chain and Memory
"""
# Set up Langsmith to trace the chain
create_langsmith_client()
# LLM and prompt template
llm = ChatOpenAI(
model=model,
temperature=temperature,
)
prompt = generate_prompt_template()
# Set retrievers with Hybrid search
retrievers_config = RetrieversConfig()
# Practitioners data
practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
# Tall Tree documents with contact information for locations and services
documents_retriever = retrievers_config.get_documents_retriever(k=10)
# Set conversation history window memory. It only uses the last k interactions
memory = ConversationBufferWindowMemory(
memory_key="history",
return_messages=True,
k=6,
)
# Set up runnable using LCEL
setup = {
"practitioners_db": itemgetter("message")
| practitioners_data_retriever
| DocumentFormatter("Practitioner #"),
"tall_tree_db": itemgetter("message")
| documents_retriever
| DocumentFormatter("No."),
"timestamp": lambda _: get_datetime(),
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
"message": itemgetter("message"),
}
chain = setup | prompt | llm | StrOutputParser()
return chain, memory