Spaces:
Runtime error
Runtime error
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import streamlit as st | |
from PIL import Image | |
tokenizer= T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords") | |
model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords") | |
img_full = Image.open("images/vl-logo-nlp-blue.png") | |
img_short = Image.open("images/vl-logo-nlp-short.png") | |
img_favicon = Image.open("images/favicon_vl.png") | |
max_length: int = 1000 | |
cache_size: int = 100 | |
st.set_page_config( | |
page_title='DEMO - keywords generation', | |
page_icon=img_favicon, | |
initial_sidebar_state="expanded", | |
) | |
def get_predictions(text): | |
input_ids = tokenizer( | |
text, return_tensors="pt", truncation=True | |
).input_ids | |
output = model.generate(input_ids, no_repeat_ngram_size=3, num_beams=4) | |
predicted_kw = tokenizer.decode(output[0], skip_special_tokens=True) | |
return predicted_kw | |
def trim_length(): | |
if len(st.session_state["input"]) > max_length: | |
st.session_state["input"] = st.session_state["input"][:max_length] | |
if __name__ == "__main__": | |
st.sidebar.image(img_short) | |
st.image(img_full) | |
st.title('VLT5 - keywords generation') | |
generated_keywords = "" | |
user_input = st.text_area( | |
label=f"Input text (max {max_length} characters)", | |
value="", | |
height=300, | |
on_change=trim_length, | |
key="input", | |
) | |
result = st.button("Generate keywords") | |
if result: | |
generated_keywords = get_predictions(text=user_input) | |
st.text_area("Generated keywords", generated_keywords) | |