import os os.system("python -m spacy download en_core_web_sm") import io import base64 import streamlit as st import numpy as np import fitz # PyMuPDF import tempfile from ultralytics import YOLO from sklearn.cluster import KMeans from sklearn.metrics.pairwise import cosine_similarity from langchain_core.output_parsers import StrOutputParser from langchain_community.document_loaders import PyMuPDFLoader from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import SpacyTextSplitter from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI import re from PIL import Image from streamlit_chat import message # Load the trained model model = YOLO("best.pt") openai_api_key = os.environ.get("openai_api_key") # Define the class indices for figures, tables, and text figure_class_index = 4 table_class_index = 3 # Utility functions def clean_text(text): return re.sub(r'\s+', ' ', text).strip() def remove_references(text): reference_patterns = [ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b', r'\bWorks Cited\b', r'\bReference\b', r'\breference\b' ] lines = text.split('\n') for i, line in enumerate(lines): if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns): return '\n'.join(lines[:i]) return text def save_uploaded_file(uploaded_file): temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file.write(uploaded_file.getbuffer()) temp_file.close() return temp_file.name def summarize_pdf(pdf_file_path, num_clusters=10): embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key) llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3) prompt = ChatPromptTemplate.from_template( """Could you please provide a concise and comprehensive summary of the given Contexts? The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately. Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section. The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long. example of summary: ## Summary: ## Key points: Contexts: {topic}""" ) output_parser = StrOutputParser() chain = prompt | llm | output_parser loader = PyMuPDFLoader(pdf_file_path) docs = loader.load() full_text = "\n".join(doc.page_content for doc in docs) cleaned_full_text = clean_text(remove_references(full_text)) text_splitter = SpacyTextSplitter(chunk_size=500) #text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0, separators=["\n\n", "\n", ".", " "]) split_contents = text_splitter.split_text(cleaned_full_text) embeddings = embeddings_model.embed_documents(split_contents) kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings) closest_point_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_] extracted_contents = [split_contents[idx] for idx in closest_point_indices] results = chain.invoke({"topic": ' '.join(extracted_contents)}) return generate_citations(results, extracted_contents) def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6): embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key) llm = ChatOpenAI(model="gpt-4o-mini", api_key=openai_api_key, temperature=0.3) prompt = ChatPromptTemplate.from_template( """Please provide a detailed and accurate answer to the given question based on the provided contexts. Ensure that the answer is comprehensive and directly addresses the query. If necessary, include relevant examples or details from the text. Question: {question} Contexts: {contexts}""" ) output_parser = StrOutputParser() chain = prompt | llm | output_parser loader = PyMuPDFLoader(pdf_file_path) docs = loader.load() full_text = "\n".join(doc.page_content for doc in docs) cleaned_full_text = clean_text(remove_references(full_text)) text_splitter = SpacyTextSplitter(chunk_size=500) #text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "]) split_contents = text_splitter.split_text(cleaned_full_text) embeddings = embeddings_model.embed_documents(split_contents) query_embedding = embeddings_model.embed_query(query) similarity_scores = cosine_similarity([query_embedding], embeddings)[0] top_indices = np.argsort(similarity_scores)[-num_clusters:] relevant_contents = [split_contents[i] for i in top_indices] results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)}) return generate_citations(results, relevant_contents, similarity_threshold) def generate_citations(text, contents, similarity_threshold=0.6): embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key) text_sentences = re.split(r'(?= similarity_threshold: most_similar_idx = np.argmax(similarity_matrix[i]) if most_similar_idx not in source_mapping: source_mapping[most_similar_idx] = len(relevant_sources) + 1 relevant_sources.append((most_similar_idx, contents[most_similar_idx])) citation_idx = source_mapping[most_similar_idx] citation = f"([Source {citation_idx}](#source-{citation_idx}))" cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence) sentence_to_source[sentence] = citation_idx cited_text = cited_text.replace(sentence, cited_sentence) sources_list = "\n\n## Sources:\n" for idx, (original_idx, content) in enumerate(relevant_sources): sources_list += f"""
Source {idx + 1}
{content}
""" # Add dummy blanks after the last source dummy_blanks = """
""" cited_text += sources_list + dummy_blanks return cited_text def infer_image_and_get_boxes(image, confidence_threshold=0.8): results = model.predict(image) return [ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0])) for result in results for box in result.boxes if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold ] def crop_images_from_boxes(image, boxes, scale_factor): figures = [] tables = [] for (x1, y1, x2, y2, cls) in boxes: cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)] if cls == figure_class_index: figures.append(cropped_img) elif cls == table_class_index: tables.append(cropped_img) return figures, tables def process_pdf(pdf_file_path): doc = fitz.open(pdf_file_path) all_figures = [] all_tables = [] low_dpi = 50 high_dpi = 300 scale_factor = high_dpi / low_dpi low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc] for page_num, low_res_pix in enumerate(low_res_pixmaps): low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3) boxes = infer_image_and_get_boxes(low_res_img) if boxes: high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi) high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3) figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor) all_figures.extend(figures) all_tables.extend(tables) return all_figures, all_tables def image_to_base64(img): buffered = io.BytesIO() img = Image.fromarray(img) img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def on_btn_click(): del st.session_state.chat_history[:] # Streamlit interface # Custom CSS for the file uploader uploadercss=''' ''' st.set_page_config(page_title="PDF Reading Assistant", page_icon="📄") # Initialize chat history in session state if not already present if 'chat_history' not in st.session_state: st.session_state.chat_history = [] st.title("📄 PDF Reading Assistant") st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.") chat_placeholder = st.empty() # File uploader for PDF uploaded_file = st.file_uploader("Upload a PDF", type="pdf") st.markdown(uploadercss, unsafe_allow_html=True) if uploaded_file: file_path = save_uploaded_file(uploaded_file) # Chat container where all messages will be displayed chat_container = st.container() user_input = st.chat_input("Ask a question about the pdf......", key="user_input") with chat_container: # Scrollable chat messages for idx, chat in enumerate(st.session_state.chat_history): if chat.get("user"): message(chat["user"], is_user=True, allow_html=True, key=f"user_{idx}", avatar_style="initials", seed="user") if chat.get("bot"): message(chat["bot"], is_user=False, allow_html=True, key=f"bot_{idx}",seed="bot") # Input area and buttons for user interaction with st.form(key="chat_form", clear_on_submit=True,border=False): col1, col2, col3 = st.columns([1, 1, 1]) with col1: summary_button = st.form_submit_button("Generate Summary") with col2: extract_button = st.form_submit_button("Extract Tables and Figures") with col3: st.form_submit_button("Clear message", on_click=on_btn_click) # Handle responses based on user input and button presses if summary_button: with st.spinner("Generating summary..."): summary = summarize_pdf(file_path) st.session_state.chat_history.append({"user": "Generate Summary", "bot": summary}) st.rerun() if extract_button: with st.spinner("Extracting tables and figures..."): figures, tables = process_pdf(file_path) if figures: st.session_state.chat_history.append({"user": "Figures"}) for idx, figure in enumerate(figures): figure_base64 = image_to_base64(figure) result_html = f'Figure {idx+1}' st.session_state.chat_history.append({"bot": f"Figure {idx+1} {result_html}"}) if tables: st.session_state.chat_history.append({"user": "Tables"}) for idx, table in enumerate(tables): table_base64 = image_to_base64(table) result_html = f'Table {idx+1}' st.session_state.chat_history.append({"bot": f"Table {idx+1} {result_html}"}) st.rerun() if user_input: st.session_state.chat_history.append({"user": user_input, "bot": None}) with st.spinner("Processing..."): answer = qa_pdf(file_path, user_input) st.session_state.chat_history[-1]["bot"] = answer st.rerun() # Additional CSS and JavaScript to ensure the chat container is scrollable and scrolls to the bottom st.markdown(""" """, unsafe_allow_html=True)