File size: 3,333 Bytes
6511960
 
 
bf0a67a
6511960
3c003b1
 
7a004a6
 
a71360e
7a004a6
 
 
 
 
 
 
 
 
 
 
a71360e
7a004a6
 
6511960
 
e85d00a
6511960
 
 
 
 
a71360e
6511960
 
 
 
d32d77d
a71360e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6511960
 
a71360e
6511960
 
 
 
 
 
a71360e
6511960
a71360e
62c3489
6511960
 
 
 
 
 
 
 
 
a71360e
 
 
 
 
 
 
 
 
 
6511960
 
a71360e
6511960
7a004a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os

@st.cache(allow_output_mutation=True)
def load_model_cache():
    auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
    tokenizer_en = T5Tokenizer.from_pretrained(
    "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
    )
    model_en = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
    )
    
    tokenizer_pl = T5Tokenizer.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
    )
    model_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
    )

    return tokenizer_en, model_en, tokenizer_pl, model_pl


img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100

st.set_page_config(
    page_title="DEMO - keywords generation",
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()

def get_predictions(text, language):
    if language == "Polish":
        input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
        output = model_pl.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
        predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
    elif language == "English":
        input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
        output = model_en.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
        predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
    return predicted_kw


def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title("VLT5 - keywords generation")
    st.markdown("**Input**: Use abstract lentgh-like text for best results. Providing very short or very long texts will result in significantly worse results.")

    generated_keywords = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )

    language = st.sidebar.title("Model settings")
    language = st.sidebar.radio(
        "Select model to test",
        [
            "Polish",
            "English",
        ],
    )

    result = st.button("Generate keywords")
    if result:
        generated_keywords = get_predictions(text=user_input, language=language)
        st.text_area("Generated keywords", generated_keywords)
        print(f"Input: {user_input}---> Keywords: {generated_keywords}")