AgaMiko's picture
init app
6511960
raw
history blame
1.57 kB
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
tokenizer= T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")
img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/vl-logo-nlp-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100
st.set_page_config(
page_title='DEMO - keywords generation',
page_icon=img_favicon,
initial_sidebar_state="expanded",
)
def get_predictions(text):
input_ids = tokenizer(
text, return_tensors="pt", truncation=True
).input_ids
output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4)
predicted_kw = tokenizer.decode(output[0], skip_special_tokens=True)
return predicted_kw
def trim_length():
if len(st.session_state["input"]) > max_length:
st.session_state["input"] = st.session_state["input"][:max_length]
if __name__ == "__main__":
st.sidebar.image(img_short)
st.image(img_full)
st.title('VLT5 - keywords generation')
generated_keywords = ""
user_input = st.text_area(
label=f"Input text (max {max_length} characters)",
value="",
height=300,
on_change=trim_length,
key="input",
)
result = st.button("Generate keywords")
if result:
generated_keywords = get_predictions(text=user_input)
st.text_area("Generated keywords", generated_keywords)