Spaces:
Sleeping
Sleeping
import json | |
from abc import abstractmethod, ABC | |
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase | |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase | |
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit | |
from langchain_core.vectorstores import InMemoryVectorStore | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain import hub | |
from langchain.agents import create_react_agent | |
from langchain.schema import SystemMessage | |
from langchain.schema import SystemMessage, HumanMessage | |
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
from langgraph.prebuilt import create_react_agent | |
class Database(ABC): | |
def create_agent(self, llm): | |
raise NotImplementedError | |
class Session: | |
def __init__(self, llm: BaseChatModel, datasources=None): | |
self.llm = llm | |
self.datasources = datasources | |
self._datasources = [] | |
self._dataagents = [] | |
if self.datasources is not None: | |
for datasource in self.datasources: | |
self.add_datasource(datasource) | |
def add_datasource(self, database: Database): | |
agent = database.create_agent(self.llm) | |
self._datasources.append(database) | |
self._dataagents.append(agent) | |
def get_relevant_source(self, message, datasource): | |
if datasource is not None: | |
return self._datasources[datasource], self._dataagents[datasource] | |
return self._datasources[0], self._dataagents[0] | |
def invoke(self, message, datasource=None): | |
db, agent = self.get_relevant_source(message, datasource) | |
processed_message = db.process_message(message) | |
response = agent.invoke(processed_message) | |
processed_response = db.postprocess(response) | |
return processed_response, response | |
def stream(self, message, stream_mode=None): | |
db, agent = self.get_relevant_source(message) | |
return agent.stream( | |
{"messages": [("user", message)]}, | |
stream_mode=stream_mode, | |
) | |
class SQLDatabase(Database): | |
def __init__(self, db): | |
self.db = db | |
def create_agent(self, llm): | |
toolkit = SQLDatabaseToolkit(db=self.db, llm=llm) | |
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") | |
system_message = prompt_template.format(dialect="SQLite", top_k=5) | |
agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message) | |
return agent | |
def process_message(self, message): | |
return {"messages": [("user", message)]} | |
def postprocess(self, response): | |
return response['messages'][-1].content | |
def from_uri(cls, database_uri, engine_args=None, **kwargs): | |
db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs) | |
return cls(db) | |
class DocumentDatabase(Database): | |
def __init__( | |
self, | |
path: str, | |
model_name: str = "sentence-transformers/all-mpnet-base-v2", | |
top_k: int = 3, | |
model_kwargs = None, | |
encode_kwargs = None, | |
): | |
self.path = path | |
self.model_name = model_name | |
self.top_k = top_k | |
self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs | |
self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs | |
embeddings = HuggingFaceEmbeddings( | |
model_name=self.model_name, | |
model_kwargs=self.model_kwargs, | |
encode_kwargs=self.encode_kwargs, | |
show_progress=False, | |
) | |
self.vector_store = InMemoryVectorStore(embeddings) | |
with open(path, 'rb') as f: | |
self.vector_store.store = json.load(f) | |
def create_agent(self, llm): | |
# Step 1: Retrieve relevant documents from the vector store | |
retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k))) | |
# Step 2: Format the retrieved docs into a prompt | |
def format_prompt(inputs): | |
message, docs = inputs | |
docs_in_promp = '\n\n'.join(doc.page_content for doc in docs) | |
prompt = [ | |
SystemMessage( | |
"You are an assistant for question-answering tasks. " + | |
"Use the following pieces of retrieved context to answer " + | |
"the question. If you don't know the answer, say that you " + | |
"don't know. Use three sentences maximum and keep the " + | |
"answer concise." + | |
"\n\n" + | |
docs_in_promp | |
), | |
HumanMessage(message) | |
] | |
return prompt | |
format_prompt_node = RunnableLambda(format_prompt) | |
# Step 3: Invoke LLM with the formatted prompt | |
invoke_llm = llm | |
# Step 4: Chain everything together | |
agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm | |
return agent_pipeline | |
def process_message(self, message): | |
return message | |
def postprocess(self, response): | |
return response.content | |