import base64 from langchain.chains.summarize import load_summarize_chain from langchain.docstore.document import Document from langchain.document_loaders.pdf import PyMuPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from PyPDF2 import PdfReader import re import streamlit as st import sys import time import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM from transformers import pipeline # notes # https://huggingface.co/docs/transformers/pad_truncation # file loader and preprocessor def file_preprocessing(file, skipfirst, skiplast): loader = PyMuPDFLoader(file) pages = loader.load_and_split() # skip page(s) if (skipfirst == 1) & (skiplast == 0): del pages[0] elif (skipfirst == 0) & (skiplast == 1): del pages[-1] elif (skipfirst == 1) & (skiplast == 1): del pages[0] del pages[-1] else: pages = pages # https://stackoverflow.com/questions/76431655/langchain-pypdfloader content = "" for page in pages: content = content + page.page_content content = re.sub("-\n", "", content) print("\n###### New article ######\n") print("Input text:\n") print(content) print("\nChunking...") text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, # number of characters chunk_overlap=100, length_function=len, separators=["\n\n", "\n", " ", ""], # default list ) # https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846 texts = text_splitter.split_text(content) print("Number of tokens: " + str(len(texts))) print("\nFirst three tokens:\n") print(texts[0]) print("") print(texts[1]) print("") print(texts[2]) print("") final_texts = "" for text in texts: final_texts = final_texts + text return texts, final_texts # function to count words in the input def preproc_count(filepath, skipfirst, skiplast): texts, input_text = file_preprocessing(filepath, skipfirst, skiplast) input_text = input_text.replace("-", "") text_length = len(re.findall(r"\w+", input_text)) print("Input word count: " f"{text_length:,}") return texts, input_text, text_length # function to covert (bart) summary to sentence case def convert_to_sentence_case(text): sentences = re.split(r"(?<=[.!?])\s+", text) formatted_sentences = [sentence.capitalize() for sentence in sentences] return " ".join(formatted_sentences) # llm pipeline def llm_pipeline(tokenizer, base_model, input_text, model_source): pipe_sum = pipeline( "summarization", model=base_model, tokenizer=tokenizer, max_length=300, min_length=200, truncation=True, ) print("Model source: %s" % (model_source)) print("Summarizing...") result = pipe_sum(input_text) summary = result[0]["summary_text"] print("Summarization finished\n") print("Summary text:\n") print(summary) print("") return summary # function to count words in the summary def postproc_count(summary): text_length = len(re.findall(r"\w+", summary)) print("Summary word count: " f"{text_length:,}") return text_length # function to clean summary text def clean_summary_text(summary): # remove whitespace summary_clean_1 = summary.strip() # remove spaces before punctuation (bart) summary_clean_2 = re.sub(r'\s([,.():;?!"](?:\s|$))', r"\1", summary_clean_1) # convert to sentence case summary_clean_3 = convert_to_sentence_case(summary_clean_2) return summary_clean_3 @st.cache_data(ttl=60 * 60) # function to display the PDF def displayPDF(file): with open(file, "rb") as f: base64_pdf = base64.b64encode(f.read()).decode("utf-8") # embed pdf in html pdf_display = f'' # display file st.markdown(pdf_display, unsafe_allow_html=True) # streamlit code st.set_page_config(layout="wide") def main(): st.title("RASA: Research Article Summarization App") uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"]) if uploaded_file is not None: st.subheader("Options") col1, col2, col3, col4 = st.columns([1, 1, 1, 2]) with col1: model_source_names = ["Cached model", "Download model"] model_source = st.radio( "For development:", model_source_names, help="Defaults to a cached model; downloading will take longer", ) with col2: model_names = [ "T5-Small", "BART", ] selected_model = st.radio( "Select a model to use:", model_names, help="Defauls to T5-Small; for most articles it summarizes better than BART", ) if selected_model == "BART": checkpoint = "ccdv/lsg-bart-base-16384-pubmed" tokenizer = AutoTokenizer.from_pretrained( checkpoint, truncation=True, model_max_length=1000, trust_remote_code=True, ) if model_source == "Download model": base_model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, torch_dtype=torch.float32, trust_remote_code=True, ) else: base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15" else: checkpoint = "MBZUAI/LaMini-Flan-T5-77M" tokenizer = AutoTokenizer.from_pretrained( checkpoint, truncation=True, legacy=False, model_max_length=1000, ) if model_source == "Download model": base_model = AutoModelForSeq2SeqLM.from_pretrained( checkpoint, torch_dtype=torch.float32, ) else: base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474" with col3: st.write("Skip any pages?") skipfirst = st.checkbox( "Skip first page", help="Select if your PDF has a cover page" ) skiplast = st.checkbox("Skip last page") with col4: st.write("Background information (links open in a new window)") st.write( "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)" "  |  Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)" ) st.write( "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)" "  |  Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)" ) if st.button("Summarize"): col1, col2 = st.columns(2) filepath = "data/" + uploaded_file.name with open(filepath, "wb") as temp_file: temp_file.write(uploaded_file.read()) with col1: texts, input_text, preproc_text_length = preproc_count( filepath, skipfirst, skiplast ) st.info( "Uploaded PDF  |  Number of words: " f"{preproc_text_length:,}" ) pdf_viewer = displayPDF(filepath) with col2: start = time.time() with st.spinner("Summarizing..."): summary = llm_pipeline( tokenizer, base_model, input_text, model_source ) postproc_text_length = postproc_count(summary) end = time.time() duration = end - start print("Duration: " f"{duration:.0f}" + " seconds") st.info( "PDF Summary  |  Number of words: " f"{postproc_text_length:,}" + "  |  Summarization time: " f"{duration:.0f}" + " seconds" ) if selected_model == "BART": summary_cleaned = clean_summary_text(summary) st.success(summary_cleaned) with st.expander("Raw output"): st.write(summary) else: st.success(summary) col1 = st.columns(1) url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846" st.info("Additional information") st.write("\n[RecursiveCharacterTextSplitter](%s) parameters used:" % url) st.write("        chunk_size=1000") st.write( "        chunk_overlap=100" ) st.write( "        length_function=len" ) st.write("") st.write("Number of tokens generated: " + str(len(texts))) st.write("") st.write("First three tokens:") st.write("----") st.write(texts[0]) st.write("----") st.write(texts[1]) st.write("----") st.write(texts[2]) st.markdown( """ """, unsafe_allow_html=True, ) if __name__ == "__main__": main()