Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import time | |
import pandas as pd | |
import sqlite3 | |
import logging | |
import requests # for HTTP calls to Gemini | |
from langchain.document_loaders import OnlinePDFLoader # for loading PDF text | |
from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.vectorstores import Chroma # vectorization from langchain_community | |
from langchain.chains import RetrievalQA # for QA chain | |
from langchain_core.prompts import PromptTemplate # prompt template import | |
# ------------------------------ | |
# Gemini API Wrapper | |
# ------------------------------ | |
class ChatGemini: | |
def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"): | |
self.api_key = api_key | |
self.temperature = temperature | |
self.model_name = model_name | |
def generate(self, prompt): | |
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}" | |
payload = { | |
"contents": [{ | |
"parts": [{"text": prompt}] | |
}] | |
} | |
headers = {"Content-Type": "application/json"} | |
response = requests.post(url, json=payload, headers=headers) | |
if response.status_code != 200: | |
raise Exception(f"Gemini API error: {response.status_code} - {response.text}") | |
data = response.json() | |
candidate = data.get("candidates", [{}])[0] | |
return candidate.get("output", {}).get("text", "No output from Gemini API") | |
def __call__(self, prompt, **kwargs): | |
return self.generate(prompt) | |
# ------------------------------ | |
# Setup Logging | |
# ------------------------------ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
log_messages = "" # global log collector | |
def update_log(message): | |
global log_messages | |
log_messages += message + "\n" | |
logger.info(message) | |
# ------------------------------ | |
# PDF Embedding & QA Chain (No OCR) | |
# ------------------------------ | |
def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages): | |
try: | |
# Use the PDF file's path to extract text. | |
pdf_path = pdf_doc.name | |
loader = OnlinePDFLoader(pdf_path) | |
pages = loader.load_and_split() | |
update_log(f"Extracted text from {len(pages)} pages in {pdf_path}") | |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
pages_to_be_loaded = [] | |
if relevant_pages: | |
for page in relevant_pages.split(","): | |
if page.strip().isdigit(): | |
pageIndex = int(page.strip()) - 1 | |
if 0 <= pageIndex < len(pages): | |
pages_to_be_loaded.append(pages[pageIndex]) | |
if not pages_to_be_loaded: | |
pages_to_be_loaded = pages.copy() | |
update_log("No specific pages selected; using entire PDF.") | |
vectordb = Chroma.from_documents(pages_to_be_loaded, embedding=embeddings) | |
prompt_template = ( | |
"""Use the following context to answer the question. If you do not know the answer, return N/A. | |
{context} | |
Question: {question} | |
Return the answer in JSON format.""" | |
) | |
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
chain_type_kwargs = {"prompt": PROMPT} | |
global pdf_qa | |
pdf_qa = RetrievalQA.from_chain_type( | |
llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"), | |
chain_type="stuff", | |
retriever=vectordb.as_retriever(search_kwargs={"k": 5}), | |
chain_type_kwargs=chain_type_kwargs, | |
return_source_documents=False | |
) | |
update_log("PDF embeddings generated and QA chain initialized using Gemini.") | |
return "Ready" | |
except Exception as e: | |
update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}") | |
return f"Error: {str(e)}" | |
# ------------------------------ | |
# SQLite Question Set Functions | |
# ------------------------------ | |
def create_db_connection(): | |
DB_FILE = "./questionset.db" | |
connection = sqlite3.connect(DB_FILE, check_same_thread=False) | |
return connection | |
def create_sqlite_table(connection): | |
update_log("Creating/Verifying SQLite table for questions.") | |
cursor = connection.cursor() | |
try: | |
cursor.execute('SELECT * FROM questions') | |
cursor.fetchall() | |
except sqlite3.OperationalError: | |
cursor.execute( | |
''' | |
CREATE TABLE questions (document_type TEXT NOT NULL, questionset_tag TEXT NOT NULL, field TEXT NOT NULL, question TEXT NOT NULL) | |
''' | |
) | |
update_log("Questions table created.") | |
connection.commit() | |
def load_master_questionset_into_sqlite(connection): | |
create_sqlite_table(connection) | |
cursor = connection.cursor() | |
masterlist_count = cursor.execute( | |
"SELECT COUNT(document_type) FROM questions WHERE document_type=? AND questionset_tag=?", | |
("DOC_A", "masterlist",) | |
).fetchone()[0] | |
if masterlist_count == 0: | |
update_log("Loading masterlist into DB.") | |
fields, queries = create_field_and_question_list_for_DOC_A() | |
for i in range(len(queries)): | |
cursor.execute( | |
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
["DOC_A", "masterlist", fields[i], queries[i]] | |
) | |
fields2, queries2 = create_field_and_question_list_for_DOC_B() | |
for i in range(len(queries2)): | |
cursor.execute( | |
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
["DOC_B", "masterlist", fields2[i], queries2[i]] | |
) | |
connection.commit() | |
total_questions = cursor.execute("SELECT COUNT(document_type) FROM questions").fetchone()[0] | |
update_log(f"Total questions in DB: {total_questions}") | |
def create_field_and_question_list_for_DOC_A(): | |
# Two sample entries for DOC_A | |
fields = ["Loan Number", "Borrower"] | |
queries = ["What is the Loan Number?", "Who is the Borrower?"] | |
return fields, queries | |
def create_field_and_question_list_for_DOC_B(): | |
# Two sample entries for DOC_B | |
fields = ["Property Address", "Signed Date"] | |
queries = ["What is the Property Address?", "What is the Signed Date?"] | |
return fields, queries | |
def retrieve_document_type_and_questionsettag_from_sqlite(): | |
connection = create_db_connection() | |
load_master_questionset_into_sqlite(connection) | |
cursor = connection.cursor() | |
rows = cursor.execute("SELECT document_type, questionset_tag FROM questions ORDER BY document_type, UPPER(questionset_tag)").fetchall() | |
choices = [] | |
for row in rows: | |
value = f"{row[0]}:{row[1]}" | |
if value not in choices: | |
choices.append(value) | |
update_log(f"Found question set: {value}") | |
connection.close() | |
return gr.Dropdown.update(choices=choices, value=choices[0] if choices else "") | |
def retrieve_fields_and_questions(dropdownoption): | |
splitwords = dropdownoption.split(":") | |
connection = create_db_connection() | |
cursor = connection.cursor() | |
rows = cursor.execute( | |
"SELECT document_type, field, question FROM questions WHERE document_type=? AND questionset_tag=?", | |
(splitwords[0], splitwords[1],) | |
).fetchall() | |
connection.close() | |
return pd.DataFrame(rows, columns=["documentType", "field", "question"]) | |
def add_questionset(data, document_type, tag_for_questionset): | |
connection = create_db_connection() | |
create_sqlite_table(connection) | |
cursor = connection.cursor() | |
for _, row in data.iterrows(): | |
cursor.execute( | |
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)", | |
[document_type, tag_for_questionset, row['field'], row['question']] | |
) | |
connection.commit() | |
connection.close() | |
def load_csv_and_store_questionset_into_sqlite(csv_file, document_type, tag_for_questionset): | |
if tag_for_questionset and document_type: | |
data = pd.read_csv(csv_file.name) | |
add_questionset(data, document_type, tag_for_questionset) | |
response = f"Uploaded {data.shape[0]} fields and questions for {document_type}:{tag_for_questionset}" | |
update_log(response) | |
return response | |
else: | |
return "Please select a Document Type and provide a name for the Question Set" | |
def answer_predefined_questions(document_type_and_questionset): | |
splitwords = document_type_and_questionset.split(":") | |
document_type = splitwords[0] | |
question_set = splitwords[1] | |
fields, questions, responses = [], [], [] | |
connection = create_db_connection() | |
cursor = connection.cursor() | |
rows = cursor.execute( | |
"SELECT field, question FROM questions WHERE document_type=? AND questionset_tag=?", | |
(document_type, question_set) | |
).fetchall() | |
connection.close() | |
for field, question in rows: | |
fields.append(field) | |
questions.append(question) | |
try: | |
responses.append(pdf_qa.run(question)) | |
except Exception as e: | |
err = f"Error: {str(e)}" | |
update_log(err) | |
responses.append(err) | |
return pd.DataFrame({"Field": fields, "Question": questions, "Response": responses}) | |
def summarize_contents(): | |
question = "Generate a short summary of the contents along with up to 3 example questions." | |
if 'pdf_qa' not in globals(): | |
return "Error: PDF embeddings not generated. Load a PDF first." | |
try: | |
response = pdf_qa.run(question) | |
update_log("Summarization successful.") | |
return response | |
except Exception as e: | |
err = f"Error in summarization: {str(e)}" | |
update_log(err) | |
return err | |
def answer_query(query): | |
if 'pdf_qa' not in globals(): | |
return "Error: PDF embeddings not generated. Load a PDF first." | |
try: | |
response = pdf_qa.run(query) | |
update_log(f"Query answered: {query}") | |
return response | |
except Exception as e: | |
err = f"Error in answering query: {str(e)}" | |
update_log(err) | |
return err | |
def get_log(): | |
return log_messages | |
# ------------------------------ | |
# Gradio Interface | |
# ------------------------------ | |
css = """ | |
#col-container {max-width: 700px; margin: auto;} | |
""" | |
title = """ | |
<div style="text-align: center;"> | |
<h1>AskMoli - Chatbot for PDFs</h1> | |
<p>Upload a PDF and generate embeddings. Then ask questions or use a predefined set.</p> | |
</div> | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(title) | |
with gr.Tab("Chatbot"): | |
with gr.Column(): | |
gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password") | |
pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath') | |
relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers") | |
with gr.Row(): | |
status = gr.Textbox(label="Status", interactive=False) | |
load_pdf_btn = gr.Button("Upload PDF & Generate Embeddings") | |
with gr.Row(): | |
summary = gr.Textbox(label="Summary") | |
summarize_pdf_btn = gr.Button("Summarize Contents") | |
with gr.Row(): | |
input_query = gr.Textbox(label="Your Question") | |
output_answer = gr.Textbox(label="Answer") | |
submit_query_btn = gr.Button("Submit Question") | |
with gr.Row(): | |
questionsets = gr.Dropdown(label="Pre-defined Question Sets", choices=[]) | |
load_questionsets_btn = gr.Button("Retrieve Sets") | |
fields_and_questions = gr.Dataframe(label="Fields & Questions") | |
load_fields_btn = gr.Button("Retrieve Questions") | |
with gr.Row(): | |
answers_df = gr.Dataframe(label="Pre-defined Answers") | |
answer_predefined_btn = gr.Button("Get Answers") | |
log_window = gr.Textbox(label="Log Window", interactive=False, lines=10) | |
with gr.Tab("Text Extractor"): | |
with gr.Column(): | |
image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath') | |
with gr.Row(): | |
extracted_text = gr.Textbox(label="Extracted Text", lines=10) | |
extract_btn = gr.Button("Extract Text") | |
def extract_text(pdf_file): | |
try: | |
loader = OnlinePDFLoader(pdf_file.name) | |
docs = loader.load_and_split() | |
text = "\n".join([doc.page_content for doc in docs]) | |
update_log(f"Extracted text from {len(docs)} pages.") | |
return text | |
except Exception as e: | |
err = f"Error extracting text: {str(e)}" | |
update_log(err) | |
return err | |
extract_btn.click(extract_text, inputs=image_pdf, outputs=extracted_text) | |
with gr.Tab("Upload Question Set"): | |
with gr.Column(): | |
document_type_for_questionset = gr.Dropdown(choices=["DOC_A", "DOC_B"], label="Select Document Type") | |
tag_for_questionset = gr.Textbox(label="Name for Question Set (e.g., basic-set)") | |
csv_file = gr.File(label="Load CSV (fields,question)", file_types=['.csv'], type='filepath') | |
with gr.Row(): | |
status_for_csv = gr.Textbox(label="Status", interactive=False) | |
load_csv_btn = gr.Button("Upload CSV into DB") | |
refresh_log_btn = gr.Button("Refresh Log") | |
refresh_log_btn.click(get_log, outputs=log_window) | |
load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status) | |
summarize_pdf_btn.click(summarize_contents, outputs=summary) | |
submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer) | |
load_questionsets_btn.click(retrieve_document_type_and_questionsettag_from_sqlite, outputs=questionsets) | |
load_fields_btn.click(retrieve_fields_and_questions, inputs=questionsets, outputs=fields_and_questions) | |
answer_predefined_btn.click(answer_predefined_questions, inputs=questionsets, outputs=answers_df) | |
load_csv_btn.click(load_csv_and_store_questionset_into_sqlite, inputs=[csv_file, document_type_for_questionset, tag_for_questionset], outputs=status_for_csv) | |
demo.launch(debug=True) | |