Spaces:
Running
Running
import streamlit as st | |
import tempfile | |
import logging | |
from typing import List, Optional | |
import torch | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import PromptTemplate | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' | |
DEFAULT_MODEL = "distilgpt2" | |
MAX_LENGTH_FRACTION = 0.2 # Set max_length to 20% of input length | |
# Check for GPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
st.sidebar.write(f"Using device: {device}") | |
def load_embeddings(model_name: str) -> Optional[HuggingFaceEmbeddings]: | |
"""Load the embedding model.""" | |
try: | |
return HuggingFaceEmbeddings(model_name=model_name) | |
except Exception as e: | |
logger.error(f"Failed to load embeddings: {e}") | |
return None | |
def load_llm(model_name: str, max_length: int) -> Optional[HuggingFacePipeline]: | |
"""Load the language model.""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_length=max_length) | |
return HuggingFacePipeline(pipeline=pipe) | |
except Exception as e: | |
logger.error(f"Failed to load LLM: {e}") | |
return None | |
def process_pdf(file) -> Optional[List[Document]]: | |
"""Process the uploaded PDF file.""" | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
temp_file.write(file.getvalue()) | |
temp_file_path = temp_file.name | |
loader = PyPDFLoader(file_path=temp_file_path) | |
pages = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100) | |
documents = text_splitter.split_documents(pages) | |
return documents | |
except Exception as e: | |
logger.error(f"Error processing PDF: {e}") | |
return None | |
def create_vector_store(documents: List[Document], embeddings: HuggingFaceEmbeddings) -> Optional[FAISS]: | |
"""Create the vector store.""" | |
try: | |
return FAISS.from_documents(documents, embeddings) | |
except Exception as e: | |
logger.error(f"Error creating vector store: {e}") | |
return None | |
def summarize_report(documents: List[Document], llm: HuggingFacePipeline, max_length: int, summary_style: str) -> Optional[str]: | |
"""Summarize the report using the loaded model.""" | |
try: | |
prompt_template = f""" | |
Summarize the following text in a {summary_style} manner. Focus on the main points and key details: | |
{{text}} | |
Summary: | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["text"]) | |
chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt) | |
summary = chain.run(documents, max_length=max_length) | |
return summary | |
except Exception as e: | |
logger.error(f"Error summarizing report: {e}") | |
return None | |
def main(): | |
st.title("Report Summarizer") | |
model_option = st.sidebar.text_input("Enter model name", value=DEFAULT_MODEL) | |
summary_style = st.sidebar.selectbox("Summary style", options=["clear and concise", "formal", "informal", "bullet points"]) | |
uploaded_file = st.sidebar.file_uploader("Upload your Report", type="pdf") | |
llm = load_llm(model_option, 1024) # Load the model with a default max_length | |
if not llm: | |
st.error(f"Failed to load the model {model_option}. Please try another model.") | |
return | |
embeddings = load_embeddings(EMBEDDING_MODEL) | |
if not embeddings: | |
st.error(f"Failed to load embeddings. Please try again later.") | |
return | |
if uploaded_file: | |
with st.spinner("Processing PDF..."): | |
documents = process_pdf(uploaded_file) | |
if documents: | |
with st.spinner("Creating vector store..."): | |
db = create_vector_store(documents, embeddings) | |
if db and st.button("Summarize"): | |
# Calculate max_length based on input text | |
input_length = sum([len(doc.page_content.split()) for doc in documents]) | |
max_length = int(input_length * MAX_LENGTH_FRACTION) | |
# Reload the model with the calculated max_length | |
llm = load_llm(model_option, max_length) | |
with st.spinner(f"Generating summary using {model_option}..."): | |
summary = summarize_report(documents, llm, max_length, summary_style) | |
if summary: | |
st.subheader("Summary:") | |
st.write(summary) | |
else: | |
st.warning("Failed to generate summary. Please try again.") | |
if __name__ == "__main__": | |
main() | |