File size: 2,668 Bytes
bd61da5
9f4fb0e
233c774
 
 
9f4fb0e
 
 
 
 
 
 
233c774
 
 
 
 
ff3389c
076a158
bd61da5
9f4fb0e
 
 
 
 
bd61da5
 
233c774
bd61da5
6b329ae
 
 
 
233c774
9f4fb0e
233c774
ff3389c
 
56e76a4
 
 
ff3389c
 
b7137a2
233c774
9f4fb0e
 
b7137a2
bcfdb10
9f4fb0e
076a158
bcfdb10
 
9f4fb0e
 
 
 
c374e34
9f4fb0e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification

# MODEL TO CALL

generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)

classifier_name = "Alirani/overview_classifier_final"
tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)

def generate_synopsis(model, tokenizer, title):
    input_ids = tokenizer(title, return_tensors="tf")
    output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
    synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
    processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
    return processed_synopsis

def generate_classification(model, tokenizer, title, overview):
    tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
    output = model(**tokens)
    return output

favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"

st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')

st.title('Demo LoreFinder')

st.header('Generate a story')

prod_title = st.text_input('Type a title to generate a synopsis')
prod_synopsis = st.text_input('Type a synopsis to guess the genre')

option_genres = st.selectbox(
    'Select a genre to tailor your synopsis',
    ('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality'),
    index=None,
    placeholder="Select genres..."
    )

button_synopsis = st.button('Get synopsis')

button_genres = st.button('Get genres')

if button_synopsis:
    if len(prod_title.split(' ')) > 0:
        gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
        st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
    else:
        st.write('Write a title for the generator to work !')

if button_genres:
    if len(prod_title.split(' ')) > 0:
        classifier_output = generate_classification(model_clf, tokenizer_clf, prod_title, prod_synopsis)
        print(classifier_output)
        st.write(f"Guessed genre : {classifier_output}")
    else:
        st.write('Write a title and a synopsis for the classifier to work !')