Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import tempfile | |
from typing import List | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DB_FAISS_PATH = 'vectorstore/db_faiss' | |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' | |
DEFAULT_MODEL = "facebook/bart-large-cnn" | |
# Default model parameters | |
DEFAULT_PARAMS = { | |
"temperature": 0.7, | |
"max_length": 1024, | |
"num_beams": 4, | |
"top_p": 0.95, | |
"repetition_penalty": 1.2, | |
} | |
def get_default_value(param_name: str, default: float) -> float: | |
"""Safely get a float value from DEFAULT_PARAMS.""" | |
value = DEFAULT_PARAMS.get(param_name, default) | |
return float(value) if not isinstance(value, list) else float(value[0]) if value else default | |
def load_embeddings(): | |
"""Load and cache the embedding model.""" | |
try: | |
return SentenceTransformer(EMBEDDING_MODEL) | |
except Exception as e: | |
logger.error(f"Failed to load embeddings: {e}") | |
raise | |
def load_llm(model_name, custom_params=None): | |
"""Load the language model with specific parameters.""" | |
try: | |
params = custom_params or DEFAULT_PARAMS | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return pipeline("summarization", model=model, tokenizer=tokenizer, **params) | |
except Exception as e: | |
logger.error(f"Failed to load LLM: {e}") | |
raise | |
def process_pdf(file) -> List[Document]: | |
"""Process the PDF and convert it into a list of Document objects.""" | |
try: | |
loader = PyPDFLoader(file_path=file) | |
documents = loader.load() # Load each page as a separate Document | |
return documents | |
except Exception as e: | |
logger.error(f"Error processing PDF: {e}") | |
raise | |
def create_vector_store(documents: List[Document], embeddings): | |
"""Create and save the vector store.""" | |
try: | |
db = FAISS.from_documents(documents, embeddings) | |
db.save_local(DB_FAISS_PATH) | |
return db | |
except Exception as e: | |
logger.error(f"Error creating vector store: {e}") | |
raise | |
def summarize_report(documents: List[Document], llm) -> str: | |
"""Summarize the report using a map-reduce approach.""" | |
try: | |
# Limit the number of chunks to process | |
max_chunks = 50 # Adjust this value based on your needs | |
if len(documents) > max_chunks: | |
logger.warning(f"Document is very large. Summarizing first {max_chunks} chunks only.") | |
documents = documents[:max_chunks] | |
# Map prompt | |
map_template = """Summarize the following text:\n\n{text}\n\nSummary:""" | |
map_prompt = PromptTemplate.from_template(map_template) | |
# Reduce prompt | |
reduce_template = """Combine these summaries into a final summary:\n\nSummary:\n{doc_summaries}\n\nFinal Summary:""" | |
reduce_prompt = PromptTemplate.from_template(reduce_template) | |
# Create the chains | |
map_chain = MapReduceDocumentsChain( | |
llm_chain=lambda text: llm(text=map_prompt.format(text=text)), | |
reduce_documents_chain=ReduceDocumentsChain( | |
combine_documents_chain=lambda summaries: llm(text=reduce_prompt.format(doc_summaries=summaries)) | |
), | |
) | |
summary = map_chain.run(documents) | |
return summary | |
except Exception as e: | |
logger.error(f"Error summarizing report: {e}") | |
raise | |
def main(pdf_path: str, model_name: str = DEFAULT_MODEL): | |
"""Main function to summarize the PDF report.""" | |
try: | |
# Load models and embeddings | |
embeddings = load_embeddings() | |
llm = load_llm(model_name) | |
# Process the PDF | |
documents = process_pdf(pdf_path) | |
# Create vector store | |
create_vector_store(documents, embeddings) | |
# Generate summary | |
summary = summarize_report(documents, llm) | |
print("Structured Summary:\n", summary) | |
except Exception as e: | |
logger.error(f"Failed to summarize the report: {e}") | |
if __name__ == "__main__": | |
pdf_path = "path/to/your/report.pdf" # Replace with the path to your PDF | |
main(pdf_path) | |