test-public / orator.py
geetu040's picture
fix fstring
b07ca6f
import json
from abc import abstractmethod, ABC
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_huggingface import HuggingFaceEmbeddings
from langchain import hub
from langchain.agents import create_react_agent
from langchain.schema import SystemMessage
from langchain.schema import SystemMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langgraph.prebuilt import create_react_agent
class Database(ABC):
@abstractmethod
def create_agent(self, llm):
raise NotImplementedError
class Session:
def __init__(self, llm: BaseChatModel, datasources=None):
self.llm = llm
self.datasources = datasources
self._datasources = []
self._dataagents = []
if self.datasources is not None:
for datasource in self.datasources:
self.add_datasource(datasource)
def add_datasource(self, database: Database):
agent = database.create_agent(self.llm)
self._datasources.append(database)
self._dataagents.append(agent)
def get_relevant_source(self, message, datasource):
if datasource is not None:
return self._datasources[datasource], self._dataagents[datasource]
return self._datasources[0], self._dataagents[0]
def invoke(self, message, datasource=None):
db, agent = self.get_relevant_source(message, datasource)
processed_message = db.process_message(message)
response = agent.invoke(processed_message)
processed_response = db.postprocess(response)
return processed_response, response
def stream(self, message, stream_mode=None):
db, agent = self.get_relevant_source(message)
return agent.stream(
{"messages": [("user", message)]},
stream_mode=stream_mode,
)
class SQLDatabase(Database):
def __init__(self, db):
self.db = db
def create_agent(self, llm):
toolkit = SQLDatabaseToolkit(db=self.db, llm=llm)
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)
agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)
return agent
def process_message(self, message):
return {"messages": [("user", message)]}
def postprocess(self, response):
return response['messages'][-1].content
@classmethod
def from_uri(cls, database_uri, engine_args=None, **kwargs):
db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs)
return cls(db)
class DocumentDatabase(Database):
def __init__(
self,
path: str,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
top_k: int = 3,
model_kwargs = None,
encode_kwargs = None,
):
self.path = path
self.model_name = model_name
self.top_k = top_k
self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs
self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs
embeddings = HuggingFaceEmbeddings(
model_name=self.model_name,
model_kwargs=self.model_kwargs,
encode_kwargs=self.encode_kwargs,
show_progress=False,
)
self.vector_store = InMemoryVectorStore(embeddings)
with open(path, 'rb') as f:
self.vector_store.store = json.load(f)
def create_agent(self, llm):
# Step 1: Retrieve relevant documents from the vector store
retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k)))
# Step 2: Format the retrieved docs into a prompt
def format_prompt(inputs):
message, docs = inputs
docs_in_promp = '\n\n'.join(doc.page_content for doc in docs)
prompt = [
SystemMessage(
"You are an assistant for question-answering tasks. " +
"Use the following pieces of retrieved context to answer " +
"the question. If you don't know the answer, say that you " +
"don't know. Use three sentences maximum and keep the " +
"answer concise." +
"\n\n" +
docs_in_promp
),
HumanMessage(message)
]
return prompt
format_prompt_node = RunnableLambda(format_prompt)
# Step 3: Invoke LLM with the formatted prompt
invoke_llm = llm
# Step 4: Chain everything together
agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm
return agent_pipeline
def process_message(self, message):
return message
def postprocess(self, response):
return response.content