Spaces:
Runtime error
Runtime error
File size: 2,233 Bytes
9390326 7743799 9390326 c767df5 9390326 87e16bf 9390326 87e16bf 9390326 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import streamlit as st
@st.cache(allow_output_mutation=True, max_entries=1)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator")
model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator")
return (model, tokenizer)
def lexicalize_plan(
model, tokenizer, output_plan, temperature=1.0, num_beams=1
):
input_ids = tokenizer(output_plan, return_tensors="pt").input_ids
output = model.generate(
input_ids,
max_length=512,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
temperature=temperature,
num_beams=int(num_beams),
)
output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip()
return output_str
def run():
st.set_page_config(page_title="Schema Guided Dialogue Response Generation")
# sidebar
st.sidebar.title("SGD Response Generator Demo")
st.sidebar.image(
"https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1628568174585-6049d8edbaa99e90d94ee67c.png",
caption="SoLID at UNCC Logo",
)
st.sidebar.markdown("### Controls:")
temperature = st.sidebar.slider(
"Temperature",
min_value=0.5,
max_value=1.5,
value=0.8,
step=0.1,
)
num_beams = st.sidebar.slider(
"Num beams",
min_value=1,
max_value=4,
step=1,
value = 2,
)
# main body
model, tokenizer = get_model()
output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.")
submit_button = st.button("Generate Response")
if submit_button:
text = st.text("Generating Response...")
response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams)
text.empty()
st.write("Generated Response: " + str(response))
if __name__ == "__main__":
run() |