import json from abc import abstractmethod, ABC from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.language_models.chat_models import BaseChatModel from langchain_huggingface import HuggingFaceEmbeddings from langchain import hub from langchain.agents import create_react_agent from langchain.schema import SystemMessage from langchain.schema import SystemMessage, HumanMessage from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langgraph.prebuilt import create_react_agent class Database(ABC): @abstractmethod def create_agent(self, llm): raise NotImplementedError class Session: def __init__(self, llm: BaseChatModel, datasources=None): self.llm = llm self.datasources = datasources self._datasources = [] self._dataagents = [] if self.datasources is not None: for datasource in self.datasources: self.add_datasource(datasource) def add_datasource(self, database: Database): agent = database.create_agent(self.llm) self._datasources.append(database) self._dataagents.append(agent) def get_relevant_source(self, message, datasource): if datasource is not None: return self._datasources[datasource], self._dataagents[datasource] return self._datasources[0], self._dataagents[0] def invoke(self, message, datasource=None): db, agent = self.get_relevant_source(message, datasource) processed_message = db.process_message(message) response = agent.invoke(processed_message) processed_response = db.postprocess(response) return processed_response, response def stream(self, message, stream_mode=None): db, agent = self.get_relevant_source(message) return agent.stream( {"messages": [("user", message)]}, stream_mode=stream_mode, ) class SQLDatabase(Database): def __init__(self, db): self.db = db def create_agent(self, llm): toolkit = SQLDatabaseToolkit(db=self.db, llm=llm) prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") system_message = prompt_template.format(dialect="SQLite", top_k=5) agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message) return agent def process_message(self, message): return {"messages": [("user", message)]} def postprocess(self, response): return response['messages'][-1].content @classmethod def from_uri(cls, database_uri, engine_args=None, **kwargs): db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs) return cls(db) class DocumentDatabase(Database): def __init__( self, path: str, model_name: str = "sentence-transformers/all-mpnet-base-v2", top_k: int = 3, model_kwargs = None, encode_kwargs = None, ): self.path = path self.model_name = model_name self.top_k = top_k self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs embeddings = HuggingFaceEmbeddings( model_name=self.model_name, model_kwargs=self.model_kwargs, encode_kwargs=self.encode_kwargs, show_progress=False, ) self.vector_store = InMemoryVectorStore(embeddings) with open(path, 'rb') as f: self.vector_store.store = json.load(f) def create_agent(self, llm): # Step 1: Retrieve relevant documents from the vector store retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k))) # Step 2: Format the retrieved docs into a prompt def format_prompt(inputs): message, docs = inputs docs_in_promp = '\n\n'.join(doc.page_content for doc in docs) prompt = [ SystemMessage( "You are an assistant for question-answering tasks. " + "Use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + docs_in_promp ), HumanMessage(message) ] return prompt format_prompt_node = RunnableLambda(format_prompt) # Step 3: Invoke LLM with the formatted prompt invoke_llm = llm # Step 4: Chain everything together agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm return agent_pipeline def process_message(self, message): return message def postprocess(self, response): return response.content