Update app.py
Browse files
app.py
CHANGED
@@ -6,62 +6,42 @@ import os
|
|
6 |
@st.cache(allow_output_mutation=True)
|
7 |
def load_model_cache():
|
8 |
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
|
9 |
-
tokenizer_en = T5Tokenizer.from_pretrained(
|
10 |
-
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
|
11 |
-
)
|
12 |
-
model_en = T5ForConditionalGeneration.from_pretrained(
|
13 |
-
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
|
14 |
-
)
|
15 |
|
16 |
tokenizer_pl = T5Tokenizer.from_pretrained(
|
17 |
-
"Voicelab/vlt5-base-
|
18 |
)
|
19 |
model_pl = T5ForConditionalGeneration.from_pretrained(
|
20 |
-
"Voicelab/vlt5-base-
|
21 |
)
|
22 |
|
23 |
-
return
|
24 |
|
25 |
|
26 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
27 |
img_short = Image.open("images/sVL-NLP-short.png")
|
28 |
img_favicon = Image.open("images/favicon_vl.png")
|
29 |
-
max_length: int =
|
30 |
cache_size: int = 100
|
31 |
|
32 |
st.set_page_config(
|
33 |
-
page_title="DEMO -
|
34 |
page_icon=img_favicon,
|
35 |
initial_sidebar_state="expanded",
|
36 |
)
|
37 |
|
38 |
tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
|
39 |
|
40 |
-
def get_predictions(text
|
41 |
-
if language == "Polish":
|
42 |
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
|
43 |
output = model_pl.generate(
|
44 |
input_ids,
|
45 |
-
no_repeat_ngram_size=
|
46 |
num_beams=3,
|
47 |
num_beam_groups=3,
|
48 |
-
|
49 |
-
|
50 |
-
length_penalty=2.0,
|
51 |
)
|
52 |
predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
|
53 |
-
elif language == "English":
|
54 |
-
input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
|
55 |
-
output = model_en.generate(
|
56 |
-
input_ids,
|
57 |
-
no_repeat_ngram_size=2,
|
58 |
-
num_beams=3,
|
59 |
-
num_beam_groups=3,
|
60 |
-
repetition_penalty=1.5,
|
61 |
-
diversity_penalty=2.0,
|
62 |
-
length_penalty=2.0,
|
63 |
-
)
|
64 |
-
predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
|
65 |
return predicted_kw
|
66 |
|
67 |
|
@@ -73,7 +53,7 @@ def trim_length():
|
|
73 |
if __name__ == "__main__":
|
74 |
st.sidebar.image(img_short)
|
75 |
st.image(img_full)
|
76 |
-
st.title("VLT5 -
|
77 |
|
78 |
generated_keywords = ""
|
79 |
user_input = st.text_area(
|
@@ -89,12 +69,11 @@ if __name__ == "__main__":
|
|
89 |
"Select model to test",
|
90 |
[
|
91 |
"Polish",
|
92 |
-
"English",
|
93 |
],
|
94 |
)
|
95 |
|
96 |
-
result = st.button("
|
97 |
if result:
|
98 |
-
|
99 |
-
st.text_area("
|
100 |
-
print(f"Input: {user_input}--->
|
|
|
6 |
@st.cache(allow_output_mutation=True)
|
7 |
def load_model_cache():
|
8 |
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
tokenizer_pl = T5Tokenizer.from_pretrained(
|
11 |
+
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
|
12 |
)
|
13 |
model_pl = T5ForConditionalGeneration.from_pretrained(
|
14 |
+
"Voicelab/vlt5-base-rfc-v1_2", use_auth_token=auth_token
|
15 |
)
|
16 |
|
17 |
+
return tokenizer_pl, model_pl
|
18 |
|
19 |
|
20 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
21 |
img_short = Image.open("images/sVL-NLP-short.png")
|
22 |
img_favicon = Image.open("images/favicon_vl.png")
|
23 |
+
max_length: int = 5000
|
24 |
cache_size: int = 100
|
25 |
|
26 |
st.set_page_config(
|
27 |
+
page_title="DEMO - Reason for Contact detection",
|
28 |
page_icon=img_favicon,
|
29 |
initial_sidebar_state="expanded",
|
30 |
)
|
31 |
|
32 |
tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
|
33 |
|
34 |
+
def get_predictions(text):
|
|
|
35 |
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
|
36 |
output = model_pl.generate(
|
37 |
input_ids,
|
38 |
+
no_repeat_ngram_size=1,
|
39 |
num_beams=3,
|
40 |
num_beam_groups=3,
|
41 |
+
min_length=10,
|
42 |
+
max_length=100,
|
|
|
43 |
)
|
44 |
predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return predicted_kw
|
46 |
|
47 |
|
|
|
53 |
if __name__ == "__main__":
|
54 |
st.sidebar.image(img_short)
|
55 |
st.image(img_full)
|
56 |
+
st.title("VLT5 - RfC generation")
|
57 |
|
58 |
generated_keywords = ""
|
59 |
user_input = st.text_area(
|
|
|
69 |
"Select model to test",
|
70 |
[
|
71 |
"Polish",
|
|
|
72 |
],
|
73 |
)
|
74 |
|
75 |
+
result = st.button("Find reason for contact")
|
76 |
if result:
|
77 |
+
generated_rfc = get_predictions(text=user_input, language=language)
|
78 |
+
st.text_area("Reason", generated_rfc)
|
79 |
+
print(f"Input: {user_input} ---> Reason for contact: {generated_rfc}")
|