File size: 3,171 Bytes
6511960
 
 
bf0a67a
6511960
3c003b1
 
7a004a6
 
a71360e
7a004a6
 
 
 
 
 
 
 
 
 
 
a71360e
7a004a6
 
6511960
 
e85d00a
6511960
 
 
 
 
a71360e
6511960
 
 
 
d32d77d
a71360e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6511960
 
a71360e
6511960
 
 
 
 
 
a71360e
6511960
a71360e
6511960
 
 
 
 
 
 
 
 
a71360e
 
 
 
 
 
 
 
 
 
6511960
 
a71360e
6511960
7a004a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os

@st.cache(allow_output_mutation=True)
def load_model_cache():
    auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
    tokenizer_en = T5Tokenizer.from_pretrained(
    "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
    )
    model_en = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
    )
    
    tokenizer_pl = T5Tokenizer.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
    )
    model_pl = T5ForConditionalGeneration.from_pretrained(
        "Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
    )

    return tokenizer_en, model_en, tokenizer_pl, model_pl


img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100

st.set_page_config(
    page_title="DEMO - keywords generation",
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()

def get_predictions(text, language):
    if language == "Polish":
        input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
        output = model_pl.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
        predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
    elif language == "English":
        input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
        output = model_en.generate(
            input_ids,
            no_repeat_ngram_size=2,
            num_beams=3,
            num_beam_groups=3,
            repetition_penalty=1.5,
            diversity_penalty=2.0,
            length_penalty=2.0,
        )
        predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
    return predicted_kw


def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title("VLT5 - keywords generation")

    generated_keywords = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )

    language = st.sidebar.title("Model settings")
    language = st.sidebar.radio(
        "Select model to test",
        [
            "Polish",
            "English",
        ],
    )

    result = st.button("Generate keywords")
    if result:
        generated_keywords = get_predictions(text=user_input, language=language)
        st.text_area("Generated keywords", generated_keywords)
        print(f"Input: {user_input}---> Keywords: {generated_keywords}")