Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,16 +4,45 @@ import time
|
|
4 |
import pandas as pd
|
5 |
import sqlite3
|
6 |
import logging
|
|
|
7 |
|
8 |
-
from langchain.document_loaders import OnlinePDFLoader # for loading
|
9 |
from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
|
10 |
from langchain.text_splitter import CharacterTextSplitter
|
11 |
-
from langchain_community.vectorstores import Chroma #
|
12 |
from langchain.chains import RetrievalQA # for QA chain
|
13 |
-
from langchain_community.chat_models import ChatOpenAI # updated import for ChatOpenAI
|
14 |
from langchain_core.prompts import PromptTemplate # prompt template import
|
15 |
|
16 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
log_messages = "" # global log collector
|
@@ -23,12 +52,12 @@ def update_log(message):
|
|
23 |
log_messages += message + "\n"
|
24 |
logger.info(message)
|
25 |
|
26 |
-
|
|
|
|
|
|
|
27 |
try:
|
28 |
-
|
29 |
-
os.environ['OPENAI_API_KEY'] = open_ai_key
|
30 |
-
|
31 |
-
# Use the file path directly as OCR is removed; text is extracted via the document loader.
|
32 |
pdf_path = pdf_doc.name
|
33 |
loader = OnlinePDFLoader(pdf_path)
|
34 |
pages = loader.load_and_split()
|
@@ -59,18 +88,21 @@ def load_pdf_and_generate_embeddings(pdf_doc, open_ai_key, relevant_pages):
|
|
59 |
|
60 |
global pdf_qa
|
61 |
pdf_qa = RetrievalQA.from_chain_type(
|
62 |
-
llm=
|
63 |
chain_type="stuff",
|
64 |
retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
|
65 |
chain_type_kwargs=chain_type_kwargs,
|
66 |
return_source_documents=False
|
67 |
)
|
68 |
-
update_log("PDF embeddings generated and QA chain initialized.")
|
69 |
return "Ready"
|
70 |
except Exception as e:
|
71 |
update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
|
72 |
return f"Error: {str(e)}"
|
73 |
|
|
|
|
|
|
|
74 |
def create_db_connection():
|
75 |
DB_FILE = "./questionset.db"
|
76 |
connection = sqlite3.connect(DB_FILE, check_same_thread=False)
|
@@ -226,6 +258,9 @@ def answer_query(query):
|
|
226 |
def get_log():
|
227 |
return log_messages
|
228 |
|
|
|
|
|
|
|
229 |
css = """
|
230 |
#col-container {max-width: 700px; margin: auto;}
|
231 |
"""
|
@@ -243,7 +278,7 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
|
|
243 |
|
244 |
with gr.Tab("Chatbot"):
|
245 |
with gr.Column():
|
246 |
-
|
247 |
pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
|
248 |
relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
|
249 |
|
@@ -272,15 +307,13 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
|
|
272 |
|
273 |
log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
|
274 |
|
275 |
-
with gr.Tab("
|
276 |
-
# This tab is now repurposed (or can be removed)
|
277 |
with gr.Column():
|
278 |
-
image_pdf = gr.File(label="Load PDF for
|
279 |
with gr.Row():
|
280 |
extracted_text = gr.Textbox(label="Extracted Text", lines=10)
|
281 |
extract_btn = gr.Button("Extract Text")
|
282 |
|
283 |
-
# For demonstration, extract text using OnlinePDFLoader
|
284 |
def extract_text(pdf_file):
|
285 |
try:
|
286 |
loader = OnlinePDFLoader(pdf_file.name)
|
@@ -306,7 +339,7 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
|
|
306 |
refresh_log_btn = gr.Button("Refresh Log")
|
307 |
refresh_log_btn.click(get_log, outputs=log_window)
|
308 |
|
309 |
-
load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc,
|
310 |
summarize_pdf_btn.click(summarize_contents, outputs=summary)
|
311 |
submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
|
312 |
|
|
|
4 |
import pandas as pd
|
5 |
import sqlite3
|
6 |
import logging
|
7 |
+
import requests # for HTTP calls to Gemini
|
8 |
|
9 |
+
from langchain.document_loaders import OnlinePDFLoader # for loading PDF text
|
10 |
from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
|
11 |
from langchain.text_splitter import CharacterTextSplitter
|
12 |
+
from langchain_community.vectorstores import Chroma # vectorization from langchain_community
|
13 |
from langchain.chains import RetrievalQA # for QA chain
|
|
|
14 |
from langchain_core.prompts import PromptTemplate # prompt template import
|
15 |
|
16 |
+
# ------------------------------
|
17 |
+
# Gemini API Wrapper
|
18 |
+
# ------------------------------
|
19 |
+
class ChatGemini:
|
20 |
+
def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"):
|
21 |
+
self.api_key = api_key
|
22 |
+
self.temperature = temperature
|
23 |
+
self.model_name = model_name
|
24 |
+
|
25 |
+
def generate(self, prompt):
|
26 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}"
|
27 |
+
payload = {
|
28 |
+
"contents": [{
|
29 |
+
"parts": [{"text": prompt}]
|
30 |
+
}]
|
31 |
+
}
|
32 |
+
headers = {"Content-Type": "application/json"}
|
33 |
+
response = requests.post(url, json=payload, headers=headers)
|
34 |
+
if response.status_code != 200:
|
35 |
+
raise Exception(f"Gemini API error: {response.status_code} - {response.text}")
|
36 |
+
data = response.json()
|
37 |
+
candidate = data.get("candidates", [{}])[0]
|
38 |
+
return candidate.get("output", {}).get("text", "No output from Gemini API")
|
39 |
+
|
40 |
+
def __call__(self, prompt, **kwargs):
|
41 |
+
return self.generate(prompt)
|
42 |
+
|
43 |
+
# ------------------------------
|
44 |
+
# Setup Logging
|
45 |
+
# ------------------------------
|
46 |
logging.basicConfig(level=logging.INFO)
|
47 |
logger = logging.getLogger(__name__)
|
48 |
log_messages = "" # global log collector
|
|
|
52 |
log_messages += message + "\n"
|
53 |
logger.info(message)
|
54 |
|
55 |
+
# ------------------------------
|
56 |
+
# PDF Embedding & QA Chain (No OCR)
|
57 |
+
# ------------------------------
|
58 |
+
def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages):
|
59 |
try:
|
60 |
+
# Use the PDF file's path to extract text.
|
|
|
|
|
|
|
61 |
pdf_path = pdf_doc.name
|
62 |
loader = OnlinePDFLoader(pdf_path)
|
63 |
pages = loader.load_and_split()
|
|
|
88 |
|
89 |
global pdf_qa
|
90 |
pdf_qa = RetrievalQA.from_chain_type(
|
91 |
+
llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"),
|
92 |
chain_type="stuff",
|
93 |
retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
|
94 |
chain_type_kwargs=chain_type_kwargs,
|
95 |
return_source_documents=False
|
96 |
)
|
97 |
+
update_log("PDF embeddings generated and QA chain initialized using Gemini.")
|
98 |
return "Ready"
|
99 |
except Exception as e:
|
100 |
update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
|
101 |
return f"Error: {str(e)}"
|
102 |
|
103 |
+
# ------------------------------
|
104 |
+
# SQLite Question Set Functions
|
105 |
+
# ------------------------------
|
106 |
def create_db_connection():
|
107 |
DB_FILE = "./questionset.db"
|
108 |
connection = sqlite3.connect(DB_FILE, check_same_thread=False)
|
|
|
258 |
def get_log():
|
259 |
return log_messages
|
260 |
|
261 |
+
# ------------------------------
|
262 |
+
# Gradio Interface
|
263 |
+
# ------------------------------
|
264 |
css = """
|
265 |
#col-container {max-width: 700px; margin: auto;}
|
266 |
"""
|
|
|
278 |
|
279 |
with gr.Tab("Chatbot"):
|
280 |
with gr.Column():
|
281 |
+
gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password")
|
282 |
pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
|
283 |
relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
|
284 |
|
|
|
307 |
|
308 |
log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
|
309 |
|
310 |
+
with gr.Tab("Text Extractor"):
|
|
|
311 |
with gr.Column():
|
312 |
+
image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath')
|
313 |
with gr.Row():
|
314 |
extracted_text = gr.Textbox(label="Extracted Text", lines=10)
|
315 |
extract_btn = gr.Button("Extract Text")
|
316 |
|
|
|
317 |
def extract_text(pdf_file):
|
318 |
try:
|
319 |
loader = OnlinePDFLoader(pdf_file.name)
|
|
|
339 |
refresh_log_btn = gr.Button("Refresh Log")
|
340 |
refresh_log_btn.click(get_log, outputs=log_window)
|
341 |
|
342 |
+
load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status)
|
343 |
summarize_pdf_btn.click(summarize_contents, outputs=summary)
|
344 |
submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
|
345 |
|