from langchain import hub from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_community.vectorstores import Chroma # class for `Retreival Augmented Generation Pipeline` class RAG: def __init__( self, llm, loader, text_splitter, embedding, prompt = None ): self.llm = llm self.embedding = embedding self.loader = loader self.text_splitter = text_splitter self.prompt = prompt if prompt else hub.pull("rlm/rag-prompt") self.docs = self.load_docs() self.splits = self.create_splits() self.vectorstore = Chroma.from_documents(documents=self.splits, embedding=self.embedding()) self.retriever = self.get_retreiver() self.rag_chain = self.generate_rag_chain() def create_splits(self): return self.text_splitter.split_documents(self.docs) def load_docs(self): return self.loader.load() def get_retreiver(self): return self.vectorstore.as_retriever() def format_docs(self, docs): return "\n\n".join(doc.page_content for doc in docs) def generate_rag_chain(self): return ( {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()} | self.prompt | self.llm | StrOutputParser() ) def invoke(self, question): return self.rag_chain.invoke(question)