Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import streamlit as st | |
def get_model(): | |
tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator") | |
model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator") | |
return (model, tokenizer) | |
def lexicalize_plan( | |
model, tokenizer, output_plan, temperature=1.0, num_beams=1 | |
): | |
input_ids = tokenizer(output_plan, return_tensors="pt").input_ids | |
output = model.generate( | |
input_ids, | |
max_length=512, | |
do_sample=True, | |
top_p=0.95, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
early_stopping=True, | |
temperature=temperature, | |
num_beams=int(num_beams), | |
) | |
output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip() | |
return output_str | |
def run(): | |
st.set_page_config(page_title="Schema Guided Dialogue Response Generation") | |
# sidebar | |
st.sidebar.title("SGD Response Generator Demo") | |
st.sidebar.image( | |
"https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1628568174585-6049d8edbaa99e90d94ee67c.png", | |
caption="SoLID at UNCC Logo", | |
) | |
st.sidebar.markdown("### Controls:") | |
temperature = st.sidebar.slider( | |
"Temperature", | |
min_value=0.5, | |
max_value=1.5, | |
value=0.8, | |
step=0.1, | |
) | |
num_beams = st.sidebar.slider( | |
"Num beams", | |
min_value=1, | |
max_value=4, | |
step=1, | |
value = 2, | |
) | |
# main body | |
model, tokenizer = get_model() | |
output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.") | |
submit_button = st.button("Generate Response") | |
if submit_button: | |
text = st.text("Generating Response...") | |
response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams) | |
text.empty() | |
st.write("Generated Response: " + str(response)) | |
if __name__ == "__main__": | |
run() |