Omarrran commited on
Commit
203d168
·
verified ·
1 Parent(s): 23b23e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -17
app.py CHANGED
@@ -4,16 +4,45 @@ import time
4
  import pandas as pd
5
  import sqlite3
6
  import logging
 
7
 
8
- from langchain.document_loaders import OnlinePDFLoader # for loading the PDF text
9
  from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
10
  from langchain.text_splitter import CharacterTextSplitter
11
- from langchain_community.vectorstores import Chroma # updated import for vectorization
12
  from langchain.chains import RetrievalQA # for QA chain
13
- from langchain_community.chat_models import ChatOpenAI # updated import for ChatOpenAI
14
  from langchain_core.prompts import PromptTemplate # prompt template import
15
 
16
- # Setup basic logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
  log_messages = "" # global log collector
@@ -23,12 +52,12 @@ def update_log(message):
23
  log_messages += message + "\n"
24
  logger.info(message)
25
 
26
- def load_pdf_and_generate_embeddings(pdf_doc, open_ai_key, relevant_pages):
 
 
 
27
  try:
28
- if open_ai_key is not None:
29
- os.environ['OPENAI_API_KEY'] = open_ai_key
30
-
31
- # Use the file path directly as OCR is removed; text is extracted via the document loader.
32
  pdf_path = pdf_doc.name
33
  loader = OnlinePDFLoader(pdf_path)
34
  pages = loader.load_and_split()
@@ -59,18 +88,21 @@ def load_pdf_and_generate_embeddings(pdf_doc, open_ai_key, relevant_pages):
59
 
60
  global pdf_qa
61
  pdf_qa = RetrievalQA.from_chain_type(
62
- llm=ChatOpenAI(temperature=0, model_name="gpt-4"),
63
  chain_type="stuff",
64
  retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
65
  chain_type_kwargs=chain_type_kwargs,
66
  return_source_documents=False
67
  )
68
- update_log("PDF embeddings generated and QA chain initialized.")
69
  return "Ready"
70
  except Exception as e:
71
  update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
72
  return f"Error: {str(e)}"
73
 
 
 
 
74
  def create_db_connection():
75
  DB_FILE = "./questionset.db"
76
  connection = sqlite3.connect(DB_FILE, check_same_thread=False)
@@ -226,6 +258,9 @@ def answer_query(query):
226
  def get_log():
227
  return log_messages
228
 
 
 
 
229
  css = """
230
  #col-container {max-width: 700px; margin: auto;}
231
  """
@@ -243,7 +278,7 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
243
 
244
  with gr.Tab("Chatbot"):
245
  with gr.Column():
246
- open_ai_key = gr.Textbox(label="Your GPT-4 API Key", type="password")
247
  pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
248
  relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
249
 
@@ -272,15 +307,13 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
272
 
273
  log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
274
 
275
- with gr.Tab("OCR Converter"):
276
- # This tab is now repurposed (or can be removed)
277
  with gr.Column():
278
- image_pdf = gr.File(label="Load PDF for Conversion", file_types=['.pdf'], type='filepath')
279
  with gr.Row():
280
  extracted_text = gr.Textbox(label="Extracted Text", lines=10)
281
  extract_btn = gr.Button("Extract Text")
282
 
283
- # For demonstration, extract text using OnlinePDFLoader
284
  def extract_text(pdf_file):
285
  try:
286
  loader = OnlinePDFLoader(pdf_file.name)
@@ -306,7 +339,7 @@ with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
306
  refresh_log_btn = gr.Button("Refresh Log")
307
  refresh_log_btn.click(get_log, outputs=log_window)
308
 
309
- load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, open_ai_key, relevant_pages], outputs=status)
310
  summarize_pdf_btn.click(summarize_contents, outputs=summary)
311
  submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
312
 
 
4
  import pandas as pd
5
  import sqlite3
6
  import logging
7
+ import requests # for HTTP calls to Gemini
8
 
9
+ from langchain.document_loaders import OnlinePDFLoader # for loading PDF text
10
  from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
11
  from langchain.text_splitter import CharacterTextSplitter
12
+ from langchain_community.vectorstores import Chroma # vectorization from langchain_community
13
  from langchain.chains import RetrievalQA # for QA chain
 
14
  from langchain_core.prompts import PromptTemplate # prompt template import
15
 
16
+ # ------------------------------
17
+ # Gemini API Wrapper
18
+ # ------------------------------
19
+ class ChatGemini:
20
+ def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"):
21
+ self.api_key = api_key
22
+ self.temperature = temperature
23
+ self.model_name = model_name
24
+
25
+ def generate(self, prompt):
26
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}"
27
+ payload = {
28
+ "contents": [{
29
+ "parts": [{"text": prompt}]
30
+ }]
31
+ }
32
+ headers = {"Content-Type": "application/json"}
33
+ response = requests.post(url, json=payload, headers=headers)
34
+ if response.status_code != 200:
35
+ raise Exception(f"Gemini API error: {response.status_code} - {response.text}")
36
+ data = response.json()
37
+ candidate = data.get("candidates", [{}])[0]
38
+ return candidate.get("output", {}).get("text", "No output from Gemini API")
39
+
40
+ def __call__(self, prompt, **kwargs):
41
+ return self.generate(prompt)
42
+
43
+ # ------------------------------
44
+ # Setup Logging
45
+ # ------------------------------
46
  logging.basicConfig(level=logging.INFO)
47
  logger = logging.getLogger(__name__)
48
  log_messages = "" # global log collector
 
52
  log_messages += message + "\n"
53
  logger.info(message)
54
 
55
+ # ------------------------------
56
+ # PDF Embedding & QA Chain (No OCR)
57
+ # ------------------------------
58
+ def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages):
59
  try:
60
+ # Use the PDF file's path to extract text.
 
 
 
61
  pdf_path = pdf_doc.name
62
  loader = OnlinePDFLoader(pdf_path)
63
  pages = loader.load_and_split()
 
88
 
89
  global pdf_qa
90
  pdf_qa = RetrievalQA.from_chain_type(
91
+ llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"),
92
  chain_type="stuff",
93
  retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
94
  chain_type_kwargs=chain_type_kwargs,
95
  return_source_documents=False
96
  )
97
+ update_log("PDF embeddings generated and QA chain initialized using Gemini.")
98
  return "Ready"
99
  except Exception as e:
100
  update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
101
  return f"Error: {str(e)}"
102
 
103
+ # ------------------------------
104
+ # SQLite Question Set Functions
105
+ # ------------------------------
106
  def create_db_connection():
107
  DB_FILE = "./questionset.db"
108
  connection = sqlite3.connect(DB_FILE, check_same_thread=False)
 
258
  def get_log():
259
  return log_messages
260
 
261
+ # ------------------------------
262
+ # Gradio Interface
263
+ # ------------------------------
264
  css = """
265
  #col-container {max-width: 700px; margin: auto;}
266
  """
 
278
 
279
  with gr.Tab("Chatbot"):
280
  with gr.Column():
281
+ gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password")
282
  pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
283
  relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
284
 
 
307
 
308
  log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
309
 
310
+ with gr.Tab("Text Extractor"):
 
311
  with gr.Column():
312
+ image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath')
313
  with gr.Row():
314
  extracted_text = gr.Textbox(label="Extracted Text", lines=10)
315
  extract_btn = gr.Button("Extract Text")
316
 
 
317
  def extract_text(pdf_file):
318
  try:
319
  loader = OnlinePDFLoader(pdf_file.name)
 
339
  refresh_log_btn = gr.Button("Refresh Log")
340
  refresh_log_btn.click(get_log, outputs=log_window)
341
 
342
+ load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status)
343
  summarize_pdf_btn.click(summarize_contents, outputs=summary)
344
  submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
345