Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,011 Bytes
e921012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import random
from datetime import datetime
from operator import itemgetter
from typing import Sequence
import langsmith
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.document_transformers import LongContextReorder
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_openai import ChatOpenAI
from zoneinfo import ZoneInfo
from rag.retrievers import RetrieversConfig
from .prompt_template import generate_prompt_template
# Helpers
def get_datetime() -> str:
"""Get the current date and time."""
return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")
def reorder_documents(docs: list[Document]) -> Sequence[Document]:
"""Reorder documents to mitigate performance degradation with long contexts."""
return LongContextReorder().transform_documents(docs)
def randomize_documents(documents: list[Document]) -> list[Document]:
"""Randomize documents to vary model recommendations."""
random.shuffle(documents)
return documents
class DocumentFormatter:
def __init__(self, prefix: str):
self.prefix = prefix
def __call__(self, docs: list[Document]) -> str:
"""Format the Documents to markdown.
Args:
docs (list[Documents]): List of Langchain documents
Returns:
docs (str):
"""
return "\n---\n".join(
[
f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
for i, d in enumerate(docs)
]
)
def create_langsmith_client():
"""Create a Langsmith client."""
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
if not langsmith_api_key:
raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
return langsmith.Client()
# Set up Runnable and Memory
def get_runnable(
model: str = "gpt-4o-mini", temperature: float = 0.1
) -> tuple[Runnable, ConversationBufferWindowMemory]:
"""Set up runnable and chat memory
Args:
model_name (str, optional): LLM model. Defaults to "gpt-4o".
temperature (float, optional): Model temperature. Defaults to 0.1.
Returns:
Runnable, Memory: Chain and Memory
"""
# Set up Langsmith to trace the chain
create_langsmith_client()
# LLM and prompt template
llm = ChatOpenAI(
model=model,
temperature=temperature,
)
prompt = generate_prompt_template()
# Set retrievers with Hybrid search
retrievers_config = RetrieversConfig()
# Practitioners data
practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)
# Tall Tree documents with contact information for locations and services
documents_retriever = retrievers_config.get_documents_retriever(k=10)
# Set conversation history window memory. It only uses the last k interactions
memory = ConversationBufferWindowMemory(
memory_key="history",
return_messages=True,
k=6,
)
# Set up runnable using LCEL
setup = {
"practitioners_db": itemgetter("message")
| practitioners_data_retriever
| DocumentFormatter("Practitioner #"),
"tall_tree_db": itemgetter("message")
| documents_retriever
| DocumentFormatter("No."),
"timestamp": lambda _: get_datetime(),
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
"message": itemgetter("message"),
}
chain = setup | prompt | llm | StrOutputParser()
return chain, memory
|