Spaces:
Sleeping
Sleeping
File size: 14,699 Bytes
e603781 26cbdf6 203d168 e603781 203d168 e603781 203d168 8c65552 8b91948 e603781 203d168 26cbdf6 8c65552 26cbdf6 203d168 26cbdf6 203d168 8b91948 26cbdf6 8b91948 26cbdf6 8c65552 26cbdf6 e603781 26cbdf6 8c65552 26cbdf6 8c65552 26cbdf6 203d168 26cbdf6 203d168 26cbdf6 e603781 203d168 e603781 26cbdf6 e603781 26cbdf6 e603781 26cbdf6 e603781 8c65552 26cbdf6 8c65552 e603781 8c65552 e603781 8c65552 e603781 8c65552 e603781 8c65552 e603781 26cbdf6 e603781 8c65552 8b91948 8c65552 e603781 8c65552 8b91948 8c65552 e603781 26cbdf6 8c65552 e603781 8c65552 e603781 8c65552 26cbdf6 e603781 8c65552 e603781 8c65552 e603781 26cbdf6 8c65552 26cbdf6 8c65552 e603781 26cbdf6 e603781 26cbdf6 8c65552 26cbdf6 8c65552 26cbdf6 8c65552 26cbdf6 8c65552 e603781 8c65552 26cbdf6 8c65552 e603781 8c65552 26cbdf6 8c65552 26cbdf6 203d168 e603781 8c65552 e603781 8c65552 e603781 8c65552 e603781 cf79553 e603781 203d168 26cbdf6 8c65552 e603781 26cbdf6 8c65552 21bef18 e603781 26cbdf6 21bef18 e603781 8c65552 26cbdf6 8c65552 21bef18 e603781 26cbdf6 8c65552 e603781 8c65552 26cbdf6 203d168 e603781 203d168 e603781 8b91948 26cbdf6 e603781 8c65552 e603781 26cbdf6 8c65552 203d168 26cbdf6 e603781 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
import gradio as gr
import os
import time
import pandas as pd
import sqlite3
import logging
import requests # for HTTP calls to Gemini
from langchain.document_loaders import OnlinePDFLoader # for loading PDF text
from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma # vectorization from langchain_community
from langchain.chains import RetrievalQA # for QA chain
from langchain_core.prompts import PromptTemplate # prompt template import
# ------------------------------
# Gemini API Wrapper
# ------------------------------
class ChatGemini:
def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"):
self.api_key = api_key
self.temperature = temperature
self.model_name = model_name
def generate(self, prompt):
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}"
payload = {
"contents": [{
"parts": [{"text": prompt}]
}]
}
headers = {"Content-Type": "application/json"}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
raise Exception(f"Gemini API error: {response.status_code} - {response.text}")
data = response.json()
candidate = data.get("candidates", [{}])[0]
return candidate.get("output", {}).get("text", "No output from Gemini API")
def __call__(self, prompt, **kwargs):
return self.generate(prompt)
# ------------------------------
# Setup Logging
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
log_messages = "" # global log collector
def update_log(message):
global log_messages
log_messages += message + "\n"
logger.info(message)
# ------------------------------
# PDF Embedding & QA Chain (No OCR)
# ------------------------------
def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages):
try:
# Use the PDF file's path to extract text.
pdf_path = pdf_doc.name
loader = OnlinePDFLoader(pdf_path)
pages = loader.load_and_split()
update_log(f"Extracted text from {len(pages)} pages in {pdf_path}")
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
pages_to_be_loaded = []
if relevant_pages:
for page in relevant_pages.split(","):
if page.strip().isdigit():
pageIndex = int(page.strip()) - 1
if 0 <= pageIndex < len(pages):
pages_to_be_loaded.append(pages[pageIndex])
if not pages_to_be_loaded:
pages_to_be_loaded = pages.copy()
update_log("No specific pages selected; using entire PDF.")
vectordb = Chroma.from_documents(pages_to_be_loaded, embedding=embeddings)
prompt_template = (
"""Use the following context to answer the question. If you do not know the answer, return N/A.
{context}
Question: {question}
Return the answer in JSON format."""
)
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain_type_kwargs = {"prompt": PROMPT}
global pdf_qa
pdf_qa = RetrievalQA.from_chain_type(
llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"),
chain_type="stuff",
retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
chain_type_kwargs=chain_type_kwargs,
return_source_documents=False
)
update_log("PDF embeddings generated and QA chain initialized using Gemini.")
return "Ready"
except Exception as e:
update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
return f"Error: {str(e)}"
# ------------------------------
# SQLite Question Set Functions
# ------------------------------
def create_db_connection():
DB_FILE = "./questionset.db"
connection = sqlite3.connect(DB_FILE, check_same_thread=False)
return connection
def create_sqlite_table(connection):
update_log("Creating/Verifying SQLite table for questions.")
cursor = connection.cursor()
try:
cursor.execute('SELECT * FROM questions')
cursor.fetchall()
except sqlite3.OperationalError:
cursor.execute(
'''
CREATE TABLE questions (document_type TEXT NOT NULL, questionset_tag TEXT NOT NULL, field TEXT NOT NULL, question TEXT NOT NULL)
'''
)
update_log("Questions table created.")
connection.commit()
def load_master_questionset_into_sqlite(connection):
create_sqlite_table(connection)
cursor = connection.cursor()
masterlist_count = cursor.execute(
"SELECT COUNT(document_type) FROM questions WHERE document_type=? AND questionset_tag=?",
("DOC_A", "masterlist",)
).fetchone()[0]
if masterlist_count == 0:
update_log("Loading masterlist into DB.")
fields, queries = create_field_and_question_list_for_DOC_A()
for i in range(len(queries)):
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
["DOC_A", "masterlist", fields[i], queries[i]]
)
fields2, queries2 = create_field_and_question_list_for_DOC_B()
for i in range(len(queries2)):
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
["DOC_B", "masterlist", fields2[i], queries2[i]]
)
connection.commit()
total_questions = cursor.execute("SELECT COUNT(document_type) FROM questions").fetchone()[0]
update_log(f"Total questions in DB: {total_questions}")
def create_field_and_question_list_for_DOC_A():
# Two sample entries for DOC_A
fields = ["Loan Number", "Borrower"]
queries = ["What is the Loan Number?", "Who is the Borrower?"]
return fields, queries
def create_field_and_question_list_for_DOC_B():
# Two sample entries for DOC_B
fields = ["Property Address", "Signed Date"]
queries = ["What is the Property Address?", "What is the Signed Date?"]
return fields, queries
def retrieve_document_type_and_questionsettag_from_sqlite():
connection = create_db_connection()
load_master_questionset_into_sqlite(connection)
cursor = connection.cursor()
rows = cursor.execute("SELECT document_type, questionset_tag FROM questions ORDER BY document_type, UPPER(questionset_tag)").fetchall()
choices = []
for row in rows:
value = f"{row[0]}:{row[1]}"
if value not in choices:
choices.append(value)
update_log(f"Found question set: {value}")
connection.close()
return gr.Dropdown.update(choices=choices, value=choices[0] if choices else "")
def retrieve_fields_and_questions(dropdownoption):
splitwords = dropdownoption.split(":")
connection = create_db_connection()
cursor = connection.cursor()
rows = cursor.execute(
"SELECT document_type, field, question FROM questions WHERE document_type=? AND questionset_tag=?",
(splitwords[0], splitwords[1],)
).fetchall()
connection.close()
return pd.DataFrame(rows, columns=["documentType", "field", "question"])
def add_questionset(data, document_type, tag_for_questionset):
connection = create_db_connection()
create_sqlite_table(connection)
cursor = connection.cursor()
for _, row in data.iterrows():
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
[document_type, tag_for_questionset, row['field'], row['question']]
)
connection.commit()
connection.close()
def load_csv_and_store_questionset_into_sqlite(csv_file, document_type, tag_for_questionset):
if tag_for_questionset and document_type:
data = pd.read_csv(csv_file.name)
add_questionset(data, document_type, tag_for_questionset)
response = f"Uploaded {data.shape[0]} fields and questions for {document_type}:{tag_for_questionset}"
update_log(response)
return response
else:
return "Please select a Document Type and provide a name for the Question Set"
def answer_predefined_questions(document_type_and_questionset):
splitwords = document_type_and_questionset.split(":")
document_type = splitwords[0]
question_set = splitwords[1]
fields, questions, responses = [], [], []
connection = create_db_connection()
cursor = connection.cursor()
rows = cursor.execute(
"SELECT field, question FROM questions WHERE document_type=? AND questionset_tag=?",
(document_type, question_set)
).fetchall()
connection.close()
for field, question in rows:
fields.append(field)
questions.append(question)
try:
responses.append(pdf_qa.run(question))
except Exception as e:
err = f"Error: {str(e)}"
update_log(err)
responses.append(err)
return pd.DataFrame({"Field": fields, "Question": questions, "Response": responses})
def summarize_contents():
question = "Generate a short summary of the contents along with up to 3 example questions."
if 'pdf_qa' not in globals():
return "Error: PDF embeddings not generated. Load a PDF first."
try:
response = pdf_qa.run(question)
update_log("Summarization successful.")
return response
except Exception as e:
err = f"Error in summarization: {str(e)}"
update_log(err)
return err
def answer_query(query):
if 'pdf_qa' not in globals():
return "Error: PDF embeddings not generated. Load a PDF first."
try:
response = pdf_qa.run(query)
update_log(f"Query answered: {query}")
return response
except Exception as e:
err = f"Error in answering query: {str(e)}"
update_log(err)
return err
def get_log():
return log_messages
# ------------------------------
# Gradio Interface
# ------------------------------
css = """
#col-container {max-width: 700px; margin: auto;}
"""
title = """
<div style="text-align: center;">
<h1>AskMoli - Chatbot for PDFs</h1>
<p>Upload a PDF and generate embeddings. Then ask questions or use a predefined set.</p>
</div>
"""
with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
with gr.Tab("Chatbot"):
with gr.Column():
gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password")
pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
with gr.Row():
status = gr.Textbox(label="Status", interactive=False)
load_pdf_btn = gr.Button("Upload PDF & Generate Embeddings")
with gr.Row():
summary = gr.Textbox(label="Summary")
summarize_pdf_btn = gr.Button("Summarize Contents")
with gr.Row():
input_query = gr.Textbox(label="Your Question")
output_answer = gr.Textbox(label="Answer")
submit_query_btn = gr.Button("Submit Question")
with gr.Row():
questionsets = gr.Dropdown(label="Pre-defined Question Sets", choices=[])
load_questionsets_btn = gr.Button("Retrieve Sets")
fields_and_questions = gr.Dataframe(label="Fields & Questions")
load_fields_btn = gr.Button("Retrieve Questions")
with gr.Row():
answers_df = gr.Dataframe(label="Pre-defined Answers")
answer_predefined_btn = gr.Button("Get Answers")
log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
with gr.Tab("Text Extractor"):
with gr.Column():
image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath')
with gr.Row():
extracted_text = gr.Textbox(label="Extracted Text", lines=10)
extract_btn = gr.Button("Extract Text")
def extract_text(pdf_file):
try:
loader = OnlinePDFLoader(pdf_file.name)
docs = loader.load_and_split()
text = "\n".join([doc.page_content for doc in docs])
update_log(f"Extracted text from {len(docs)} pages.")
return text
except Exception as e:
err = f"Error extracting text: {str(e)}"
update_log(err)
return err
extract_btn.click(extract_text, inputs=image_pdf, outputs=extracted_text)
with gr.Tab("Upload Question Set"):
with gr.Column():
document_type_for_questionset = gr.Dropdown(choices=["DOC_A", "DOC_B"], label="Select Document Type")
tag_for_questionset = gr.Textbox(label="Name for Question Set (e.g., basic-set)")
csv_file = gr.File(label="Load CSV (fields,question)", file_types=['.csv'], type='filepath')
with gr.Row():
status_for_csv = gr.Textbox(label="Status", interactive=False)
load_csv_btn = gr.Button("Upload CSV into DB")
refresh_log_btn = gr.Button("Refresh Log")
refresh_log_btn.click(get_log, outputs=log_window)
load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status)
summarize_pdf_btn.click(summarize_contents, outputs=summary)
submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
load_questionsets_btn.click(retrieve_document_type_and_questionsettag_from_sqlite, outputs=questionsets)
load_fields_btn.click(retrieve_fields_and_questions, inputs=questionsets, outputs=fields_and_questions)
answer_predefined_btn.click(answer_predefined_questions, inputs=questionsets, outputs=answers_df)
load_csv_btn.click(load_csv_and_store_questionset_into_sqlite, inputs=[csv_file, document_type_for_questionset, tag_for_questionset], outputs=status_for_csv)
demo.launch(debug=True)
|