File size: 2,668 Bytes
bd61da5 9f4fb0e 233c774 9f4fb0e 233c774 ff3389c 076a158 bd61da5 9f4fb0e bd61da5 233c774 bd61da5 6b329ae 233c774 9f4fb0e 233c774 ff3389c 56e76a4 ff3389c b7137a2 233c774 9f4fb0e b7137a2 bcfdb10 9f4fb0e 076a158 bcfdb10 9f4fb0e c374e34 9f4fb0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import streamlit as st
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification
# MODEL TO CALL
generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)
classifier_name = "Alirani/overview_classifier_final"
tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)
def generate_synopsis(model, tokenizer, title):
input_ids = tokenizer(title, return_tensors="tf")
output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
return processed_synopsis
def generate_classification(model, tokenizer, title, overview):
tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
output = model(**tokens)
return output
favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
st.title('Demo LoreFinder')
st.header('Generate a story')
prod_title = st.text_input('Type a title to generate a synopsis')
prod_synopsis = st.text_input('Type a synopsis to guess the genre')
option_genres = st.selectbox(
'Select a genre to tailor your synopsis',
('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality'),
index=None,
placeholder="Select genres..."
)
button_synopsis = st.button('Get synopsis')
button_genres = st.button('Get genres')
if button_synopsis:
if len(prod_title.split(' ')) > 0:
gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
else:
st.write('Write a title for the generator to work !')
if button_genres:
if len(prod_title.split(' ')) > 0:
classifier_output = generate_classification(model_clf, tokenizer_clf, prod_title, prod_synopsis)
print(classifier_output)
st.write(f"Guessed genre : {classifier_output}")
else:
st.write('Write a title and a synopsis for the classifier to work !') |