File size: 1,547 Bytes
09ae247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.vectorstores import Chroma

# class for `Retreival Augmented Generation Pipeline`
class RAG:
    def __init__(
            self,
            llm,
            loader,
            text_splitter,
            embedding,
            prompt = None
        ):
        self.llm = llm
        self.embedding = embedding
        self.loader = loader
        self.text_splitter = text_splitter
        self.prompt = prompt if prompt else hub.pull("rlm/rag-prompt")

        self.docs = self.load_docs()
        self.splits = self.create_splits()
        self.vectorstore = Chroma.from_documents(documents=self.splits, embedding=self.embedding())
        self.retriever = self.get_retreiver()
        self.rag_chain = self.generate_rag_chain()

    def create_splits(self):
        return self.text_splitter.split_documents(self.docs)
    
    def load_docs(self):
        return self.loader.load()

    def get_retreiver(self):
        return self.vectorstore.as_retriever()

    def format_docs(self, docs):
        return "\n\n".join(doc.page_content for doc in docs)

    def generate_rag_chain(self):
        return (
            {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        )

    def invoke(self, question):
        return self.rag_chain.invoke(question)