Spaces:
Runtime error
Runtime error
#Import Library | |
from unstructured.partition.pdf import partition_pdf | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough,RunnableLambda | |
from langchain_postgres.vectorstores import PGVector | |
from database import COLLECTION_NAME, CONNECTION_STRING | |
from langchain_community.storage import RedisStore | |
from langchain.schema.document import Document | |
from langchain_openai import OpenAIEmbeddings | |
from langchain.retrievers.multi_vector import MultiVectorRetriever | |
from pathlib import Path | |
from IPython.display import display, HTML | |
from base64 import b64decode | |
import os, hashlib, shutil, uuid, json, time | |
import torch, redis, streamlit as st | |
import logging | |
import openai | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Ensure PyTorch module path is correctly set | |
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)] | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Initialize Redis client | |
# client = redis.Redis(host="localhost", port=6379, db=0) | |
redis_host = os.getenv("REDIS_HOST", "redis-stack") | |
redis_port = int(os.getenv("REDIS_PORT", 6379)) | |
client = redis.Redis(host=redis_host, port=redis_port, db=0) | |
#Data Loading | |
def load_pdf_data(file_path): | |
logging.info(f"Data ready to be partitioned and loaded ") | |
raw_pdf_elements = partition_pdf( | |
filename=file_path, | |
infer_table_structure=True, | |
strategy = "hi_res", | |
extract_image_block_types = ["Image"], | |
extract_image_block_to_payload = True, | |
chunking_strategy="by_title", | |
mode='elements', | |
max_characters=10000, | |
new_after_n_chars=5000, | |
combine_text_under_n_chars=2000, | |
image_output_dir_path="data/", | |
) | |
logging.info(f"Pdf data finish loading, chunks now available!") | |
return raw_pdf_elements | |
# Generate a unique hash for a PDF file | |
def get_pdf_hash(pdf_path): | |
"""Generate a SHA-256 hash of the PDF file content.""" | |
with open(pdf_path, "rb") as f: | |
pdf_bytes = f.read() | |
return hashlib.sha256(pdf_bytes).hexdigest() | |
# Summarize extracted text and tables using LLM | |
def summarize_text_and_tables(text, tables): | |
logging.info("Ready to summarize data with LLM") | |
prompt_text = """You are an assistant tasked with summarizing text and tables. \ | |
You are to give a concise summary of the table or text and do nothing else. | |
Table or text chunk: {element} """ | |
prompt = ChatPromptTemplate.from_template(prompt_text) | |
model = ChatOpenAI(temperature=0.6, model="gpt-4o-mini", openai_api_key=openai_api_key) | |
summarize_chain = {"element": RunnablePassthrough()}| prompt | model | StrOutputParser() | |
logging.info(f"{model} done with summarization") | |
return { | |
"text": summarize_chain.batch(text, {"max_concurrency": 5}), | |
"table": summarize_chain.batch(tables, {"max_concurrency": 5}) | |
} | |
#Initialize a pgvector and retriever for storing and searching documents | |
def initialize_retriever(): | |
store = RedisStore(client=client) | |
id_key = "doc_id" | |
vectorstore = PGVector( | |
embeddings=OpenAIEmbeddings(), | |
collection_name=COLLECTION_NAME, | |
connection=CONNECTION_STRING, | |
use_jsonb=True, | |
) | |
retrieval_loader = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key="doc_id") | |
return retrieval_loader | |
# Store text, tables, and their summaries in the retriever | |
def store_docs_in_retriever(text, text_summary, table, table_summary, retriever): | |
"""Store text and table documents along with their summaries in the retriever.""" | |
def add_documents_to_retriever(documents, summaries, retriever, id_key = "doc_id"): | |
"""Helper function to add documents and their summaries to the retriever.""" | |
if not summaries: | |
return None, [] | |
doc_ids = [str(uuid.uuid4()) for _ in documents] | |
summary_docs = [ | |
Document(page_content=summary, metadata={id_key: doc_ids[i]}) | |
for i, summary in enumerate(summaries) | |
] | |
retriever.vectorstore.add_documents(summary_docs, ids=doc_ids) | |
retriever.docstore.mset(list(zip(doc_ids, documents))) | |
# Add text, table, and image summaries to the retriever | |
add_documents_to_retriever(text, text_summary, retriever) | |
add_documents_to_retriever(table, table_summary, retriever) | |
return retriever | |
# Parse the retriever output | |
def parse_retriver_output(data): | |
parsed_elements = [] | |
for element in data: | |
# Decode bytes to string if necessary | |
if isinstance(element, bytes): | |
element = element.decode("utf-8") | |
parsed_elements.append(element) | |
return parsed_elements | |
# Chat with the LLM using retrieved context | |
def chat_with_llm(retriever): | |
logging.info(f"Context ready to send to LLM ") | |
prompt_text = """ | |
You are an AI Assistant tasked with understanding detailed | |
information from text and tables. You are to answer the question based on the | |
context provided to you. You must not go beyond the context given to you. | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
prompt = ChatPromptTemplate.from_template(prompt_text) | |
model = ChatOpenAI(temperature=0.6, model="gpt-4o-mini", openai_api_key=openai_api_key) | |
rag_chain = ({ | |
"context": retriever | RunnableLambda(parse_retriver_output), "question": RunnablePassthrough(), | |
} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
logging.info(f"Completed! ") | |
return rag_chain | |
# Generate temporary file path of uploaded docs | |
def _get_file_path(file_upload): | |
temp_dir = "temp" | |
os.makedirs(temp_dir, exist_ok=True) # Ensure the directory exists | |
if isinstance(file_upload, str): | |
file_path = file_upload # Already a string path | |
else: | |
file_path = os.path.join(temp_dir, file_upload.name) | |
with open(file_path, "wb") as f: | |
f.write(file_upload.getbuffer()) | |
return file_path | |
# Process uploaded PDF file | |
def process_pdf(file_upload): | |
print('Processing PDF hash info...') | |
file_path = _get_file_path(file_upload) | |
pdf_hash = get_pdf_hash(file_path) | |
load_retriever = initialize_retriever() | |
existing = client.exists(f"pdf:{pdf_hash}") | |
print(f"Checking Redis for hash {pdf_hash}: {'Exists' if existing else 'Not found'}") | |
if existing: | |
print(f"PDF already exists with hash {pdf_hash}. Skipping upload.") | |
return load_retriever | |
print(f"New PDF detected. Processing... {pdf_hash}") | |
pdf_elements = load_pdf_data(file_path) | |
tables = [element.metadata.text_as_html for element in | |
pdf_elements if 'Table' in str(type(element))] | |
text = [element.text for element in pdf_elements if | |
'CompositeElement' in str(type(element))] | |
summaries = summarize_text_and_tables(text, tables) | |
retriever = store_docs_in_retriever(text, summaries['text'], tables, summaries['table'], load_retriever) | |
# Store the PDF hash in Redis | |
client.set(f"pdf:{pdf_hash}", json.dumps({"text": "PDF processed"})) | |
# Debug: Check if Redis stored the key | |
stored = client.exists(f"pdf:{pdf_hash}") | |
# #remove temp directory | |
# shutil.rmtree("dir") | |
print(f"Stored PDF hash in Redis: {'Success' if stored else 'Failed'}") | |
return retriever | |
#Invoke chat with LLM based on uploaded PDF and user query | |
def invoke_chat(file_upload, message): | |
retriever =process_pdf(file_upload) | |
rag_chain = chat_with_llm(retriever) | |
response = rag_chain.invoke(message) | |
response_placeholder = st.empty() | |
response_placeholder.write(response) | |
return response | |
# Main application interface using Streamlit | |
def main(): | |
st.title("PDF Chat Assistant ") | |
logging.info("App started") | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
file_upload = st.sidebar.file_uploader( | |
label="Upload", type=["pdf"], | |
accept_multiple_files=False, | |
key="pdf_uploader" | |
) | |
if file_upload: | |
st.success("File uploaded successfully! You can now ask your question.") | |
# Prompt for user input | |
if prompt := st.chat_input("Your question"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Generate response if last message is not from assistant | |
if st.session_state.messages and st.session_state.messages[-1]["role"] != "assistant": | |
with st.chat_message("assistant"): | |
start_time = time.time() | |
logging.info("Generating response...") | |
with st.spinner("Processing..."): | |
user_message = " ".join([msg["content"] for msg in st.session_state.messages if msg]) | |
response_message = invoke_chat(file_upload, user_message) | |
duration = time.time() - start_time | |
response_msg_with_duration = f"{response_message}\n\nDuration: {duration:.2f} seconds" | |
st.session_state.messages.append({"role": "assistant", "content": response_msg_with_duration}) | |
st.write(f"Duration: {duration:.2f} seconds") | |
logging.info(f"Response: {response_message}, Duration: {duration:.2f} s") | |
if __name__ == "__main__": | |
main() |