File size: 1,570 Bytes
6511960
 
 
 
 
 
 
 
e85d00a
6511960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image

tokenizer= T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")

img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100

st.set_page_config(
    page_title='DEMO - keywords generation',
    page_icon=img_favicon,
    initial_sidebar_state="expanded",
)

def get_predictions(text):
    input_ids = tokenizer(
        text, return_tensors="pt", truncation=True
    ).input_ids
    output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
    predicted_kw = tokenizer.decode(output[0], skip_special_tokens=True)
    return predicted_kw

def trim_length():
    if len(st.session_state["input"]) > max_length:
        st.session_state["input"] = st.session_state["input"][:max_length]


if __name__ == "__main__":
    st.sidebar.image(img_short)
    st.image(img_full)
    st.title('VLT5 - keywords generation')

    generated_keywords = ""
    user_input = st.text_area(
        label=f"Input text (max {max_length} characters)",
        value="",
        height=300,
        on_change=trim_length,
        key="input",
    )
    
    result = st.button("Generate keywords")
    if result:
        generated_keywords = get_predictions(text=user_input)
        st.text_area("Generated keywords", generated_keywords)