File size: 2,081 Bytes
606291c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st

import inference
from app_utils import get_default_texts, display_output

summarizer_path = 'google/pegasus-large'
qa_path = 'qa_model'
entity_recognition_path = 'entity_rec'

if 'sample_index' not in st.session_state:
    st.session_state['sample_index'] = 0

if 'which_button' not in st.session_state:
    st.session_state['which_button'] = 'sample_button'

st.title('NLP Demo')

with st.sidebar:
    st.header("Select your choices")
    ops_to_perform = st.multiselect('Select operation to perform :', ['Question Answering', 'Entity Recognition', 'Text Summarization'],
                                    default=['Question Answering'])
    chosen_dataset = st.selectbox("Choose one of the datasest to get samples :", ['squad-qa', 'bbc-xsum-summarization', 'conll-ner'])

samples_dict = get_default_texts(chosen_dataset)
tot_index = len(samples_dict)

st.write('**Select from sample images**')

st.write("Select one from these available samples: ")
current_index = st.session_state['sample_index']

prev_button, next_button = st.columns(2)
with prev_button:
    prev = st.button('prev_text')
with next_button:
    next = st.button('next_text')
if prev:
    current_index = (current_index - 1) % tot_index
if next:
    current_index = (current_index + 1) % tot_index
st.session_state['sample_index'] = current_index
sample_text = samples_dict[current_index]
input_text = st.text_area("Input text to perform selected operations on : ", sample_text)

question = None
if "Question Answering" in ops_to_perform:
    question = st.text_input("Enter a valid question here :")

predict_clicked = st.button("Submit for predictions")
if predict_clicked:
    which_button = st.session_state['which_button']
    if which_button == 'sample_button':
        all_outputs = inference.get_predictions(input_text, ops_to_perform, question)
        st.session_state['prev_outputs'] = all_outputs
        display_output(all_outputs)
else:
    if 'prev_outputs' in st.session_state:
        all_outputs = st.session_state['prev_outputs']
        display_output(all_outputs)