GenAI / src /rag.py
Ubuntu
rag initial pipeline setup
09ae247
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Chroma
# class for `Retreival Augmented Generation Pipeline`
class RAG:
def __init__(
self,
llm,
loader,
text_splitter,
embedding,
prompt = None
):
self.llm = llm
self.embedding = embedding
self.loader = loader
self.text_splitter = text_splitter
self.prompt = prompt if prompt else hub.pull("rlm/rag-prompt")
self.docs = self.load_docs()
self.splits = self.create_splits()
self.vectorstore = Chroma.from_documents(documents=self.splits, embedding=self.embedding())
self.retriever = self.get_retreiver()
self.rag_chain = self.generate_rag_chain()
def create_splits(self):
return self.text_splitter.split_documents(self.docs)
def load_docs(self):
return self.loader.load()
def get_retreiver(self):
return self.vectorstore.as_retriever()
def format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
def generate_rag_chain(self):
return (
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def invoke(self, question):
return self.rag_chain.invoke(question)