|
import os |
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from rag_system import load_retrieval_qa_chain, get_answer, update_embeddings |
|
import json |
|
import re |
|
from PyPDF2 import PdfReader |
|
from PIL import Image |
|
import io |
|
from pydantic_settings import BaseSettings |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
os.environ["OPENAI_API_KEY"] = openai_api_key |
|
|
|
|
|
static_directory = "static" |
|
if not os.path.exists(static_directory): |
|
os.makedirs(static_directory) |
|
|
|
|
|
def get_pdf_page_count(file_path): |
|
with open(file_path, 'rb') as file: |
|
pdf = PdfReader(file) |
|
return len(pdf.pages) |
|
|
|
def render_pdf_page(file_path, page_num): |
|
import fitz |
|
doc = fitz.open(file_path) |
|
page = doc.load_page(page_num - 1) |
|
pix = page.get_pixmap() |
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return img |
|
|
|
|
|
def load_pdf_data(): |
|
pdf_data = {} |
|
pdf_files = [f for f in os.listdir("./documents") if f.endswith('.pdf')] |
|
for pdf_file in pdf_files: |
|
file_path = f"./documents/{pdf_file}" |
|
pdf_data[pdf_file] = { |
|
'path': file_path, |
|
'num_pages': get_pdf_page_count(file_path) |
|
} |
|
return pdf_data |
|
|
|
|
|
update_embeddings() |
|
|
|
|
|
qa_chain = load_retrieval_qa_chain() |
|
pdf_data = load_pdf_data() |
|
|
|
def pdf_viewer_interface(pdf_state, page_number, action=None, page_input=None): |
|
selected_pdf = pdf_state['selected_pdf'] |
|
current_page = page_number |
|
max_pages = pdf_data[selected_pdf]['num_pages'] |
|
|
|
if action == "prev": |
|
current_page = max(1, current_page - 1) |
|
elif action == "next": |
|
current_page = min(max_pages, current_page + 1) |
|
elif page_input is not None: |
|
try: |
|
current_page = int(page_input) |
|
current_page = max(1, min(current_page, max_pages)) |
|
except ValueError: |
|
pass |
|
|
|
pdf_state['page_number'] = current_page |
|
pdf_path = pdf_data[selected_pdf]['path'] |
|
img = render_pdf_page(pdf_path, current_page) |
|
return img, current_page, str(current_page) |
|
|
|
def chat_interface(user_input, chat_history, pdf_state): |
|
chat_history_list = [item for sublist in chat_history for item in sublist] |
|
|
|
response = get_answer(qa_chain, user_input, chat_history_list) |
|
full_response = response["answer"] |
|
sources = response["sources"] |
|
|
|
chat_history.append((user_input, full_response)) |
|
return chat_history, sources |
|
|
|
def handle_source_click(evt: gr.SelectData, sources, pdf_state, page_number): |
|
index = evt.index[0] if isinstance(evt.index, list) else evt.index |
|
|
|
if index >= len(sources): |
|
return None, pdf_state, page_number, "" |
|
|
|
source = sources[index] |
|
file_name, page_str = source.split(' (Page ') |
|
page_str = page_str.rstrip(')') |
|
page = int(page_str) |
|
|
|
if file_name not in pdf_data: |
|
return None, pdf_state, page_number, "" |
|
|
|
pdf_state['selected_pdf'] = file_name |
|
pdf_state['page_number'] = page |
|
pdf_path = pdf_data[file_name]['path'] |
|
img = render_pdf_page(pdf_path, page) |
|
return img, pdf_state, page, str(page) |
|
|
|
with gr.Blocks() as demo: |
|
initial_pdf = list(pdf_data.keys())[0] |
|
pdf_state = gr.State({'selected_pdf': initial_pdf, 'page_number': 1}) |
|
sources = gr.State([]) |
|
page_number = gr.State(1) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chat_history = gr.State([]) |
|
chatbot = gr.Chatbot() |
|
user_input = gr.Textbox(show_label=False, placeholder="Enter your question...") |
|
source_list = gr.Dataframe( |
|
headers=["Source", "Page"], |
|
datatype=["str", "number"], |
|
row_count=4, |
|
col_count=2, |
|
interactive=False, |
|
label="Sources" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
pdf_dropdown = gr.Dropdown(choices=list(pdf_data.keys()), label="Select PDF", value=initial_pdf) |
|
pdf_viewer = gr.Image(label="PDF Viewer", height=600) |
|
pdf_page = gr.Number(label="Page Number", value=1) |
|
with gr.Row(): |
|
prev_button = gr.Button("Previous Page") |
|
next_button = gr.Button("Next Page") |
|
|
|
user_input.submit(chat_interface, [user_input, chat_history, pdf_state], [chatbot, sources]).then( |
|
lambda s: [[src.split(' (Page ')[0], int(src.split(' (Page ')[1].rstrip(')'))] for src in s], |
|
inputs=[sources], |
|
outputs=[source_list] |
|
) |
|
|
|
source_list.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page]) |
|
|
|
pdf_dropdown.change( |
|
lambda x: {'selected_pdf': x, 'page_number': 1}, |
|
inputs=[pdf_dropdown], |
|
outputs=[pdf_state] |
|
).then( |
|
pdf_viewer_interface, |
|
inputs=[pdf_state, gr.State(1)], |
|
outputs=[pdf_viewer, page_number, pdf_page] |
|
) |
|
|
|
prev_button.click( |
|
pdf_viewer_interface, |
|
inputs=[pdf_state, page_number, gr.State("prev")], |
|
outputs=[pdf_viewer, page_number, pdf_page] |
|
) |
|
|
|
next_button.click( |
|
pdf_viewer_interface, |
|
inputs=[pdf_state, page_number, gr.State("next")], |
|
outputs=[pdf_viewer, page_number, pdf_page] |
|
) |
|
|
|
pdf_page.submit( |
|
pdf_viewer_interface, |
|
inputs=[pdf_state, page_number, gr.State(None), pdf_page], |
|
outputs=[pdf_viewer, page_number, pdf_page] |
|
) |
|
|
|
chatbot.select(handle_source_click, [sources, pdf_state, page_number], [pdf_viewer, pdf_state, page_number, pdf_page]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |