Divyansh12 commited on
Commit
11c7c99
·
verified ·
1 Parent(s): 756320e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -40
app.py CHANGED
@@ -1,12 +1,6 @@
1
  import os
2
- import asyncio
3
  import nest_asyncio
4
- import pinecone
5
  import time
6
- import fitz
7
- import base64
8
- from pathlib import Path
9
- from typing import List, Tuple
10
  from dotenv import find_dotenv, load_dotenv
11
  from langchain_groq import ChatGroq
12
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -17,7 +11,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders import PyPDFDirectoryLoader
18
  from langchain_pinecone import PineconeVectorStore
19
  from pinecone import Pinecone, ServerlessSpec
20
- import gradio as gr
21
  from langchain import hub
22
 
23
  # Allow nested async calls
@@ -61,7 +55,9 @@ docs = text_splitter.split_documents(documents)
61
 
62
  def are_documents_indexed(index):
63
  try:
 
64
  test_embedding = embedding_model.embed_query("test")
 
65
  results = index.query(vector=test_embedding, top_k=1)
66
  return len(results.matches) > 0
67
  except Exception as e:
@@ -95,7 +91,7 @@ relevance_prompt_template = PromptTemplate.from_template(
95
  Return ONLY the numeric score, without any additional text or explanation.
96
  Question: {question}
97
  Retrieved Context: {retrieved_context}
98
- Relevance Score:"""
99
  )
100
 
101
  def format_docs(docs):
@@ -111,17 +107,6 @@ def conditional_answer(x):
111
  relevance_score = extract_score(x["relevance_score"])
112
  return "I don't know." if relevance_score < 4 else x["answer"]
113
 
114
- def highlight_pdf(pdf_path: str, text_chunks: List[str]) -> str:
115
- doc = fitz.open(pdf_path)
116
- for page in doc:
117
- for chunk in text_chunks:
118
- areas = page.search_for(chunk)
119
- for rect in areas:
120
- highlight = page.add_highlight_annot(rect)
121
- buffer = doc.write()
122
- doc.close()
123
- return base64.b64encode(buffer).decode()
124
-
125
  # RAG pipeline
126
  rag_chain_from_docs = (
127
  RunnablePassthrough.assign(context=lambda x: format_docs(x["context"]))
@@ -145,41 +130,32 @@ rag_chain_from_docs = (
145
  )
146
 
147
  rag_chain_with_source = RunnableParallel(
148
- {"context": retriever, "question": RunnablePassthrough()}
 
 
149
  ).assign(answer=rag_chain_from_docs)
150
 
151
- async def process_question(question: str) -> Tuple[str, str, dict]:
152
  try:
153
  result = await rag_chain_with_source.ainvoke(question)
154
  final_answer = result["answer"]["final_answer"]
155
- context_docs = result["context"]
156
-
157
- sources = []
158
- highlighted_pdfs = {}
159
-
160
- for doc in context_docs:
161
- source = doc.metadata.get("source")
162
- if source and source.endswith('.pdf'):
163
- sources.append(source)
164
- if source not in highlighted_pdfs:
165
- highlighted_pdfs[source] = highlight_pdf(source, [doc.page_content])
166
-
167
- return final_answer, ", ".join(sources), highlighted_pdfs
168
  except Exception as e:
169
- return f"Error: {str(e)}", "Error retrieving sources", {}
170
 
171
- # Gradio interface
172
- print("Setting up Gradio interface...")
173
  demo = gr.Interface(
174
  fn=process_question,
175
- inputs=gr.Textbox(label="Enter your question"),
176
  outputs=[
177
  gr.Textbox(label="Answer"),
178
  gr.Textbox(label="Sources"),
179
- gr.Gallery(label="Highlighted PDFs")
180
  ],
181
  title="RAG Question Answering",
182
- description="Enter a question to get an answer with highlighted relevant sections."
183
  )
184
 
185
  if __name__ == "__main__":
 
1
  import os
 
2
  import nest_asyncio
 
3
  import time
 
 
 
 
4
  from dotenv import find_dotenv, load_dotenv
5
  from langchain_groq import ChatGroq
6
  from langchain_huggingface import HuggingFaceEmbeddings
 
11
  from langchain_community.document_loaders import PyPDFDirectoryLoader
12
  from langchain_pinecone import PineconeVectorStore
13
  from pinecone import Pinecone, ServerlessSpec
14
+ import gradio as gr
15
  from langchain import hub
16
 
17
  # Allow nested async calls
 
55
 
56
  def are_documents_indexed(index):
57
  try:
58
+ # Create a simple test embedding
59
  test_embedding = embedding_model.embed_query("test")
60
+ # Query the index
61
  results = index.query(vector=test_embedding, top_k=1)
62
  return len(results.matches) > 0
63
  except Exception as e:
 
91
  Return ONLY the numeric score, without any additional text or explanation.
92
  Question: {question}
93
  Retrieved Context: {retrieved_context}
94
+ Relevance Score: """
95
  )
96
 
97
  def format_docs(docs):
 
107
  relevance_score = extract_score(x["relevance_score"])
108
  return "I don't know." if relevance_score < 4 else x["answer"]
109
 
 
 
 
 
 
 
 
 
 
 
 
110
  # RAG pipeline
111
  rag_chain_from_docs = (
112
  RunnablePassthrough.assign(context=lambda x: format_docs(x["context"]))
 
130
  )
131
 
132
  rag_chain_with_source = RunnableParallel(
133
+ {"context": retriever,
134
+ "question": RunnablePassthrough()
135
+ }
136
  ).assign(answer=rag_chain_from_docs)
137
 
138
+ async def process_question(question):
139
  try:
140
  result = await rag_chain_with_source.ainvoke(question)
141
  final_answer = result["answer"]["final_answer"]
142
+ sources = [doc.metadata.get("source") for doc in result["context"]]
143
+ source_list = ", ".join(sources)
144
+ return final_answer, source_list
 
 
 
 
 
 
 
 
 
 
145
  except Exception as e:
146
+ return f"Error: {str(e)}", "Error retrieving sources"
147
 
148
+ # Gradio
149
+ print("Gradio interface...")
150
  demo = gr.Interface(
151
  fn=process_question,
152
+ inputs=gr.Textbox(label="Enter your question", value=""),
153
  outputs=[
154
  gr.Textbox(label="Answer"),
155
  gr.Textbox(label="Sources"),
 
156
  ],
157
  title="RAG Question Answering",
158
+ description="Enter a question and get an answer from the PDFs.",
159
  )
160
 
161
  if __name__ == "__main__":