Spaces:
Runtime error
Runtime error
File size: 3,333 Bytes
6511960 bf0a67a 6511960 3c003b1 7a004a6 a71360e 7a004a6 a71360e 7a004a6 6511960 e85d00a 6511960 a71360e 6511960 d32d77d a71360e 6511960 a71360e 6511960 a71360e 6511960 a71360e 62c3489 6511960 a71360e 6511960 a71360e 6511960 7a004a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os
@st.cache(allow_output_mutation=True)
def load_model_cache():
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer_en = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
)
model_en = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
)
tokenizer_pl = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
)
model_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
)
return tokenizer_en, model_en, tokenizer_pl, model_pl
img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100
st.set_page_config(
page_title="DEMO - keywords generation",
page_icon=img_favicon,
initial_sidebar_state="expanded",
)
tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
def get_predictions(text, language):
if language == "Polish":
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
output = model_pl.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
elif language == "English":
input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
output = model_en.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
return predicted_kw
def trim_length():
if len(st.session_state["input"]) > max_length:
st.session_state["input"] = st.session_state["input"][:max_length]
if __name__ == "__main__":
st.sidebar.image(img_short)
st.image(img_full)
st.title("VLT5 - keywords generation")
st.markdown("**Input**: Use abstract lentgh-like text for best results. Providing very short or very long texts will result in significantly worse results.")
generated_keywords = ""
user_input = st.text_area(
label=f"Input text (max {max_length} characters)",
value="",
height=300,
on_change=trim_length,
key="input",
)
language = st.sidebar.title("Model settings")
language = st.sidebar.radio(
"Select model to test",
[
"Polish",
"English",
],
)
result = st.button("Generate keywords")
if result:
generated_keywords = get_predictions(text=user_input, language=language)
st.text_area("Generated keywords", generated_keywords)
print(f"Input: {user_input}---> Keywords: {generated_keywords}")
|