import os import nltk from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import streamlit as st from src.doc2vec import inference from src.abstractive_sum import summarize_text_with_model from src.textrank import custom_textrank_summarizer from src.clean import clean_license_text CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer" nltk.download('punkt') os.environ["TOKENIZERS_PARALLELISM"] = "false" device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") with st.spinner('Loading...'): model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device) tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME) summarization_type = st.sidebar.selectbox( "Select summarization type.", ("Abstractive", "Extractive", "Both") ) if summarization_type == 'Abstractive': st.sidebar.caption('Summary will be generated by the T5 Transformer Model') elif summarization_type == 'Extractive': st.sidebar.caption('Summary will be generated by a custom TextRank Algorithm') summary_len = st.sidebar.slider('Summary length percentage', 1, 100, 30) elif summarization_type == 'Both': st.sidebar.caption('The License text will be first passed through the custom TextRank algorithm and then passed on to the T5 Transformer Model to generate a summary.') clean_text = st.sidebar.checkbox('Show cleaned license text') st.title('Clearly Defined: License Summarizer') input = st.text_area('Enter contents of the license') if len(input) > 0: with st.spinner('Loading...'): if summarization_type == 'Abstractive': summary, definitions = summarize_text_with_model(input, model, tokenizer) if summarization_type == 'Extractive': summary, definitions = custom_textrank_summarizer(input, summary_len = summary_len/100) if summarization_type == 'Both': summary, definitions = summarize_text_with_model(input, model, tokenizer) summary, _ = custom_textrank_summarizer(summary, summary_len = 1) st.header('Summary') st.write(summary) prediction_scores = inference(input) st.header('Similarity Index') st.dataframe(prediction_scores) if definitions: st.header('Definitions') st.write(definitions) if clean_text: st.header('Cleaned License Text') st.write(clean_license_text(input)[0])