AgaMiko commited on
Commit
a71360e
·
1 Parent(s): bf0a67a

add new language

Browse files
Files changed (1) hide show
  1. app.py +54 -13
app.py CHANGED
@@ -5,8 +5,19 @@ import os
5
 
6
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
7
 
8
- tokenizer= T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token)
9
- model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token)
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  img_full = Image.open("images/vl-logo-nlp-blue.png")
12
  img_short = Image.open("images/sVL-NLP-short.png")
@@ -15,27 +26,49 @@ max_length: int = 1000
15
  cache_size: int = 100
16
 
17
  st.set_page_config(
18
- page_title='DEMO - keywords generation',
19
  page_icon=img_favicon,
20
  initial_sidebar_state="expanded",
21
  )
22
 
23
- def get_predictions(text):
24
- input_ids = tokenizer(
25
- text, return_tensors="pt", truncation=True
26
- ).input_ids
27
- output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
28
- predicted_kw = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return predicted_kw
30
 
 
31
  def trim_length():
32
  if len(st.session_state["input"]) > max_length:
33
  st.session_state["input"] = st.session_state["input"][:max_length]
34
 
35
 
36
  if __name__ == "__main__":
 
37
  st.image(img_full)
38
- st.title('VLT5 - keywords generation')
39
 
40
  generated_keywords = ""
41
  user_input = st.text_area(
@@ -45,9 +78,17 @@ if __name__ == "__main__":
45
  on_change=trim_length,
46
  key="input",
47
  )
48
-
 
 
 
 
 
 
 
 
 
49
  result = st.button("Generate keywords")
50
  if result:
51
- generated_keywords = get_predictions(text=user_input)
52
  st.text_area("Generated keywords", generated_keywords)
53
-
 
5
 
6
  auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
7
 
8
+ tokenizer_en = T5Tokenizer.from_pretrained(
9
+ "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
10
+ )
11
+ model_en = T5ForConditionalGeneration.from_pretrained(
12
+ "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
13
+ )
14
+
15
+ tokenizer_pl = T5Tokenizer.from_pretrained(
16
+ "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
17
+ )
18
+ model_pl = T5ForConditionalGeneration.from_pretrained(
19
+ "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
20
+ )
21
 
22
  img_full = Image.open("images/vl-logo-nlp-blue.png")
23
  img_short = Image.open("images/sVL-NLP-short.png")
 
26
  cache_size: int = 100
27
 
28
  st.set_page_config(
29
+ page_title="DEMO - keywords generation",
30
  page_icon=img_favicon,
31
  initial_sidebar_state="expanded",
32
  )
33
 
34
+
35
+ def get_predictions(text, language):
36
+ if language == "Polish":
37
+ input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
38
+ output = model_pl.generate(
39
+ input_ids,
40
+ no_repeat_ngram_size=2,
41
+ num_beams=3,
42
+ num_beam_groups=3,
43
+ repetition_penalty=1.5,
44
+ diversity_penalty=2.0,
45
+ length_penalty=2.0,
46
+ )
47
+ predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
48
+ elif language == "English":
49
+ input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
50
+ output = model_en.generate(
51
+ input_ids,
52
+ no_repeat_ngram_size=2,
53
+ num_beams=3,
54
+ num_beam_groups=3,
55
+ repetition_penalty=1.5,
56
+ diversity_penalty=2.0,
57
+ length_penalty=2.0,
58
+ )
59
+ predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
60
  return predicted_kw
61
 
62
+
63
  def trim_length():
64
  if len(st.session_state["input"]) > max_length:
65
  st.session_state["input"] = st.session_state["input"][:max_length]
66
 
67
 
68
  if __name__ == "__main__":
69
+ st.sidebar.image(img_short)
70
  st.image(img_full)
71
+ st.title("VLT5 - keywords generation")
72
 
73
  generated_keywords = ""
74
  user_input = st.text_area(
 
78
  on_change=trim_length,
79
  key="input",
80
  )
81
+
82
+ language = st.sidebar.title("Model settings")
83
+ language = st.sidebar.radio(
84
+ "Select model to test",
85
+ [
86
+ "Polish",
87
+ "English",
88
+ ],
89
+ )
90
+
91
  result = st.button("Generate keywords")
92
  if result:
93
+ generated_keywords = get_predictions(text=user_input, language=language)
94
  st.text_area("Generated keywords", generated_keywords)