File size: 4,011 Bytes
e921012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import random
from datetime import datetime
from operator import itemgetter
from typing import Sequence

import langsmith
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.document_transformers import LongContextReorder
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_openai import ChatOpenAI
from zoneinfo import ZoneInfo

from rag.retrievers import RetrieversConfig

from .prompt_template import generate_prompt_template

# Helpers


def get_datetime() -> str:
    """Get the current date and time."""
    return datetime.now(ZoneInfo("America/Vancouver")).strftime("%A, %Y-%b-%d %H:%M:%S")


def reorder_documents(docs: list[Document]) -> Sequence[Document]:
    """Reorder documents to mitigate performance degradation with long contexts."""

    return LongContextReorder().transform_documents(docs)


def randomize_documents(documents: list[Document]) -> list[Document]:
    """Randomize documents to vary model recommendations."""
    random.shuffle(documents)
    return documents


class DocumentFormatter:
    def __init__(self, prefix: str):
        self.prefix = prefix

    def __call__(self, docs: list[Document]) -> str:
        """Format the Documents to markdown.

        Args:

            docs (list[Documents]): List of Langchain documents

        Returns:

            docs (str):

        """
        return "\n---\n".join(
            [
                f"- {self.prefix} {i+1}:\n\n\t" + d.page_content
                for i, d in enumerate(docs)
            ]
        )


def create_langsmith_client():
    """Create a Langsmith client."""
    os.environ["LANGCHAIN_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_PROJECT"] = "admin-ai-assistant"
    os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
    langsmith_api_key = os.getenv("LANGCHAIN_API_KEY")
    if not langsmith_api_key:
        raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY")
    return langsmith.Client()


# Set up Runnable and Memory


def get_runnable(

    model: str = "gpt-4o-mini", temperature: float = 0.1

) -> tuple[Runnable, ConversationBufferWindowMemory]:
    """Set up runnable and chat memory



    Args:

        model_name (str, optional): LLM model. Defaults to "gpt-4o".

        temperature (float, optional): Model temperature. Defaults to 0.1.



    Returns:

        Runnable, Memory: Chain and Memory

    """

    # Set up Langsmith to trace the chain
    create_langsmith_client()

    # LLM and prompt template
    llm = ChatOpenAI(
        model=model,
        temperature=temperature,
    )

    prompt = generate_prompt_template()

    # Set retrievers with Hybrid search

    retrievers_config = RetrieversConfig()

    # Practitioners data
    practitioners_data_retriever = retrievers_config.get_practitioners_retriever(k=10)

    # Tall Tree documents with contact information for locations and services
    documents_retriever = retrievers_config.get_documents_retriever(k=10)

    # Set conversation history window memory. It only uses the last k interactions
    memory = ConversationBufferWindowMemory(
        memory_key="history",
        return_messages=True,
        k=6,
    )

    # Set up runnable using LCEL
    setup = {
        "practitioners_db": itemgetter("message")
        | practitioners_data_retriever
        | DocumentFormatter("Practitioner #"),
        "tall_tree_db": itemgetter("message")
        | documents_retriever
        | DocumentFormatter("No."),
        "timestamp": lambda _: get_datetime(),
        "history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"),
        "message": itemgetter("message"),
    }

    chain = setup | prompt | llm | StrOutputParser()

    return chain, memory