|
from langchain import hub |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_community.vectorstores import Chroma |
|
|
|
|
|
class RAG: |
|
def __init__( |
|
self, |
|
llm, |
|
loader, |
|
text_splitter, |
|
embedding, |
|
prompt = None |
|
): |
|
self.llm = llm |
|
self.embedding = embedding |
|
self.loader = loader |
|
self.text_splitter = text_splitter |
|
self.prompt = prompt if prompt else hub.pull("rlm/rag-prompt") |
|
|
|
self.docs = self.load_docs() |
|
self.splits = self.create_splits() |
|
self.vectorstore = Chroma.from_documents(documents=self.splits, embedding=self.embedding()) |
|
self.retriever = self.get_retreiver() |
|
self.rag_chain = self.generate_rag_chain() |
|
|
|
def create_splits(self): |
|
return self.text_splitter.split_documents(self.docs) |
|
|
|
def load_docs(self): |
|
return self.loader.load() |
|
|
|
def get_retreiver(self): |
|
return self.vectorstore.as_retriever() |
|
|
|
def format_docs(self, docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
def generate_rag_chain(self): |
|
return ( |
|
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
|
|
def invoke(self, question): |
|
return self.rag_chain.invoke(question) |
|
|
|
|