File size: 4,158 Bytes
30c6ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from pydantic import BaseModel, Field
from langchain.tools import Tool
from langchain_community.vectorstores import Neo4jVector
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain import hub
import os

# Initialize LLM
key = "sk-proj-LdVhjM2bTI27bA3grOK8T3BlbkFJh5whi2UHYKkgM2pNwpbe"
os.environ["OPENAI_API_KEY"] = key

class RAGToolConfig(BaseModel):
    NEO4J_URI: str = Field(default="neo4j+s://741a3118.databases.neo4j.io")
    NEO4J_USERNAME: str = Field(default="neo4j")
    NEO4J_PASSWORD: str = Field(default="XvUolnAXmgx9SG_lRSJuisbDClxi2MiTKGIoBdqN53A")
    pdf_path: str = Field(default="/mnt/d/atx/hragent/rag/Sirca_Paints.pdf")

class RAGToolImplementation:
    def __init__(self, config: RAGToolConfig, llm):
        self.config = config
        self.llm = llm  # Store the llm instance
        self.embedding_model = OpenAIEmbeddings()
        self.vectorstore = self._initialize_vectorstore()
        self.rag_chain = self._setup_rag_chain()

    def _initialize_vectorstore(self):
        try:
            # Try to load existing vector store
            vectorstore = Neo4jVector(
                url=self.config.NEO4J_URI,
                username=self.config.NEO4J_USERNAME,
                password=self.config.NEO4J_PASSWORD,
                embedding=self.embedding_model,
                index_name="pdf_embeddings",
                node_label="PDFChunk",
                text_node_property="text",
                embedding_node_property="embedding"
            )
            vectorstore.similarity_search("Test query", k=1)
            print("Existing vector store loaded.")
        except Exception as e:
            print(f"Creating new vector store. Error: {e}")
            # Load and process the PDF
            loader = PyPDFLoader(self.config.pdf_path)
            docs = loader.load()

            # Split the document into chunks
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
            splits = text_splitter.split_documents(docs)

            # Create new vector store
            vectorstore = Neo4jVector.from_documents(
                documents=splits,
                embedding=self.embedding_model,
                url=self.config.NEO4J_URI,
                username=self.config.NEO4J_USERNAME,
                password=self.config.NEO4J_PASSWORD,
                index_name="pdf_embeddings",
                node_label="PDFChunk",
                text_node_property="text",
                embedding_node_property="embedding"
            )
            print("New vector store created and loaded.")
        return vectorstore
    
    def _setup_rag_chain(self):
        retriever = self.vectorstore.as_retriever()
        prompt = hub.pull("rlm/rag-prompt")

        def format_docs(docs):
            return "\n\n".join(doc.page_content for doc in docs)

        rag_chain = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | self.llm  # Use the llm instance here
            | StrOutputParser()
        )
        return rag_chain

    def run(self, query: str) -> str:
        try:
            response = self.rag_chain.invoke(query)
            return response
        except Exception as e:
            return f"An error occurred while processing the query: {str(e)}"
        

def create_rag_tool(config: RAGToolConfig = RAGToolConfig(), llm=None):
    implementation = RAGToolImplementation(config, llm)
    return Tool(
        name="RAGTool",
        description="Retrieval-Augmented Generation Tool for querying PDF content about Sirca Paints",
        func=implementation.run
    )

# # Example Usage
# if __name__ == "__main__":
#     llm = ChatOpenAI(model="gpt-4", temperature=0)
#     rag_tool = create_rag_tool(llm=llm)

#     # Test the tool
#     result = rag_tool.run("What is spil ethics?")
#     print(result)