|
import os |
|
import random |
|
from functools import cache |
|
from operator import itemgetter |
|
|
|
import langsmith |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain.retrievers import EnsembleRetriever |
|
from langchain_community.document_transformers import LongContextReorder |
|
from langchain_core.documents import Document |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableLambda |
|
from langchain_openai.chat_models import ChatOpenAI |
|
|
|
from .prompt_template import generate_prompt_template |
|
from .retrievers_setup import (DenseRetrieverClient, SparseRetrieverClient, |
|
compression_retriever_setup) |
|
|
|
|
|
|
|
|
|
def reorder_documents(docs: list[Document]) -> list[Document]: |
|
"""Long-Context Reorder: No matter the architecture of the model, there is |
|
a performance degradation when we include 10+ retrieved documents. |
|
|
|
Args: |
|
docs (list): List of Langchain documents |
|
|
|
Returns: |
|
list: Reordered list of Langchain documents |
|
""" |
|
reorder = LongContextReorder() |
|
return reorder.transform_documents(docs) |
|
|
|
|
|
def randomize_documents(documents: list[Document]) -> list[Document]: |
|
"""Randomize the documents to vary the recommendations.""" |
|
random.shuffle(documents) |
|
return documents |
|
|
|
|
|
def format_practitioners_docs(docs: list[Document]) -> str: |
|
"""Format the practitioners_db Documents to markdown. |
|
Args: |
|
docs (list[Documents]): List of Langchain documents |
|
Returns: |
|
docs (str): |
|
""" |
|
return f"\n{'-' * 3}\n".join( |
|
[f"- Practitioner #{i+1}:\n\n\t" + |
|
d.page_content for i, d in enumerate(docs)] |
|
) |
|
|
|
|
|
def format_tall_tree_docs(docs: list[Document]) -> str: |
|
"""Format the tall_tree_db Documents to markdown. |
|
Args: |
|
docs (list[Documents]): List of Langchain documents |
|
Returns: |
|
docs (str): |
|
|
|
""" |
|
return f"\n{'-' * 3}\n".join( |
|
[f"- No. {i+1}:\n\n\t" + |
|
d.page_content for i, d in enumerate(docs)] |
|
) |
|
|
|
|
|
def create_langsmith_client(): |
|
"""Create a Langsmith client.""" |
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant" |
|
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" |
|
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY") |
|
if not langsmith_api_key: |
|
raise EnvironmentError( |
|
"Missing environment variable: LANGCHAIN_API_KEY") |
|
return langsmith.Client() |
|
|
|
|
|
|
|
|
|
|
|
@cache |
|
def get_rag_chain(model_name: str = "gpt-4", temperature: float = 0.2) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]: |
|
"""Set up runnable and chat memory |
|
|
|
Args: |
|
model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024. |
|
temperature (float, optional): Model temperature. Defaults to 0.2. |
|
|
|
Returns: |
|
Runnable, Memory: Chain and Memory |
|
""" |
|
|
|
|
|
langsmith_tracing = create_langsmith_client() |
|
|
|
|
|
llm = ChatOpenAI(model_name=model_name, |
|
temperature=temperature) |
|
|
|
prompt = generate_prompt_template() |
|
|
|
|
|
embeddings_model = "text-embedding-ada-002" |
|
dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model, |
|
collection_name="practitioners_db") |
|
|
|
|
|
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity", |
|
k=10) |
|
|
|
|
|
collection_name = "practitioners_db_sparse_collection" |
|
vector_name = "sparse_vector" |
|
sparse_retriever_client = SparseRetrieverClient( |
|
collection_name=collection_name, |
|
vector_name=vector_name, |
|
splade_model_id="naver/splade-cocondenser-ensembledistil", |
|
k=15) |
|
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever() |
|
|
|
|
|
|
|
practitioners_ensemble_retriever = EnsembleRetriever( |
|
retrievers=[practitioners_db_dense_retriever, |
|
practitioners_db_sparse_retriever], weights=[0.1, 0.9] |
|
) |
|
|
|
|
|
|
|
practitioners_db_compression_retriever = compression_retriever_setup( |
|
practitioners_db_sparse_retriever, |
|
embeddings_model="text-embedding-ada-002", |
|
similarity_threshold=0.74 |
|
) |
|
|
|
|
|
dense_retriever_client = DenseRetrieverClient(embeddings_model=embeddings_model, |
|
collection_name="tall_tree_db") |
|
tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever(search_type="similarity", |
|
k=5) |
|
|
|
tall_tree_db_compression_retriever = compression_retriever_setup( |
|
tall_tree_db_dense_retriever, |
|
embeddings_model="text-embedding-ada-002", |
|
similarity_threshold=0.5 |
|
) |
|
|
|
|
|
memory = ConversationBufferWindowMemory(memory_key="history", |
|
return_messages=True, |
|
k=5) |
|
|
|
|
|
setup_and_retrieval = {"practitioners_db": itemgetter("message") |
|
| practitioners_db_compression_retriever |
|
| randomize_documents |
|
| format_practitioners_docs, |
|
"tall_tree_db": itemgetter("message") | tall_tree_db_compression_retriever | format_tall_tree_docs, |
|
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"), |
|
"message": itemgetter("message") |
|
} |
|
|
|
chain = ( |
|
setup_and_retrieval |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
|
|
return chain, memory |
|
|