File size: 10,995 Bytes
29cf982
 
d687543
29cf982
fc8d6af
29cf982
 
 
4dd1424
 
 
29cf982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4627350
29cf982
 
 
4627350
29cf982
 
 
f8977f5
 
 
 
 
 
 
 
 
29cf982
e49ec86
f8977f5
 
29cf982
 
f8977f5
 
29cf982
f8977f5
 
 
 
c1c75a4
 
f8977f5
c1c75a4
f8977f5
 
 
c1c75a4
f8977f5
 
 
 
e49ec86
f8977f5
 
29cf982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8d6af
29cf982
 
 
 
 
 
 
e49ec86
29cf982
 
 
 
 
 
 
 
 
 
 
 
 
 
4627350
29cf982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8977f5
 
 
 
 
 
 
 
 
e49ec86
f8977f5
 
29cf982
 
0191305
 
 
 
 
29cf982
5dc4034
07b0d8d
5dc4034
29cf982
 
 
d687543
29cf982
 
 
 
 
 
 
 
 
0191305
29cf982
 
 
0191305
 
29cf982
806b1ef
29cf982
0191305
 
29cf982
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import mlflow
import datetime
import streamlit as st
from functools import partial
from operator import itemgetter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
# from langchain_community.chat_models import ChatDatabricks # lets be consistent with the packages were using
from langchain_databricks import ChatDatabricks
# from langchain_community.vectorstores import DatabricksVectorSearch # is this causing an issue?
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableBranch
from langchain_core.messages import HumanMessage, AIMessage

# ## Enable MLflow Tracing
# mlflow.langchain.autolog()

class ChainBuilder:

    def __init__(self):
        # Load the chain's configuration from yaml
        self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml")
        self.databricks_resources = self.model_config.get("databricks_resources")
        self.llm_config = self.model_config.get("llm_config")
        self.retriever_config = self.model_config.get("retriever_config")
        self.vector_search_schema = self.retriever_config.get("schema")

    # Return the string contents of the most recent message from the user
    def extract_user_query_string(self, chat_messages_array):
        return chat_messages_array[-1]["content"]

    # Return the chat history, which is everything before the last question
    def extract_chat_history(self, chat_messages_array):
        return chat_messages_array[:-1]

    def load_embedding_model(self):
        model_name = self.retriever_config.get("embedding_model")

        # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
        # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
        # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
        # does this cache to the given folder though? It does appear to populate the folder as expected after being run
        @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
        def load_and_cache_embedding_model(model_name):
            embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
        # update this to read from a presaved cache of bge-large
            return embeddings # return directly?
        
        return load_and_cache_embedding_model(model_name)

    def get_retriever(self):
        endpoint=self.databricks_resources.get("vector_search_endpoint_name")
        index_name=self.retriever_config.get("vector_search_index")
        embeddings = self.load_embedding_model()
        search_kwargs=self.retriever_config.get("parameters")
        
        # you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument. 
        # This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
        # 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
        # this is fine, we are caching the entire function above for embeddings, so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
        @st.cache_resource # cache the Databricks vector store retriever
        def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
            vector_search_as_retriever = DatabricksVectorSearch(
                endpoint=endpoint,
                index_name=index_name,
                embedding=_embeddings,
                text_column="name",
                columns=["name", "description"],
            ).as_retriever(search_kwargs=search_kwargs)

            return vector_search_as_retriever # return directly?
        
        return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs)

    # # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
    # # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
    # mlflow.models.set_retriever_schema(
    #     primary_key=self.vector_search_schema.get("primary_key"),
    #     text_column=vector_search_schema.get("chunked_terms"),
    #     # doc_uri=vector_search_schema.get("definition")
    #     other_columns=[vector_search_schema.get("definition")],
    #     # Review App uses `doc_uri` to display chunks from the same document in a single view
    # )

    # Method to format the terms and definitions returned by the retriever into the prompt
    def format_context(self, retrieved_terms):
        chunk_template = self.retriever_config.get("chunk_template")
        chunk_contents = [
            chunk_template.format(
                name=term.page_content,
                description=term.metadata[self.vector_search_schema.get("description")],
            )
            for term in retrieved_terms
        ]
        return "".join(chunk_contents)

    def get_prompt(self):
        # Prompt Template for generation
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", partial(self.llm_config.get("llm_prompt_template").format, date_str=datetime.datetime.now().strftime("%B %d, %Y"))), # add current date to the date_str var in system prompt
                # *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO
                # We need to at some point chop this history down to fixed amount of recent messages
                MessagesPlaceholder(variable_name="formatted_chat_history"),
                # User's most current question
                ("user", "{question}"),
            ]
        )
        return prompt # return directly?

    # Format the converastion history to fit into the prompt template above.
    # **** TODO after only a few statements this will likely overflow the context window
    def format_chat_history_for_prompt(self, chat_messages_array):
        history = self.extract_chat_history(chat_messages_array)
        formatted_chat_history = []
        if len(history) > 0:
            for chat_message in history:
                if chat_message["role"] == "user":
                    formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
                elif chat_message["role"] == "assistant":
                    formatted_chat_history.append(AIMessage(content=chat_message["content"]))
        return formatted_chat_history

    def get_query_rewrite_prompt(self):
        # Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?"
        query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so 
        that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant 
        information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation.

        Chat history: {chat_history}

        Question: {question}"""

        query_rewrite_prompt = PromptTemplate(
            template=query_rewrite_template,
            input_variables=["chat_history", "question"],
        )
        return query_rewrite_prompt

    def get_model(self):
        endpoint = self.databricks_resources.get("llm_endpoint_name")
        extra_params=self.llm_config.get("llm_parameters")

        @st.cache_resource # cache the DBRX Instruct model we are loading for repeated use in our chain for chat completion
        def get_and_cache_model(endpoint, extra_params):
            model = ChatDatabricks(
                endpoint=endpoint,
                extra_params=extra_params,
            )
            return model # return directly?
        
        return get_and_cache_model(endpoint, extra_params)

    def build_chain(self):
        # model = self.get_model()
        # prompt = self.get_prompt()
        # format_context = self.format_context()
        # vector_search_as_retriever = self.get_retriever()
        # query_rewrite_prompt = self.get_query_rewrite_prompt()

        def get_date():
            return datetime.datetime.now().strftime("%B %d, %Y")

        # RAG Chain
        chain = (
            {
                "question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), # set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string()
                "chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history),
                "formatted_chat_history": itemgetter("messages")
                | RunnableLambda(self.format_chat_history_for_prompt),
            }
            | RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain
            | {
                "context": RunnableBranch(  # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement
                    (
                        lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html
                        self.get_query_rewrite_prompt() | self.get_model() | StrOutputParser(), # rewrite question with context
                    ),
                    itemgetter("question"), # else, just ask the question
                )
                | self.get_retriever() # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search
                | RunnableLambda(self.format_context), 
                "formatted_chat_history": itemgetter("formatted_chat_history"),
                "question": itemgetter("question"),
            }
            | self.get_prompt() # 'context', 'formatted_chat_history', and 'question' passed to prompt
            | self.get_model() # prompt passed to model
            | StrOutputParser()
        )
        return chain

    # ## Tell MLflow logging where to find your chain.
    # mlflow.models.set_model(model=chain)