ChromaDB_HNM / app.py
Omarrran's picture
Update app.py
203d168 verified
import gradio as gr
import os
import time
import pandas as pd
import sqlite3
import logging
import requests # for HTTP calls to Gemini
from langchain.document_loaders import OnlinePDFLoader # for loading PDF text
from langchain.embeddings import HuggingFaceEmbeddings # open source embedding model
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma # vectorization from langchain_community
from langchain.chains import RetrievalQA # for QA chain
from langchain_core.prompts import PromptTemplate # prompt template import
# ------------------------------
# Gemini API Wrapper
# ------------------------------
class ChatGemini:
def __init__(self, api_key, temperature=0, model_name="gemini-2.0-flash"):
self.api_key = api_key
self.temperature = temperature
self.model_name = model_name
def generate(self, prompt):
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={self.api_key}"
payload = {
"contents": [{
"parts": [{"text": prompt}]
}]
}
headers = {"Content-Type": "application/json"}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
raise Exception(f"Gemini API error: {response.status_code} - {response.text}")
data = response.json()
candidate = data.get("candidates", [{}])[0]
return candidate.get("output", {}).get("text", "No output from Gemini API")
def __call__(self, prompt, **kwargs):
return self.generate(prompt)
# ------------------------------
# Setup Logging
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
log_messages = "" # global log collector
def update_log(message):
global log_messages
log_messages += message + "\n"
logger.info(message)
# ------------------------------
# PDF Embedding & QA Chain (No OCR)
# ------------------------------
def load_pdf_and_generate_embeddings(pdf_doc, gemini_api_key, relevant_pages):
try:
# Use the PDF file's path to extract text.
pdf_path = pdf_doc.name
loader = OnlinePDFLoader(pdf_path)
pages = loader.load_and_split()
update_log(f"Extracted text from {len(pages)} pages in {pdf_path}")
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
pages_to_be_loaded = []
if relevant_pages:
for page in relevant_pages.split(","):
if page.strip().isdigit():
pageIndex = int(page.strip()) - 1
if 0 <= pageIndex < len(pages):
pages_to_be_loaded.append(pages[pageIndex])
if not pages_to_be_loaded:
pages_to_be_loaded = pages.copy()
update_log("No specific pages selected; using entire PDF.")
vectordb = Chroma.from_documents(pages_to_be_loaded, embedding=embeddings)
prompt_template = (
"""Use the following context to answer the question. If you do not know the answer, return N/A.
{context}
Question: {question}
Return the answer in JSON format."""
)
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain_type_kwargs = {"prompt": PROMPT}
global pdf_qa
pdf_qa = RetrievalQA.from_chain_type(
llm=ChatGemini(api_key=gemini_api_key, temperature=0, model_name="gemini-2.0-flash"),
chain_type="stuff",
retriever=vectordb.as_retriever(search_kwargs={"k": 5}),
chain_type_kwargs=chain_type_kwargs,
return_source_documents=False
)
update_log("PDF embeddings generated and QA chain initialized using Gemini.")
return "Ready"
except Exception as e:
update_log(f"Error in load_pdf_and_generate_embeddings: {str(e)}")
return f"Error: {str(e)}"
# ------------------------------
# SQLite Question Set Functions
# ------------------------------
def create_db_connection():
DB_FILE = "./questionset.db"
connection = sqlite3.connect(DB_FILE, check_same_thread=False)
return connection
def create_sqlite_table(connection):
update_log("Creating/Verifying SQLite table for questions.")
cursor = connection.cursor()
try:
cursor.execute('SELECT * FROM questions')
cursor.fetchall()
except sqlite3.OperationalError:
cursor.execute(
'''
CREATE TABLE questions (document_type TEXT NOT NULL, questionset_tag TEXT NOT NULL, field TEXT NOT NULL, question TEXT NOT NULL)
'''
)
update_log("Questions table created.")
connection.commit()
def load_master_questionset_into_sqlite(connection):
create_sqlite_table(connection)
cursor = connection.cursor()
masterlist_count = cursor.execute(
"SELECT COUNT(document_type) FROM questions WHERE document_type=? AND questionset_tag=?",
("DOC_A", "masterlist",)
).fetchone()[0]
if masterlist_count == 0:
update_log("Loading masterlist into DB.")
fields, queries = create_field_and_question_list_for_DOC_A()
for i in range(len(queries)):
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
["DOC_A", "masterlist", fields[i], queries[i]]
)
fields2, queries2 = create_field_and_question_list_for_DOC_B()
for i in range(len(queries2)):
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
["DOC_B", "masterlist", fields2[i], queries2[i]]
)
connection.commit()
total_questions = cursor.execute("SELECT COUNT(document_type) FROM questions").fetchone()[0]
update_log(f"Total questions in DB: {total_questions}")
def create_field_and_question_list_for_DOC_A():
# Two sample entries for DOC_A
fields = ["Loan Number", "Borrower"]
queries = ["What is the Loan Number?", "Who is the Borrower?"]
return fields, queries
def create_field_and_question_list_for_DOC_B():
# Two sample entries for DOC_B
fields = ["Property Address", "Signed Date"]
queries = ["What is the Property Address?", "What is the Signed Date?"]
return fields, queries
def retrieve_document_type_and_questionsettag_from_sqlite():
connection = create_db_connection()
load_master_questionset_into_sqlite(connection)
cursor = connection.cursor()
rows = cursor.execute("SELECT document_type, questionset_tag FROM questions ORDER BY document_type, UPPER(questionset_tag)").fetchall()
choices = []
for row in rows:
value = f"{row[0]}:{row[1]}"
if value not in choices:
choices.append(value)
update_log(f"Found question set: {value}")
connection.close()
return gr.Dropdown.update(choices=choices, value=choices[0] if choices else "")
def retrieve_fields_and_questions(dropdownoption):
splitwords = dropdownoption.split(":")
connection = create_db_connection()
cursor = connection.cursor()
rows = cursor.execute(
"SELECT document_type, field, question FROM questions WHERE document_type=? AND questionset_tag=?",
(splitwords[0], splitwords[1],)
).fetchall()
connection.close()
return pd.DataFrame(rows, columns=["documentType", "field", "question"])
def add_questionset(data, document_type, tag_for_questionset):
connection = create_db_connection()
create_sqlite_table(connection)
cursor = connection.cursor()
for _, row in data.iterrows():
cursor.execute(
"INSERT INTO questions(document_type, questionset_tag, field, question) VALUES(?,?,?,?)",
[document_type, tag_for_questionset, row['field'], row['question']]
)
connection.commit()
connection.close()
def load_csv_and_store_questionset_into_sqlite(csv_file, document_type, tag_for_questionset):
if tag_for_questionset and document_type:
data = pd.read_csv(csv_file.name)
add_questionset(data, document_type, tag_for_questionset)
response = f"Uploaded {data.shape[0]} fields and questions for {document_type}:{tag_for_questionset}"
update_log(response)
return response
else:
return "Please select a Document Type and provide a name for the Question Set"
def answer_predefined_questions(document_type_and_questionset):
splitwords = document_type_and_questionset.split(":")
document_type = splitwords[0]
question_set = splitwords[1]
fields, questions, responses = [], [], []
connection = create_db_connection()
cursor = connection.cursor()
rows = cursor.execute(
"SELECT field, question FROM questions WHERE document_type=? AND questionset_tag=?",
(document_type, question_set)
).fetchall()
connection.close()
for field, question in rows:
fields.append(field)
questions.append(question)
try:
responses.append(pdf_qa.run(question))
except Exception as e:
err = f"Error: {str(e)}"
update_log(err)
responses.append(err)
return pd.DataFrame({"Field": fields, "Question": questions, "Response": responses})
def summarize_contents():
question = "Generate a short summary of the contents along with up to 3 example questions."
if 'pdf_qa' not in globals():
return "Error: PDF embeddings not generated. Load a PDF first."
try:
response = pdf_qa.run(question)
update_log("Summarization successful.")
return response
except Exception as e:
err = f"Error in summarization: {str(e)}"
update_log(err)
return err
def answer_query(query):
if 'pdf_qa' not in globals():
return "Error: PDF embeddings not generated. Load a PDF first."
try:
response = pdf_qa.run(query)
update_log(f"Query answered: {query}")
return response
except Exception as e:
err = f"Error in answering query: {str(e)}"
update_log(err)
return err
def get_log():
return log_messages
# ------------------------------
# Gradio Interface
# ------------------------------
css = """
#col-container {max-width: 700px; margin: auto;}
"""
title = """
<div style="text-align: center;">
<h1>AskMoli - Chatbot for PDFs</h1>
<p>Upload a PDF and generate embeddings. Then ask questions or use a predefined set.</p>
</div>
"""
with gr.Blocks(css=css, theme=gr.themes.Monochrome()) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
with gr.Tab("Chatbot"):
with gr.Column():
gemini_api_key = gr.Textbox(label="Your Gemini API Key", type="password")
pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type='filepath')
relevant_pages = gr.Textbox(label="Optional: Comma separated page numbers")
with gr.Row():
status = gr.Textbox(label="Status", interactive=False)
load_pdf_btn = gr.Button("Upload PDF & Generate Embeddings")
with gr.Row():
summary = gr.Textbox(label="Summary")
summarize_pdf_btn = gr.Button("Summarize Contents")
with gr.Row():
input_query = gr.Textbox(label="Your Question")
output_answer = gr.Textbox(label="Answer")
submit_query_btn = gr.Button("Submit Question")
with gr.Row():
questionsets = gr.Dropdown(label="Pre-defined Question Sets", choices=[])
load_questionsets_btn = gr.Button("Retrieve Sets")
fields_and_questions = gr.Dataframe(label="Fields & Questions")
load_fields_btn = gr.Button("Retrieve Questions")
with gr.Row():
answers_df = gr.Dataframe(label="Pre-defined Answers")
answer_predefined_btn = gr.Button("Get Answers")
log_window = gr.Textbox(label="Log Window", interactive=False, lines=10)
with gr.Tab("Text Extractor"):
with gr.Column():
image_pdf = gr.File(label="Load PDF for Text Extraction", file_types=['.pdf'], type='filepath')
with gr.Row():
extracted_text = gr.Textbox(label="Extracted Text", lines=10)
extract_btn = gr.Button("Extract Text")
def extract_text(pdf_file):
try:
loader = OnlinePDFLoader(pdf_file.name)
docs = loader.load_and_split()
text = "\n".join([doc.page_content for doc in docs])
update_log(f"Extracted text from {len(docs)} pages.")
return text
except Exception as e:
err = f"Error extracting text: {str(e)}"
update_log(err)
return err
extract_btn.click(extract_text, inputs=image_pdf, outputs=extracted_text)
with gr.Tab("Upload Question Set"):
with gr.Column():
document_type_for_questionset = gr.Dropdown(choices=["DOC_A", "DOC_B"], label="Select Document Type")
tag_for_questionset = gr.Textbox(label="Name for Question Set (e.g., basic-set)")
csv_file = gr.File(label="Load CSV (fields,question)", file_types=['.csv'], type='filepath')
with gr.Row():
status_for_csv = gr.Textbox(label="Status", interactive=False)
load_csv_btn = gr.Button("Upload CSV into DB")
refresh_log_btn = gr.Button("Refresh Log")
refresh_log_btn.click(get_log, outputs=log_window)
load_pdf_btn.click(load_pdf_and_generate_embeddings, inputs=[pdf_doc, gemini_api_key, relevant_pages], outputs=status)
summarize_pdf_btn.click(summarize_contents, outputs=summary)
submit_query_btn.click(answer_query, inputs=input_query, outputs=output_answer)
load_questionsets_btn.click(retrieve_document_type_and_questionsettag_from_sqlite, outputs=questionsets)
load_fields_btn.click(retrieve_fields_and_questions, inputs=questionsets, outputs=fields_and_questions)
answer_predefined_btn.click(answer_predefined_questions, inputs=questionsets, outputs=answers_df)
load_csv_btn.click(load_csv_and_store_questionset_into_sqlite, inputs=[csv_file, document_type_for_questionset, tag_for_questionset], outputs=status_for_csv)
demo.launch(debug=True)