rasa / app.py
wjjessen's picture
update code
b2d65e0
raw
history blame
10.4 kB
import base64
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import re
import streamlit as st
import sys
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import pipeline
# notes
# https://huggingface.co/docs/transformers/pad_truncation
# file loader and preprocessor
def file_preprocessing(file, skipfirst, skiplast):
loader = PyMuPDFLoader(file)
pages = loader.load_and_split()
# skip page(s)
if (skipfirst == 1) & (skiplast == 0):
del pages[0]
elif (skipfirst == 0) & (skiplast == 1):
del pages[-1]
elif (skipfirst == 1) & (skiplast == 1):
del pages[0]
del pages[-1]
else:
pages = pages
# https://stackoverflow.com/questions/76431655/langchain-pypdfloader
content = ""
for page in pages:
content = content + page.page_content
content = re.sub("-\n", "", content)
print("\n###### New article ######\n")
print("Input text:\n")
print(content)
print("\nChunking...")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # number of characters
chunk_overlap=100,
length_function=len,
separators=["\n\n", "\n", " ", ""], # default list
)
# https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
texts = text_splitter.split_text(content)
print("Number of tokens: " + str(len(texts)))
print("\nFirst three tokens:\n")
print(texts[0])
print("")
print(texts[1])
print("")
print(texts[2])
print("")
final_texts = ""
for text in texts:
final_texts = final_texts + text
return texts, final_texts
# function to count words in the input
def preproc_count(filepath, skipfirst, skiplast):
texts, input_text = file_preprocessing(filepath, skipfirst, skiplast)
input_text = input_text.replace("-", "")
text_length = len(re.findall(r"\w+", input_text))
print("Input word count: " f"{text_length:,}")
return texts, input_text, text_length
# function to covert (bart) summary to sentence case
def convert_to_sentence_case(text):
sentences = re.split(r"(?<=[.!?])\s+", text)
formatted_sentences = [sentence.capitalize() for sentence in sentences]
return " ".join(formatted_sentences)
# llm pipeline
def llm_pipeline(tokenizer, base_model, input_text, model_source):
pipe_sum = pipeline(
"summarization",
model=base_model,
tokenizer=tokenizer,
max_length=300,
min_length=200,
truncation=True,
)
print("Model source: %s" % (model_source))
print("Summarizing...")
result = pipe_sum(input_text)
summary = result[0]["summary_text"]
print("Summarization finished\n")
print("Summary text:\n")
print(summary)
print("")
return summary
# function to count words in the summary
def postproc_count(summary):
text_length = len(re.findall(r"\w+", summary))
print("Summary word count: " f"{text_length:,}")
return text_length
# function to clean summary text
def clean_summary_text(summary):
# remove whitespace
summary_clean_1 = summary.strip()
# remove spaces before punctuation (bart)
summary_clean_2 = re.sub(r'\s([,.():;?!"](?:\s|$))', r"\1", summary_clean_1)
# convert to sentence case
summary_clean_3 = convert_to_sentence_case(summary_clean_2)
return summary_clean_3
@st.cache_data(ttl=60 * 60)
# function to display the PDF
def displayPDF(file):
with open(file, "rb") as f:
base64_pdf = base64.b64encode(f.read()).decode("utf-8")
# embed pdf in html
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
# display file
st.markdown(pdf_display, unsafe_allow_html=True)
# streamlit code
st.set_page_config(layout="wide")
def main():
st.title("RASA: Research Article Summarization App")
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
if uploaded_file is not None:
st.subheader("Options")
col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
with col1:
model_source_names = ["Cached model", "Download model"]
model_source = st.radio(
"For development:",
model_source_names,
help="Defaults to a cached model; downloading will take longer",
)
with col2:
model_names = [
"T5-Small",
"BART",
]
selected_model = st.radio(
"Select a model to use:",
model_names,
help="Defauls to T5-Small; for most articles it summarizes better than BART",
)
if selected_model == "BART":
checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
truncation=True,
model_max_length=1000,
trust_remote_code=True,
)
if model_source == "Download model":
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
torch_dtype=torch.float32,
trust_remote_code=True,
)
else:
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
else:
checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
truncation=True,
legacy=False,
model_max_length=1000,
)
if model_source == "Download model":
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
torch_dtype=torch.float32,
)
else:
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
with col3:
st.write("Skip any pages?")
skipfirst = st.checkbox(
"Skip first page", help="Select if your PDF has a cover page"
)
skiplast = st.checkbox("Skip last page")
with col4:
st.write("Background information (links open in a new window)")
st.write(
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
"&nbsp;&nbsp;|&nbsp;&nbsp;Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
)
st.write(
"Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
"&nbsp;&nbsp;|&nbsp;&nbsp;Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
)
if st.button("Summarize"):
col1, col2 = st.columns(2)
filepath = "data/" + uploaded_file.name
with open(filepath, "wb") as temp_file:
temp_file.write(uploaded_file.read())
with col1:
texts, input_text, preproc_text_length = preproc_count(
filepath, skipfirst, skiplast
)
st.info(
"Uploaded PDF&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
f"{preproc_text_length:,}"
)
pdf_viewer = displayPDF(filepath)
with col2:
start = time.time()
with st.spinner("Summarizing..."):
summary = llm_pipeline(
tokenizer, base_model, input_text, model_source
)
postproc_text_length = postproc_count(summary)
end = time.time()
duration = end - start
print("Duration: " f"{duration:.0f}" + " seconds")
st.info(
"PDF Summary&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
f"{postproc_text_length:,}"
+ "&nbsp;&nbsp;|&nbsp;&nbsp;Summarization time: "
f"{duration:.0f}" + " seconds"
)
if selected_model == "BART":
summary_cleaned = clean_summary_text(summary)
st.success(summary_cleaned)
with st.expander("Raw output"):
st.write(summary)
else:
st.success(summary)
col1 = st.columns(1)
url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
st.info("Additional information")
st.write("\n[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
st.write("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_size=1000")
st.write(
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_overlap=100"
)
st.write(
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;length_function=len"
)
st.write("")
st.write("Number of tokens generated: " + str(len(texts)))
st.write("")
st.write("First three tokens:")
st.write("----")
st.write(texts[0])
st.write("----")
st.write(texts[1])
st.write("----")
st.write(texts[2])
st.markdown(
"""<style>
div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
font-size: 1rem;
font-weight: 400;
}
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
margin-bottom: -15px;
}
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
margin-bottom: -15px;
}
body > a {
text-decoration: underline;
}
</style>
""",
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()