Spaces:
Runtime error
Runtime error
add new language
Browse files
app.py
CHANGED
@@ -5,8 +5,19 @@ import os
|
|
5 |
|
6 |
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
12 |
img_short = Image.open("images/sVL-NLP-short.png")
|
@@ -15,27 +26,49 @@ max_length: int = 1000
|
|
15 |
cache_size: int = 100
|
16 |
|
17 |
st.set_page_config(
|
18 |
-
page_title=
|
19 |
page_icon=img_favicon,
|
20 |
initial_sidebar_state="expanded",
|
21 |
)
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return predicted_kw
|
30 |
|
|
|
31 |
def trim_length():
|
32 |
if len(st.session_state["input"]) > max_length:
|
33 |
st.session_state["input"] = st.session_state["input"][:max_length]
|
34 |
|
35 |
|
36 |
if __name__ == "__main__":
|
|
|
37 |
st.image(img_full)
|
38 |
-
st.title(
|
39 |
|
40 |
generated_keywords = ""
|
41 |
user_input = st.text_area(
|
@@ -45,9 +78,17 @@ if __name__ == "__main__":
|
|
45 |
on_change=trim_length,
|
46 |
key="input",
|
47 |
)
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
result = st.button("Generate keywords")
|
50 |
if result:
|
51 |
-
generated_keywords = get_predictions(text=user_input)
|
52 |
st.text_area("Generated keywords", generated_keywords)
|
53 |
-
|
|
|
5 |
|
6 |
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
|
7 |
|
8 |
+
tokenizer_en = T5Tokenizer.from_pretrained(
|
9 |
+
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
|
10 |
+
)
|
11 |
+
model_en = T5ForConditionalGeneration.from_pretrained(
|
12 |
+
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
|
13 |
+
)
|
14 |
+
|
15 |
+
tokenizer_pl = T5Tokenizer.from_pretrained(
|
16 |
+
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
|
17 |
+
)
|
18 |
+
model_pl = T5ForConditionalGeneration.from_pretrained(
|
19 |
+
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
|
20 |
+
)
|
21 |
|
22 |
img_full = Image.open("images/vl-logo-nlp-blue.png")
|
23 |
img_short = Image.open("images/sVL-NLP-short.png")
|
|
|
26 |
cache_size: int = 100
|
27 |
|
28 |
st.set_page_config(
|
29 |
+
page_title="DEMO - keywords generation",
|
30 |
page_icon=img_favicon,
|
31 |
initial_sidebar_state="expanded",
|
32 |
)
|
33 |
|
34 |
+
|
35 |
+
def get_predictions(text, language):
|
36 |
+
if language == "Polish":
|
37 |
+
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
|
38 |
+
output = model_pl.generate(
|
39 |
+
input_ids,
|
40 |
+
no_repeat_ngram_size=2,
|
41 |
+
num_beams=3,
|
42 |
+
num_beam_groups=3,
|
43 |
+
repetition_penalty=1.5,
|
44 |
+
diversity_penalty=2.0,
|
45 |
+
length_penalty=2.0,
|
46 |
+
)
|
47 |
+
predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
|
48 |
+
elif language == "English":
|
49 |
+
input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
|
50 |
+
output = model_en.generate(
|
51 |
+
input_ids,
|
52 |
+
no_repeat_ngram_size=2,
|
53 |
+
num_beams=3,
|
54 |
+
num_beam_groups=3,
|
55 |
+
repetition_penalty=1.5,
|
56 |
+
diversity_penalty=2.0,
|
57 |
+
length_penalty=2.0,
|
58 |
+
)
|
59 |
+
predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
|
60 |
return predicted_kw
|
61 |
|
62 |
+
|
63 |
def trim_length():
|
64 |
if len(st.session_state["input"]) > max_length:
|
65 |
st.session_state["input"] = st.session_state["input"][:max_length]
|
66 |
|
67 |
|
68 |
if __name__ == "__main__":
|
69 |
+
st.sidebar.image(img_short)
|
70 |
st.image(img_full)
|
71 |
+
st.title("VLT5 - keywords generation")
|
72 |
|
73 |
generated_keywords = ""
|
74 |
user_input = st.text_area(
|
|
|
78 |
on_change=trim_length,
|
79 |
key="input",
|
80 |
)
|
81 |
+
|
82 |
+
language = st.sidebar.title("Model settings")
|
83 |
+
language = st.sidebar.radio(
|
84 |
+
"Select model to test",
|
85 |
+
[
|
86 |
+
"Polish",
|
87 |
+
"English",
|
88 |
+
],
|
89 |
+
)
|
90 |
+
|
91 |
result = st.button("Generate keywords")
|
92 |
if result:
|
93 |
+
generated_keywords = get_predictions(text=user_input, language=language)
|
94 |
st.text_area("Generated keywords", generated_keywords)
|
|