Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tempfile | |
import os | |
import logging | |
import subprocess | |
from typing import List | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.runnables import RunnableMap, RunnableLambda | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DB_FAISS_PATH = 'vectorstore/db_faiss' | |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' | |
DEFAULT_MODEL = "google/flan-t5-large" # Replace with your preferred Hugging Face model | |
# Default model parameters | |
DEFAULT_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 1.0, | |
"num_ctx": 4096, | |
"repeat_penalty": 1.1, | |
} | |
def get_default_value(param_name: str, default: float) -> float: | |
"""Safely get a float value from DEFAULT_PARAMS.""" | |
value = DEFAULT_PARAMS.get(param_name, default) | |
return float(value) if not isinstance(value, list) else float(value[0]) if value else default | |
def load_embeddings(): | |
"""Load and cache the embedding model.""" | |
try: | |
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL, model_kwargs={'device': 'cpu'}) | |
except Exception as e: | |
logger.error(f"Failed to load embeddings: {e}") | |
st.error("Failed to load the embedding model. Please try again later.") | |
return None | |
def load_llm(model_name: str): | |
"""Load and cache the Hugging Face model and tokenizer.""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer) | |
return summarizer | |
except Exception as e: | |
logger.error(f"Failed to load LLM: {e}") | |
st.error(f"Failed to load the model {model_name}. Please check the model name and try again.") | |
return None | |
def process_pdf(file) -> List[Document]: | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
temp_file.write(file.getvalue()) | |
temp_file_path = temp_file.name | |
loader = PyPDFLoader(file_path=temp_file_path) | |
documents = loader.load() # This loads each page as a separate Document | |
os.unlink(temp_file_path) # Clean up the temporary file | |
return documents | |
except Exception as e: | |
logger.error(f"Error processing PDF: {e}") | |
st.error("Failed to process the PDF. Please make sure it's a valid PDF file.") | |
return [] | |
def create_vector_store(documents: List[Document], embeddings): | |
"""Create and save the vector store.""" | |
try: | |
db = FAISS.from_documents(documents, embeddings) | |
db.save_local(DB_FAISS_PATH) | |
return db | |
except Exception as e: | |
logger.error(f"Error creating vector store: {e}") | |
st.error("Failed to create the vector store. Please try again.") | |
return None | |
def summarize_report(documents: List[Document], summarizer) -> str: | |
"""Summarize the report using a map-reduce approach.""" | |
try: | |
# Limit the number of chunks to process | |
max_chunks = 50 # Adjust this value based on your needs | |
if len(documents) > max_chunks: | |
st.warning(f"Document is very large. Summarizing first {max_chunks} chunks only.") | |
documents = documents[:max_chunks] | |
# Map prompt | |
def map_fn(text): | |
summary = summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text'] | |
return summary | |
# Reduce prompt | |
def reduce_fn(summaries): | |
combined_text = " ".join(summaries) | |
final_summary = summarizer(combined_text, max_length=300, min_length=100, do_sample=False)[0]['summary_text'] | |
return final_summary | |
# RunnableSequence replaces the deprecated LLMChain | |
map_chain = RunnableMap( | |
llm_chain=lambda text: map_fn(text) | |
) | |
reduce_chain = RunnableLambda( | |
llm_chain=lambda doc_summaries: reduce_fn(doc_summaries) | |
) | |
with st.spinner("Generating summary..."): | |
# Run map-reduce sequence | |
summaries = map_chain.run([doc.page_content for doc in documents]) | |
summary = reduce_chain.run({"doc_summaries": summaries}) | |
return summary | |
except Exception as e: | |
logger.error(f"Error summarizing report: {e}") | |
st.error("Failed to summarize the report. Please try again.") | |
return "" | |
def main(): | |
st.title("Report Summarizer ") | |
model_option = st.sidebar.text_input("Enter Hugging Face model name", value=DEFAULT_MODEL) | |
# Advanced options | |
with st.sidebar.expander("Advanced Model Parameters"): | |
custom_temp = st.slider("Temperature", 0.0, 1.0, | |
value=get_default_value("temperature", 0.7), | |
step=0.01) | |
custom_top_p = st.slider("Top P", 0.0, 1.0, | |
value=get_default_value("top_p", 1.0), | |
step=0.01) | |
custom_num_ctx = st.number_input("Context Window", 1024, 8192, | |
value=int(get_default_value("num_ctx", 4096))) | |
custom_repeat_penalty = st.slider("Repeat Penalty", 1.0, 2.0, | |
value=get_default_value("repeat_penalty", 1.1), | |
step=0.01) | |
custom_params = { | |
"temperature": custom_temp, | |
"top_p": custom_top_p, | |
"num_ctx": custom_num_ctx, | |
"repeat_penalty": custom_repeat_penalty | |
} | |
uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf") | |
summarizer = load_llm(model_option) | |
embeddings = load_embeddings() | |
if not summarizer or not embeddings: | |
return | |
if uploaded_file: | |
with st.spinner("Processing PDF..."): | |
documents = process_pdf(uploaded_file) | |
if documents: | |
with st.spinner("Creating vector store..."): | |
db = create_vector_store(documents, embeddings) | |
if db and st.button("Summarize"): | |
with st.spinner(f"Generating structured summary using {model_option}..."): | |
summary = summarize_report(documents, summarizer) | |
if summary: | |
st.subheader("Structured Summary:") | |
st.markdown(summary) | |
else: | |
st.warning("Failed to generate summary. Please try again.") | |
if __name__ == "__main__": | |
main() | |