import os os.system("python -m spacy download en_core_web_sm") import io import base64 import streamlit as st import numpy as np import fitz # PyMuPDF import tempfile from ultralytics import YOLO from sklearn.cluster import KMeans from sklearn.metrics.pairwise import cosine_similarity from langchain_core.output_parsers import StrOutputParser from langchain_community.document_loaders import PyMuPDFLoader from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI import re from PIL import Image openai_api_key = os.environ.get("openai_api_key") # Cached resources @st.cache_resource def load_models(): return { "yolo": YOLO("best.pt"), "embeddings": OpenAIEmbeddings(model="text-embedding-3-small",api_key=openai_api_key), "llm": ChatOpenAI(model="gpt-4-turbo", temperature=0.3,api_key=openai_api_key) } models = load_models() # Constants FIGURE_CLASS_INDEX = 4 TABLE_CLASS_INDEX = 3 CHUNK_SIZE = 1000 CHUNK_OVERLAP = 200 NUM_CLUSTERS = 8 # Utility functions def clean_text(text): return re.sub(r'\s+', ' ', text).strip() def remove_references(text): reference_patterns = [ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b', r'\bWorks Cited\b' ] return re.sub('|'.join(reference_patterns), '', text, flags=re.IGNORECASE) @st.cache_data def process_pdf(file_path): """Process PDF once and cache results""" loader = PyMuPDFLoader(file_path) docs = loader.load() full_text = "\n".join(doc.page_content for doc in docs) cleaned_text = clean_text(remove_references(full_text)) text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, separators=["\n\n", "\n", ". ", "! ", "? ", " "] ) split_contents = text_splitter.split_text(cleaned_text) return { "text": cleaned_text, "chunks": split_contents, "embeddings": models["embeddings"].embed_documents(split_contents) } @st.cache_data def extract_visuals(file_path): """Extract figures and tables with caching""" doc = fitz.open(file_path) all_figures = [] all_tables = [] for page in doc: low_res_pix = page.get_pixmap(dpi=50) low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3) results = models["yolo"].predict(low_res_img) boxes = [ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0])) for result in results for box in result.boxes if box.conf[0] > 0.8 and int(box.cls[0]) in {FIGURE_CLASS_INDEX, TABLE_CLASS_INDEX} ] if boxes: high_res_pix = page.get_pixmap(dpi=300) high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3) for x1, y1, x2, y2, cls in boxes: img = high_res_img[int(y1*6):int(y2*6), int(x1*6):int(x2*6)] if cls == FIGURE_CLASS_INDEX: all_figures.append(img) else: all_tables.append(img) return {"figures": all_figures, "tables": all_tables} def generate_summary(chunks, embeddings): """Generate summary using clustered chunks""" kmeans = KMeans(n_clusters=NUM_CLUSTERS, init='k-means++').fit(embeddings) cluster_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_] selected_chunks = [chunks[i] for i in cluster_indices] prompt = ChatPromptTemplate.from_template( """Create a structured summary with key points from these context sections: {contexts} Format: ## Summary [concise overview] ## Key Points - [main point 1] - [main point 2] ...""" ) chain = prompt | models["llm"] | StrOutputParser() return chain.invoke({"contexts": '\n\n'.join(selected_chunks)}) def answer_question(question, chunks, embeddings): """Answer question using semantic search""" query_embedding = models["embeddings"].embed_query(question) similarities = cosine_similarity([query_embedding], embeddings)[0] top_indices = np.argsort(similarities)[-5:][::-1] context = '\n'.join([chunks[i] for i in top_indices if similarities[i] > 0.6]) prompt = ChatPromptTemplate.from_template( """Answer this question: {question} Using only this context: {context} - Be precise and include relevant details - Cite sources as [Source 1], [Source 2], etc.""" ) chain = prompt | models["llm"] | StrOutputParser() return chain.invoke({"question": question, "context": context}) # Streamlit UI #st.set_page_config(page_title="PDF Assistant", layout="wide") st.title("📄 Smart PDF Assistant") if "chat" not in st.session_state: st.session_state.chat = [] if "processed_data" not in st.session_state: st.session_state.processed_data = None # File upload section with st.sidebar: uploaded_file = st.file_uploader("Upload PDF", type="pdf") if uploaded_file: with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(uploaded_file.getbuffer()) st.session_state.processed_data = process_pdf(tmp.name) visuals = extract_visuals(tmp.name) # Chat interface col1, col2 = st.columns([3, 1]) with col1: st.subheader("Document Interaction") for msg in st.session_state.chat: with st.chat_message(msg["role"]): if "image" in msg: st.image(msg["image"], caption=msg.get("caption")) else: st.markdown(msg["content"]) if prompt := st.chat_input("Ask about the document..."): st.session_state.chat.append({"role": "user", "content": prompt}) with st.spinner("Analyzing..."): response = answer_question( prompt, st.session_state.processed_data["chunks"], st.session_state.processed_data["embeddings"] ) st.session_state.chat.append({"role": "assistant", "content": response}) st.rerun() with col2: st.subheader("Document Insights") if st.button("Generate Summary"): with st.spinner("Summarizing..."): summary = generate_summary( st.session_state.processed_data["chunks"], st.session_state.processed_data["embeddings"] ) st.session_state.chat.append({ "role": "assistant", "content": f"## Document Summary\n{summary}" }) st.rerun() if visuals["figures"]: with st.expander(f"📷 Figures ({len(visuals['figures'])})"): for idx, fig in enumerate(visuals["figures"], 1): st.image(fig, caption=f"Figure {idx}") if visuals["tables"]: with st.expander(f"📊 Tables ({len(visuals['tables'])})"): for idx, tbl in enumerate(visuals["tables"], 1): st.image(tbl, caption=f"Table {idx}") # Custom styling st.markdown(""" """, unsafe_allow_html=True)