Nihal D'Souza
Custom textrank, changes to UI
a804ced
raw
history blame
2.48 kB
import os
import nltk
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import streamlit as st
from src.doc2vec import inference
from src.abstractive_sum import summarize_text_with_model
from src.textrank import custom_textrank_summarizer
from src.clean import clean_license_text
CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer"
nltk.download('punkt')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with st.spinner('Loading...'):
model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME)
summarization_type = st.sidebar.selectbox(
"Select summarization type.",
("Abstractive", "Extractive", "Both")
)
if summarization_type == 'Abstractive':
st.sidebar.caption('Summary will be generated by the T5 Transformer Model')
elif summarization_type == 'Extractive':
st.sidebar.caption('Summary will be generated by a custom TextRank Algorithm')
summary_len = st.sidebar.slider('Summary length percentage', 1, 10, 3)
elif summarization_type == 'Both':
st.sidebar.caption('The License text will be first passed through the custom TextRank algorithm and then passed on to the T5 Transformer Model to generate a summary.')
clean_text = st.sidebar.checkbox('Show cleaned license text')
st.title('Clearly Defined: License Summarizer')
input = st.text_area('Enter contents of the license')
if len(input) > 0:
with st.spinner('Loading...'):
if summarization_type == 'Abstractive':
summary, definitions = summarize_text_with_model(input, model, tokenizer)
if summarization_type == 'Extractive':
summary, definitions = custom_textrank_summarizer(input, summary_len = summary_len/10)
if summarization_type == 'Both':
summary, definitions = summarize_text_with_model(input, model, tokenizer)
summary, _ = custom_textrank_summarizer(summary, summary_len = 1)
if clean_text:
st.header('Cleaned License Text')
st.write(clean_license_text(input)[0])
st.header('Summary')
st.write(summary)
prediction_scores = inference(input)
st.header('Similarity Index')
st.dataframe(prediction_scores)
if definitions:
st.header('Definitions')
st.write(definitions)