File size: 1,547 Bytes
09ae247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Chroma
# class for `Retreival Augmented Generation Pipeline`
class RAG:
def __init__(
self,
llm,
loader,
text_splitter,
embedding,
prompt = None
):
self.llm = llm
self.embedding = embedding
self.loader = loader
self.text_splitter = text_splitter
self.prompt = prompt if prompt else hub.pull("rlm/rag-prompt")
self.docs = self.load_docs()
self.splits = self.create_splits()
self.vectorstore = Chroma.from_documents(documents=self.splits, embedding=self.embedding())
self.retriever = self.get_retreiver()
self.rag_chain = self.generate_rag_chain()
def create_splits(self):
return self.text_splitter.split_documents(self.docs)
def load_docs(self):
return self.loader.load()
def get_retreiver(self):
return self.vectorstore.as_retriever()
def format_docs(self, docs):
return "\n\n".join(doc.page_content for doc in docs)
def generate_rag_chain(self):
return (
{"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def invoke(self, question):
return self.rag_chain.invoke(question)
|