mistreal / app.py
Kathirsci's picture
Update app.py
2976ddc verified
raw
history blame
4.64 kB
import os
import logging
import tempfile
from typing import List
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DB_FAISS_PATH = 'vectorstore/db_faiss'
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
DEFAULT_MODEL = "facebook/bart-large-cnn"
# Default model parameters
DEFAULT_PARAMS = {
"temperature": 0.7,
"max_length": 1024,
"num_beams": 4,
"top_p": 0.95,
"repetition_penalty": 1.2,
}
def get_default_value(param_name: str, default: float) -> float:
"""Safely get a float value from DEFAULT_PARAMS."""
value = DEFAULT_PARAMS.get(param_name, default)
return float(value) if not isinstance(value, list) else float(value[0]) if value else default
def load_embeddings():
"""Load and cache the embedding model."""
try:
return SentenceTransformer(EMBEDDING_MODEL)
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
raise
def load_llm(model_name, custom_params=None):
"""Load the language model with specific parameters."""
try:
params = custom_params or DEFAULT_PARAMS
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return pipeline("summarization", model=model, tokenizer=tokenizer, **params)
except Exception as e:
logger.error(f"Failed to load LLM: {e}")
raise
def process_pdf(file) -> List[Document]:
"""Process the PDF and convert it into a list of Document objects."""
try:
loader = PyPDFLoader(file_path=file)
documents = loader.load() # Load each page as a separate Document
return documents
except Exception as e:
logger.error(f"Error processing PDF: {e}")
raise
def create_vector_store(documents: List[Document], embeddings):
"""Create and save the vector store."""
try:
db = FAISS.from_documents(documents, embeddings)
db.save_local(DB_FAISS_PATH)
return db
except Exception as e:
logger.error(f"Error creating vector store: {e}")
raise
def summarize_report(documents: List[Document], llm) -> str:
"""Summarize the report using a map-reduce approach."""
try:
# Limit the number of chunks to process
max_chunks = 50 # Adjust this value based on your needs
if len(documents) > max_chunks:
logger.warning(f"Document is very large. Summarizing first {max_chunks} chunks only.")
documents = documents[:max_chunks]
# Map prompt
map_template = """Summarize the following text:\n\n{text}\n\nSummary:"""
map_prompt = PromptTemplate.from_template(map_template)
# Reduce prompt
reduce_template = """Combine these summaries into a final summary:\n\nSummary:\n{doc_summaries}\n\nFinal Summary:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
# Create the chains
map_chain = MapReduceDocumentsChain(
llm_chain=lambda text: llm(text=map_prompt.format(text=text)),
reduce_documents_chain=ReduceDocumentsChain(
combine_documents_chain=lambda summaries: llm(text=reduce_prompt.format(doc_summaries=summaries))
),
)
summary = map_chain.run(documents)
return summary
except Exception as e:
logger.error(f"Error summarizing report: {e}")
raise
def main(pdf_path: str, model_name: str = DEFAULT_MODEL):
"""Main function to summarize the PDF report."""
try:
# Load models and embeddings
embeddings = load_embeddings()
llm = load_llm(model_name)
# Process the PDF
documents = process_pdf(pdf_path)
# Create vector store
create_vector_store(documents, embeddings)
# Generate summary
summary = summarize_report(documents, llm)
print("Structured Summary:\n", summary)
except Exception as e:
logger.error(f"Failed to summarize the report: {e}")
if __name__ == "__main__":
pdf_path = "path/to/your/report.pdf" # Replace with the path to your PDF
main(pdf_path)