File size: 10,442 Bytes
6fe569d
 
 
 
 
 
6901ce4
6fe569d
6901ce4
052edd8
6fe569d
 
6901ce4
6fe569d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6901ce4
 
 
 
 
 
 
 
 
6fe569d
 
 
 
 
 
 
6901ce4
 
 
a95a714
 
 
 
 
 
6fe569d
 
6901ce4
 
6fe569d
 
a95a714
6fe569d
6901ce4
 
 
a95a714
6901ce4
 
 
 
 
 
 
 
6fe569d
 
 
d124ecd
6fe569d
 
 
 
6901ce4
 
 
6fe569d
6901ce4
990e81e
6fe569d
e828745
6901ce4
 
 
 
e828745
6fe569d
 
a95a714
621da38
6901ce4
a95a714
621da38
 
 
6901ce4
 
 
 
 
 
 
 
 
 
 
 
6fe569d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d124ecd
6fe569d
6901ce4
 
 
 
 
 
d124ecd
6fe569d
 
 
 
6901ce4
 
 
6ffbf05
6901ce4
6fe569d
 
 
 
 
 
6901ce4
6fe569d
b8f16a6
d124ecd
6901ce4
 
 
d124ecd
6901ce4
d124ecd
 
6fe569d
 
 
 
 
6901ce4
6fe569d
b8f16a6
d124ecd
6901ce4
 
d124ecd
 
 
 
6fe569d
6901ce4
 
 
6fe569d
d124ecd
6fe569d
 
948faf9
6901ce4
6fe569d
 
948faf9
6901ce4
6fe569d
 
 
 
 
 
 
6901ce4
 
 
6fe569d
 
bb2f7b8
6fe569d
 
 
7edcbdb
2d3e634
6901ce4
 
 
bb2f7b8
7edcbdb
 
a95a714
6fe569d
 
bb2f7b8
7edcbdb
 
6fe569d
6901ce4
 
 
 
 
 
 
 
 
 
6ffbf05
6901ce4
 
 
 
 
 
 
 
 
 
 
b2d65e0
6901ce4
b2d65e0
6901ce4
b2d65e0
6901ce4
4bd0367
6fe569d
 
 
 
 
 
 
 
 
 
6901ce4
6fe569d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import base64
from langchain.chains.summarize import load_summarize_chain
from langchain.docstore.document import Document
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
import re
import streamlit as st
import sys
import time
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers import pipeline

# notes
# https://huggingface.co/docs/transformers/pad_truncation


# file loader and preprocessor
def file_preprocessing(file, skipfirst, skiplast):
    loader = PyMuPDFLoader(file)
    pages = loader.load_and_split()
    # skip page(s)
    if (skipfirst == 1) & (skiplast == 0):
        del pages[0]
    elif (skipfirst == 0) & (skiplast == 1):
        del pages[-1]
    elif (skipfirst == 1) & (skiplast == 1):
        del pages[0]
        del pages[-1]
    else:
        pages = pages
    # https://stackoverflow.com/questions/76431655/langchain-pypdfloader
    content = ""
    for page in pages:
        content = content + page.page_content
    content = re.sub("-\n", "", content)
    print("\n###### New article ######\n")
    print("Input text:\n")
    print(content)
    print("\nChunking...")
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,  # number of characters
        chunk_overlap=100,
        length_function=len,
        separators=["\n\n", "\n", " ", ""],  # default list
    )
    # https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846
    texts = text_splitter.split_text(content)
    print("Number of tokens: " + str(len(texts)))
    print("\nFirst three tokens:\n")
    print(texts[0])
    print("")
    print(texts[1])
    print("")
    print(texts[2])
    print("")
    final_texts = ""
    for text in texts:
        final_texts = final_texts + text
    return texts, final_texts


# function to count words in the input
def preproc_count(filepath, skipfirst, skiplast):
    texts, input_text = file_preprocessing(filepath, skipfirst, skiplast)
    input_text = input_text.replace("-", "")
    text_length = len(re.findall(r"\w+", input_text))
    print("Input word count: " f"{text_length:,}")
    return texts, input_text, text_length


# function to covert (bart) summary to sentence case
def convert_to_sentence_case(text):
    sentences = re.split(r"(?<=[.!?])\s+", text)
    formatted_sentences = [sentence.capitalize() for sentence in sentences]
    return " ".join(formatted_sentences)


# llm pipeline
def llm_pipeline(tokenizer, base_model, input_text, model_source):
    pipe_sum = pipeline(
        "summarization",
        model=base_model,
        tokenizer=tokenizer,
        max_length=300,
        min_length=200,
        truncation=True,
    )
    print("Model source: %s" % (model_source))
    print("Summarizing...")
    result = pipe_sum(input_text)
    summary = result[0]["summary_text"]
    print("Summarization finished\n")
    print("Summary text:\n")
    print(summary)
    print("")
    return summary


# function to count words in the summary
def postproc_count(summary):
    text_length = len(re.findall(r"\w+", summary))
    print("Summary word count: " f"{text_length:,}")
    return text_length


# function to clean summary text
def clean_summary_text(summary):
    # remove whitespace
    summary_clean_1 = summary.strip()
    # remove spaces before punctuation (bart)
    summary_clean_2 = re.sub(r'\s([,.():;?!"](?:\s|$))', r"\1", summary_clean_1)
    # convert to sentence case
    summary_clean_3 = convert_to_sentence_case(summary_clean_2)
    return summary_clean_3


@st.cache_data(ttl=60 * 60)
# function to display the PDF
def displayPDF(file):
    with open(file, "rb") as f:
        base64_pdf = base64.b64encode(f.read()).decode("utf-8")
    # embed pdf in html
    pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
    # display file
    st.markdown(pdf_display, unsafe_allow_html=True)


# streamlit code
st.set_page_config(layout="wide")


def main():
    st.title("RASA: Research Article Summarization App")
    uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"])
    if uploaded_file is not None:
        st.subheader("Options")
        col1, col2, col3, col4 = st.columns([1, 1, 1, 2])
        with col1:
            model_source_names = ["Cached model", "Download model"]
            model_source = st.radio(
                "For development:",
                model_source_names,
                help="Defaults to a cached model; downloading will take longer",
            )
        with col2:
            model_names = [
                "T5-Small",
                "BART",
            ]
            selected_model = st.radio(
                "Select a model to use:",
                model_names,
                help="Defauls to T5-Small; for most articles it summarizes better than BART",
            )
            if selected_model == "BART":
                checkpoint = "ccdv/lsg-bart-base-16384-pubmed"
                tokenizer = AutoTokenizer.from_pretrained(
                    checkpoint,
                    truncation=True,
                    model_max_length=1000,
                    trust_remote_code=True,
                )
                if model_source == "Download model":
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(
                        checkpoint,
                        torch_dtype=torch.float32,
                        trust_remote_code=True,
                    )
                else:
                    base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15"
            else:
                checkpoint = "MBZUAI/LaMini-Flan-T5-77M"
                tokenizer = AutoTokenizer.from_pretrained(
                    checkpoint,
                    truncation=True,
                    legacy=False,
                    model_max_length=1000,
                )
                if model_source == "Download model":
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(
                        checkpoint,
                        torch_dtype=torch.float32,
                    )
                else:
                    base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474"
        with col3:
            st.write("Skip any pages?")
            skipfirst = st.checkbox(
                "Skip first page", help="Select if your PDF has a cover page"
            )
            skiplast = st.checkbox("Skip last page")
        with col4:
            st.write("Background information (links open in a new window)")
            st.write(
                "Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)"
                "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)"
            )
            st.write(
                "Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)"
                "&nbsp;&nbsp;|&nbsp;&nbsp;Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)"
            )
        if st.button("Summarize"):
            col1, col2 = st.columns(2)
            filepath = "data/" + uploaded_file.name
            with open(filepath, "wb") as temp_file:
                temp_file.write(uploaded_file.read())
            with col1:
                texts, input_text, preproc_text_length = preproc_count(
                    filepath, skipfirst, skiplast
                )
                st.info(
                    "Uploaded PDF&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
                    f"{preproc_text_length:,}"
                )
                pdf_viewer = displayPDF(filepath)
            with col2:
                start = time.time()
                with st.spinner("Summarizing..."):
                    summary = llm_pipeline(
                        tokenizer, base_model, input_text, model_source
                    )
                    postproc_text_length = postproc_count(summary)
                end = time.time()
                duration = end - start
                print("Duration: " f"{duration:.0f}" + " seconds")
                st.info(
                    "PDF Summary&nbsp;&nbsp;|&nbsp;&nbsp;Number of words: "
                    f"{postproc_text_length:,}"
                    + "&nbsp;&nbsp;|&nbsp;&nbsp;Summarization time: "
                    f"{duration:.0f}" + " seconds"
                )
                if selected_model == "BART":
                    summary_cleaned = clean_summary_text(summary)
                    st.success(summary_cleaned)
                    with st.expander("Raw output"):
                        st.write(summary)
                else:
                    st.success(summary)
            col1 = st.columns(1)
            url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846"
            st.info("Additional information")
            st.write("\n[RecursiveCharacterTextSplitter](%s) parameters used:" % url)
            st.write("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_size=1000")
            st.write(
                "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;chunk_overlap=100"
            )
            st.write(
                "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;length_function=len"
            )
            st.write("")
            st.write("Number of tokens generated: " + str(len(texts)))
            st.write("")
            st.write("First three tokens:")
            st.write("----")
            st.write(texts[0])
            st.write("----")
            st.write(texts[1])
            st.write("----")
            st.write(texts[2])


st.markdown(
    """<style>
div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p {
    font-size: 1rem;
    font-weight: 400;
}
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p {
    margin-bottom: -15px;
}
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] {
    margin-bottom: -15px;
}
body > a {
    text-decoration: underline;
}
    </style>
    """,
    unsafe_allow_html=True,
)


if __name__ == "__main__":
    main()