Ley_Fill7 commited on
Commit
00621a0
·
1 Parent(s): b24caf7

Add the HyDE RAG app file

Browse files
Files changed (1) hide show
  1. app.py +152 -2
app.py CHANGED
@@ -1,7 +1,157 @@
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
 
3
- x = st.slider('Select the value')
4
- st.write(x, 'squared is', x * x)
 
5
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import modules and classes
2
+ from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
3
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
4
+ from llama_index.core.indices.query.query_transform import HyDEQueryTransform
5
+ from llama_index.core.query_engine import TransformQueryEngine
6
+ from langchain_core.documents import Document as LangDocument
7
+ from llama_index.core import Document as LlamaDocument
8
+ from llama_index.core import Settings
9
+ from llama_parse import LlamaParse
10
  import streamlit as st
11
+ import os
12
 
13
+ # Set environmental variables
14
+ nvidia_api_key = os.getenv("NVIDIA_KEY")
15
+ llamaparse_api_key = os.getenv("PARSE_KEY")
16
 
17
+ # Initialize ChatNVIDIA, NVIDIARerank, and NVIDIAEmbeddings
18
+ client = ChatNVIDIA(
19
+ model="meta/llama-3.1-8b-instruct",
20
+ api_key=nvidia_api_key,
21
+ temperature=0.2,
22
+ top_p=0.7,
23
+ max_tokens=1024
24
+ )
25
 
26
+ embed_model = NVIDIAEmbeddings(
27
+ model="nvidia/nv-embedqa-e5-v5",
28
+ api_key=nvidia_api_key,
29
+ truncate="NONE"
30
+ )
31
 
32
+ reranker = NVIDIARerank(
33
+ model="nvidia/nv-rerankqa-mistral-4b-v3",
34
+ api_key=nvidia_api_key,
35
+ )
36
+
37
+ # Set the NVIDIA models globally
38
+ Settings.embed_model = embed_model
39
+ Settings.llm = client
40
+
41
+ # Parse the local PDF document
42
+ parser = LlamaParse(
43
+ api_key=llamaparse_api_key,
44
+ result_type="markdown",
45
+ verbose=True
46
+ )
47
+
48
+ documents = parser.load_data("C:\\Users\\user\\Documents\\Jan 2024\\Projects\\RAGs\\Files\\PhilDataset.pdf")
49
+ print("Document Parsed")
50
+
51
+ # Split parsed text into chunks for embedding model
52
+ def split_text(text, max_tokens=512):
53
+ words = text.split()
54
+ chunks = []
55
+ current_chunk = []
56
+ current_length = 0
57
+
58
+ for word in words:
59
+ word_length = len(word)
60
+ if current_length + word_length + 1 > max_tokens:
61
+ chunks.append(" ".join(current_chunk))
62
+ current_chunk = [word]
63
+ current_length = word_length + 1
64
+ else:
65
+ current_chunk.append(word)
66
+ current_length += word_length + 1
67
+
68
+ if current_chunk:
69
+ chunks.append(" ".join(current_chunk))
70
+
71
+ return chunks
72
+
73
+ # Generate embeddings for document chunks
74
+ all_embeddings = []
75
+ all_documents = []
76
+
77
+ for doc in documents:
78
+ text_chunks = split_text(doc.text)
79
+ for chunk in text_chunks:
80
+ embedding = embed_model.embed_query(chunk)
81
+ all_embeddings.append(embedding)
82
+ all_documents.append(LlamaDocument(text=chunk))
83
+ print("Embeddings generated")
84
+
85
+ # Create and persist index with NVIDIAEmbeddings
86
+ index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model)
87
+ index.set_index_id("vector_index")
88
+ index.storage_context.persist("./storage")
89
+ print("Index created")
90
+
91
+ # Load index from storage
92
+ storage_context = StorageContext.from_defaults(persist_dir="storage")
93
+ index = load_index_from_storage(storage_context, index_id="vector_index")
94
+ print("Index loaded")
95
+
96
+ # Initialize HyDEQueryTransform and TransformQueryEngine
97
+ hyde = HyDEQueryTransform(include_original=True)
98
+ query_engine = index.as_query_engine()
99
+ hyde_query_engine = TransformQueryEngine(query_engine, hyde)
100
+
101
+ # Query the index with HyDE and use output as LLM context
102
+ def query_model_with_context(question):
103
+ # Generate a hypothetical document using HyDE
104
+ hyde_response = hyde_query_engine.query(question)
105
+ print(f"HyDE Response: {hyde_response}")
106
+
107
+ if isinstance(hyde_response, str):
108
+ hyde_query = hyde_response
109
+ else:
110
+ hyde_query = hyde_response.response
111
+
112
+ # Use the hypothetical document to retrieve relevant documents
113
+ retriever = index.as_retriever(similarity_top_k=3)
114
+ nodes = retriever.retrieve(hyde_query)
115
+
116
+ for node in nodes:
117
+ print(node)
118
+
119
+ # Rerank the retrieved documents
120
+ ranked_documents = reranker.compress_documents(
121
+ query=question,
122
+ documents=[LangDocument(page_content=node.text) for node in nodes]
123
+ )
124
+
125
+ # Print the most relevant and least relevant node
126
+ print(f"Most relevant node: {ranked_documents[0].page_content}")
127
+
128
+ # Use the most relevant node as context
129
+ context = ranked_documents[0].page_content
130
+
131
+ # Send context and question to the client (NVIDIA Llama 3.1 8B model)
132
+ messages = [
133
+ {"role": "system", "content": context},
134
+ {"role": "user", "content": str(question)}
135
+ ]
136
+ completion = client.stream(messages)
137
+
138
+ # Process response
139
+ response_text = ""
140
+ for chunk in completion:
141
+ if chunk.content is not None:
142
+ response_text += chunk.content
143
+
144
+ return response_text
145
+
146
+
147
+ # Streamlit UI
148
+ st.title("Chat with HyDE + Rerank RAG")
149
+ question = st.text_input("Enter your question:")
150
+
151
+ if st.button("Submit"):
152
+ if question:
153
+ st.write("**RAG Response:**")
154
+ response = query_model_with_context(question)
155
+ st.write(response)
156
+ else:
157
+ st.warning("Please enter a question.")