alirani
commited on
Commit
·
9f4fb0e
1
Parent(s):
56e76a4
add classifier
Browse files
app.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
3 |
|
4 |
# MODEL TO CALL
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def generate_synopsis(model, tokenizer, title):
|
11 |
input_ids = tokenizer(title, return_tensors="tf")
|
@@ -14,6 +18,11 @@ def generate_synopsis(model, tokenizer, title):
|
|
14 |
processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
|
15 |
return processed_synopsis
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
|
18 |
|
19 |
st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
|
@@ -23,6 +32,7 @@ st.title('Demo LoreFinder')
|
|
23 |
st.header('Generate a story')
|
24 |
|
25 |
prod_title = st.text_input('Type a title to generate a synopsis')
|
|
|
26 |
|
27 |
option_genres = st.selectbox(
|
28 |
'Select a genre to tailor your synopsis',
|
@@ -33,9 +43,18 @@ option_genres = st.selectbox(
|
|
33 |
|
34 |
button_synopsis = st.button('Get synopsis')
|
35 |
|
|
|
|
|
36 |
if button_synopsis:
|
37 |
if len(prod_title.split(' ')) > 0:
|
38 |
-
gen_synopsis = generate_synopsis(
|
39 |
st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
|
40 |
else:
|
41 |
st.write('Write a title for the generator to work !')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification
|
3 |
|
4 |
# MODEL TO CALL
|
5 |
|
6 |
+
generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
|
7 |
+
tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
|
8 |
+
model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)
|
9 |
+
|
10 |
+
classifier_name = "Alirani/overview_classifier_final"
|
11 |
+
tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
|
12 |
+
model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)
|
13 |
|
14 |
def generate_synopsis(model, tokenizer, title):
|
15 |
input_ids = tokenizer(title, return_tensors="tf")
|
|
|
18 |
processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
|
19 |
return processed_synopsis
|
20 |
|
21 |
+
def generate_classification(model, tokenizer, title, overview):
|
22 |
+
tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
|
23 |
+
output = model(**tokens)
|
24 |
+
return output
|
25 |
+
|
26 |
favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
|
27 |
|
28 |
st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
|
|
|
32 |
st.header('Generate a story')
|
33 |
|
34 |
prod_title = st.text_input('Type a title to generate a synopsis')
|
35 |
+
prod_synopsis = st.text_input('Type a synopsis to guess the genre')
|
36 |
|
37 |
option_genres = st.selectbox(
|
38 |
'Select a genre to tailor your synopsis',
|
|
|
43 |
|
44 |
button_synopsis = st.button('Get synopsis')
|
45 |
|
46 |
+
button_genres = st.button('Get genres')
|
47 |
+
|
48 |
if button_synopsis:
|
49 |
if len(prod_title.split(' ')) > 0:
|
50 |
+
gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
|
51 |
st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
|
52 |
else:
|
53 |
st.write('Write a title for the generator to work !')
|
54 |
+
|
55 |
+
if button_genres:
|
56 |
+
if len(prod_title.split(' ')) > 0:
|
57 |
+
classifier_output = generate_classification(model_clf, tokenizer_clf, prod_title, prod_synopsis)
|
58 |
+
st.write(f"Guessed genre : {classifier_output}")
|
59 |
+
else:
|
60 |
+
st.write('Write a title and a synopsis for the classifier to work !')
|