File size: 5,858 Bytes
6bf8905
 
544b199
6bf8905
 
 
 
 
 
 
b9bfad5
6bf8905
73694e1
 
 
6bf8905
 
 
 
 
 
544b199
6bf8905
 
3092d79
6bf8905
 
 
3092d79
6bf8905
 
 
3092d79
6bf8905
 
 
 
 
cc276cb
6bf8905
cc276cb
9ee1c0c
6bf8905
 
 
 
544b199
 
 
6bf8905
 
 
544b199
 
 
 
d253047
 
4133bdb
b9bfad5
73694e1
b9bfad5
 
73694e1
 
 
62effc5
b9bfad5
4e4371b
b9bfad5
4e4371b
b9bfad5
d253047
4133bdb
6bf8905
d253047
6bf8905
b9bfad5
 
d253047
7abfa7c
d253047
7abfa7c
d253047
7abfa7c
 
2276c42
 
 
d253047
2276c42
 
 
544b199
 
 
 
 
035efeb
 
 
544b199
 
a202a75
544b199
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import torch
from pandas import options
from transformers import BartForConditionalGeneration, BartTokenizer

# initialize model + tok variables
model = None
tok = None

# Examples for each models
context_example = ''
examples = [
    "Well, I was born in South Africa, lived there until I was 17. Came to North America of my own accord, against my parent’s wishes. And was in Canada for a few years. I started school there which is where I met my wife. Transferred down to the University of Pennsylvania and got a degree in physics, degree in business at Wharton. Came out to California with the intent of doing a PHD in the material science and physics [unintelligible] with an eye towards using that as an energy storage unit for electric vehicles. I ended up deferring that graduate work to start a couple to start a couple of area companies, one of which people have heard about, such as Pay Pal.",
    "Hi my name is Maria Sanchez, and I was born in Japan. I lived there for 20 years and moved out to the United States for college. I studied graphic design and later realized that my true passion was in fashion. It's lovely to see amazing models wearing my collection this fall, can't wait to show it to you guys soon. ",
    "I moved from Indiana to California when I was 19 to pursue my career as an young entrepreneur with a small loan of million dollars. My first start up was Blindr, where we sold blinders that auto adjusts depending on the time of the day. It was revolutionary, in only 2 years, we were able to accumulate 10 million customers and gain attraction internationally. We are planning to go further beyond this year with Blindr 2.0 where not only auto adjusts your blinders, but it also detects intruders who are violating your privacy at any time. "
]

# Descriptions for each models
# descriptions = "Interview question remake is a model that..."

# pass in Strings of model choice and input text for context
@st.cache
def genQuestion(model_choice, context):
    # global descriptions
    if model_choice=="Base model":
        model = BartForConditionalGeneration.from_pretrained("hyechanjun/interview-question-remake")
        tok = BartTokenizer.from_pretrained("hyechanjun/interview-question-remake")
        # descriptions = "Interview question remake is a model that..."
    elif model_choice=="Lengthed model":
        model = BartForConditionalGeneration.from_pretrained("hyechanjun/interview-length-tagged")
        tok = BartTokenizer.from_pretrained("hyechanjun/interview-length-tagged")
        # descriptions = "Interview question tagged is a model that..."
    elif model_choice=="Reversed model":
        model = BartForConditionalGeneration.from_pretrained("hyechanjun/reverse-interview-question")
        tok = BartTokenizer.from_pretrained("hyechanjun/reverse-interview-question")
        # descriptions = "Reverse interview question  is a model that..."

    inputs = tok(context, return_tensors="pt")
    output = model.generate(inputs["input_ids"], num_beams=4, max_length=64, min_length=9, num_return_sequences=4, diversity_penalty=1.0, num_beam_groups=4)
    final_output = ''
    for i in range(4):
        final_output +=  [tok.decode(beam, skip_special_tokens=True, clean_up_tokenization_spaces=False) for beam in output][i] + "\n\n"

    return final_output


# Wide page layout (instead of having a narrower, one-column page layout)
st.set_page_config(layout="wide")

# Title
st.title("Interview AI Test Website")

# Adding a Session State to store stateful variables and for saving user's labels/tags for generated questions
if 'button_sent' not in st.session_state:
    st.session_state.button_sent = False

maxl, minl = st.columns(2)

context_option = minl.selectbox(
                'Feel free to choose one of our premade contexts',
                ('Select one','Elon Musk', 'Fashion designer', 'Young entrepreneur')
)

if context_option == 'Select one':
    context_example = ""
elif context_option == 'Elon Musk':
    context_example = examples[0]
elif context_option == 'Fashion designer':
    context_example = examples[1]
elif context_option == 'Young entrepreneur':
    context_example = examples[2]


option = maxl.selectbox(
                'Please select a model.',
                ('Base model', 'Lengthed model', 'Reverse model'))



if option == 'Base model':
    st.write("This is the re-fine-tuned base model for our interview AI. It returns strings terminating in a question mark (?).")
elif option == 'Lengthed model':
    st.write("This is a length-tagged version of our interview AI. You can specify how long its responses should be (ranges of multiples of 10)")
elif option == 'Reverse model':
    st.write("This model asks a question that would have resulted in the context you provide (a.k.a. it traverses backward through the interview)")

# Input fields
input = st.text_input('Context', value=context_example)                                    # user inputs context to construct a response (str)






# Column layout to display generated responses alongside tags
col1, col2 = st.columns((3, 1))

if st.button('Submit') or st.session_state.button_sent:
    with st.spinner('Generating a response...'):
        output = genQuestion(option, input)
        print(output)
    # st.write(output)
    st.session_state.button_sent = True
    col1.text_area(label="Generated Responses:", value=output, height=200)



# TODO:
#   - disable multiselect widget when responses are being generated AND when a question is not selected to be tagged
#   - connect tags with an individual question
#   - save session state so tags associated with their respective questions can also be saved
#   - write/store the saved state data to some database for future use?
#   - brainstorm good names for tags/labels OR allow users to enter their own tag names if possible