import streamlit as st
from transformers import pipeline

st.set_page_config(page_title="Common NLP Tasks")
st.title("Common NLP Tasks")
st.subheader(":point_left: Use the menu on the left to select a NLP task (click on > if closed).")
"""
[![](https://img.shields.io/github/followers/OOlajide?label=OOlajide&style=social)](https://gitHub.com/OOlajide)
[![](https://img.shields.io/twitter/follow/sageOlamide?label=@sageOlamide&style=social)](https://twitter.com/sageOlamide)
"""
expander = st.sidebar.expander("About")
expander.write("This web app allows you to perform common Natural Language Processing tasks, select a task below to get started.")

st.sidebar.header("What will you like to do?")
option = st.sidebar.radio("", ["Text summarization", "Extractive question answering", "Text generation"])

@st.cache(show_spinner=False, allow_output_mutation=True)
def question_model():
    model_name = "deepset/tinyroberta-squad2"
    question_answerer = pipeline(model=model_name, tokenizer=model_name, task="question-answering")
    return question_answerer

@st.cache(show_spinner=False, allow_output_mutation=True)
def summarization_model():
    model_name = "google/pegasus-xsum"
    summarizer = pipeline(model=model_name, tokenizer=model_name, task="summarization")
    return summarizer

@st.cache(show_spinner=False, allow_output_mutation=True)
def generation_model():
    model_name = "distilgpt2"
    generator = pipeline(model=model_name, tokenizer=model_name, task="text-generation")
    return generator

if option == "Extractive question answering":
    st.markdown("<h2 style='text-align: center; color:grey;'>Extractive Question Answering</h2>", unsafe_allow_html=True)
    st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is extractive question answering about?<b></h3>", unsafe_allow_html=True)
    st.write("Extractive question answering is a Natural Language Processing task where text is provided for a model so that the model can refer to it and make predictions about where the answer to a question is.")
    st.markdown('___')
    source = st.radio("How would you like to start? Choose an option below", ["I want to input some text", "I want to upload a file"])
    sample_question = "What did the shepherd boy do to amuse himself?"
    if source == "I want to input some text":
        with open("sample.txt", "r") as text_file:
            sample_text = text_file.read()
        context = st.text_area("Use the example below or input your own text in English (10,000 characters max)", value=sample_text, max_chars=10000, height=330)
        question = st.text_input(label="Use the question below or enter your own question", value=sample_question)
        button = st.button("Get answer")
        if button:
            with st.spinner(text="Loading question model..."):
                question_answerer = question_model()
            with st.spinner(text="Getting answer..."):
                answer = question_answerer(context=context, question=question)
                answer = answer["answer"]
                st.text(answer)
    elif source == "I want to upload a file":
        uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
        if uploaded_file is not None:
            raw_text = str(uploaded_file.read(),"utf-8")
            context = st.text_area("", value=raw_text, height=330)
            question = st.text_input(label="Enter your question", value=sample_question)
            button = st.button("Get answer")
            if button:
                with st.spinner(text="Loading summarization model..."):
                    question_answerer = question_model()
                with st.spinner(text="Getting answer..."):
                    answer = question_answerer(context=context, question=question)
                    answer = answer["answer"]
                    st.text(answer)

elif option == "Text summarization":
    st.markdown("<h2 style='text-align: center; color:grey;'>Text Summarization</h2>", unsafe_allow_html=True)
    st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is text summarization about?<b></h3>", unsafe_allow_html=True)
    st.write("Text summarization is producing a shorter version of a given text while preserving its important information.")
    st.markdown('___')
    source = st.radio("How would you like to start? Choose an option below", ["I want to input some text", "I want to upload a file"])
    if source == "I want to input some text":
        with open("sample.txt", "r") as text_file:
            sample_text = text_file.read()
        text = st.text_area("Input a text in English (10,000 characters max) or use the example below", value=sample_text, max_chars=10000, height=330)
        button = st.button("Get summary")
        if button:
            with st.spinner(text="Loading summarization model..."):
                summarizer = summarization_model()
            with st.spinner(text="Summarizing text..."):
                summary = summarizer(text, max_length=130, min_length=30)
                st.text(summary[0]["summary_text"])

    elif source == "I want to upload a file":
        uploaded_file = st.file_uploader("Choose a .txt file to upload", type=["txt"])
        if uploaded_file is not None:
            raw_text = str(uploaded_file.read(),"utf-8")
            text = st.text_area("", value=raw_text, height=330)
            button = st.button("Get summary")
            if button:
                with st.spinner(text="Loading summarization model..."):
                    summarizer = summarization_model()
                with st.spinner(text="Summarizing text..."):
                    summary = summarizer(text, max_length=130, min_length=30)
                    st.text(summary[0]["summary_text"])
                
elif option == "Text generation":
    st.markdown("<h2 style='text-align: center; color:grey;'>Text Generation</h2>", unsafe_allow_html=True)
    st.markdown("<h3 style='text-align: left; color:#F63366; font-size:18px;'><b>What is text generation about?<b></h3>", unsafe_allow_html=True)
    st.write("Text generation is the task of generating text with the goal of appearing indistinguishable to human-written text.")
    st.markdown('___')
    text = st.text_input(label="Enter one line of text and let the NLP model generate the rest for you")
    button = st.button("Generate text")
    if button:
        with st.spinner(text="Loading text generation model..."):
            generator = generation_model()
        with st.spinner(text="Generating text..."):
            generated_text = generator(text, max_length=50)
            st.text(generated_text[0]["generated_text"])