Spaces:
Sleeping
Sleeping
File size: 3,633 Bytes
bea6893 b9afdfb bea6893 b9afdfb bea6893 ca0f425 bea6893 ca0f425 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 ca0f425 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 b9afdfb 1cf06d2 b9afdfb 1cf06d2 64def62 1cf06d2 0951988 1cf06d2 b9afdfb bea6893 1cf06d2 c80bd44 64def62 bea6893 1cf06d2 bea6893 1cf06d2 bea6893 ca0f425 64def62 bea6893 1cf06d2 c80bd44 52f8d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os
@st.cache(allow_output_mutation=True)
def load_model_cache():
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer_pl = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
)
model_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
)
model_det_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-rfc-detector-1.0", use_auth_token=auth_token
)
return tokenizer_pl, model_pl, model_det_pl
img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 5000
cache_size: int = 100
st.set_page_config(
page_title="DEMO - Reason for Contact generation",
page_icon=img_favicon,
initial_sidebar_state="expanded",
)
tokenizer_pl, model_pl, model_det_pl = load_model_cache()
def get_predictions(text, mode):
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
if mode == "Polish - RfC Generation":
output = model_pl.generate(
input_ids,
no_repeat_ngram_size=1,
num_beams=3,
num_beam_groups=3,
min_length=10,
max_length=100,
diversity_penalty=1.0,
)
elif mode == "Polish - RfC Detection":
output = model_det_pl.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_rfc = tokenizer_pl.decode(output[0], skip_special_tokens=True)
return predicted_rfc
def trim_length():
if len(st.session_state["input"]) > max_length:
st.session_state["input"] = st.session_state["input"][:max_length]
if __name__ == "__main__":
st.sidebar.image(img_short)
st.image(img_full)
st.title("VLT5 - Reason for Contact generator")
st.markdown("#### RfC Generation model.")
st.markdown("**Input***: Whole conversation. Should specify roles (e.g. *AGENT: Hello, how can I help you? CLIENT: Hi, I would like to report a stolen card.* Put a whole conversation or full e-mail here.")
st.markdown("**Output**: Reason for calling for the whole conversation.")
st.markdown("#### RfC Detection model.")
st.markdown("**Input**: A single turn from the conversation e.g. *'Hello, how can I help you?'* or *'Hi, I would like to report a stolen card.'. Put a single turn or a few sentences here.*")
st.markdown("**Output**: Model will return an empty string if a turn possibly does not includes Reason for Calling, or a sentence if the RfC is detected.")
generated_rfc = ""
user_input = st.text_area(
label=f"Input text (max {max_length} characters)",
value="",
height=300,
on_change=trim_length,
key="input",
)
mode = st.sidebar.title("Model settings")
mode = st.sidebar.radio(
"Select model to test",
[
"Polish - RfC Generation",
"Polish - RfC Detection",
],
)
result = st.button("Find reason for contact")
if result:
generated_rfc = get_predictions(text=user_input, mode=mode)
st.text_area("Find reason for contact", generated_rfc)
print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}") |