Spaces:
Runtime error
Runtime error
File size: 3,332 Bytes
6511960 bf0a67a 6511960 3c003b1 7a004a6 a71360e 7a004a6 a71360e 7a004a6 6511960 e85d00a 6511960 a71360e 6511960 d32d77d a71360e 6511960 a71360e 6511960 a71360e 6511960 a71360e 61b01d0 6511960 a71360e 6511960 a71360e 6511960 7a004a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from transformers import T5ForConditionalGeneration, T5Tokenizer
import streamlit as st
from PIL import Image
import os
@st.cache(allow_output_mutation=True)
def load_model_cache():
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
tokenizer_en = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
)
model_en = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3-en", use_auth_token=auth_token
)
tokenizer_pl = T5Tokenizer.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
)
model_pl = T5ForConditionalGeneration.from_pretrained(
"Voicelab/vlt5-base-keywords-v4_3", use_auth_token=auth_token
)
return tokenizer_en, model_en, tokenizer_pl, model_pl
img_full = Image.open("images/vl-logo-nlp-blue.png")
img_short = Image.open("images/sVL-NLP-short.png")
img_favicon = Image.open("images/favicon_vl.png")
max_length: int = 1000
cache_size: int = 100
st.set_page_config(
page_title="DEMO - keywords generation",
page_icon=img_favicon,
initial_sidebar_state="expanded",
)
tokenizer_en, model_en, tokenizer_pl, model_pl = load_model_cache()
def get_predictions(text, language):
if language == "Polish":
input_ids = tokenizer_pl(text, return_tensors="pt", truncation=True).input_ids
output = model_pl.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_kw = tokenizer_pl.decode(output[0], skip_special_tokens=True)
elif language == "English":
input_ids = tokenizer_en(text, return_tensors="pt", truncation=True).input_ids
output = model_en.generate(
input_ids,
no_repeat_ngram_size=2,
num_beams=3,
num_beam_groups=3,
repetition_penalty=1.5,
diversity_penalty=2.0,
length_penalty=2.0,
)
predicted_kw = tokenizer_en.decode(output[0], skip_special_tokens=True)
return predicted_kw
def trim_length():
if len(st.session_state["input"]) > max_length:
st.session_state["input"] = st.session_state["input"][:max_length]
if __name__ == "__main__":
st.sidebar.image(img_short)
st.image(img_full)
st.title("VLT5 - keywords generation")
st.markdown("**Input**: Use abstract lentgh-like text fo best results. Providing very short or very long texts will result in significantly worse results.")
generated_keywords = ""
user_input = st.text_area(
label=f"Input text (max {max_length} characters)",
value="",
height=300,
on_change=trim_length,
key="input",
)
language = st.sidebar.title("Model settings")
language = st.sidebar.radio(
"Select model to test",
[
"Polish",
"English",
],
)
result = st.button("Generate keywords")
if result:
generated_keywords = get_predictions(text=user_input, language=language)
st.text_area("Generated keywords", generated_keywords)
print(f"Input: {user_input}---> Keywords: {generated_keywords}")
|