alirani commited on
Commit
9f4fb0e
·
1 Parent(s): 56e76a4

add classifier

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -1,11 +1,15 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, TFAutoModelForCausalLM
3
 
4
  # MODEL TO CALL
5
 
6
- model_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = TFAutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
9
 
10
  def generate_synopsis(model, tokenizer, title):
11
  input_ids = tokenizer(title, return_tensors="tf")
@@ -14,6 +18,11 @@ def generate_synopsis(model, tokenizer, title):
14
  processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
15
  return processed_synopsis
16
 
 
 
 
 
 
17
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
18
 
19
  st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
@@ -23,6 +32,7 @@ st.title('Demo LoreFinder')
23
  st.header('Generate a story')
24
 
25
  prod_title = st.text_input('Type a title to generate a synopsis')
 
26
 
27
  option_genres = st.selectbox(
28
  'Select a genre to tailor your synopsis',
@@ -33,9 +43,18 @@ option_genres = st.selectbox(
33
 
34
  button_synopsis = st.button('Get synopsis')
35
 
 
 
36
  if button_synopsis:
37
  if len(prod_title.split(' ')) > 0:
38
- gen_synopsis = generate_synopsis(model, tokenizer, f"{prod_title} | {option_genres} | ")
39
  st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
40
  else:
41
  st.write('Write a title for the generator to work !')
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification
3
 
4
  # MODEL TO CALL
5
 
6
+ generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
7
+ tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
8
+ model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)
9
+
10
+ classifier_name = "Alirani/overview_classifier_final"
11
+ tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
12
+ model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)
13
 
14
  def generate_synopsis(model, tokenizer, title):
15
  input_ids = tokenizer(title, return_tensors="tf")
 
18
  processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
19
  return processed_synopsis
20
 
21
+ def generate_classification(model, tokenizer, title, overview):
22
+ tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
23
+ output = model(**tokens)
24
+ return output
25
+
26
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
27
 
28
  st.set_page_config(page_title="LoreFinder-demo", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
 
32
  st.header('Generate a story')
33
 
34
  prod_title = st.text_input('Type a title to generate a synopsis')
35
+ prod_synopsis = st.text_input('Type a synopsis to guess the genre')
36
 
37
  option_genres = st.selectbox(
38
  'Select a genre to tailor your synopsis',
 
43
 
44
  button_synopsis = st.button('Get synopsis')
45
 
46
+ button_genres = st.button('Get genres')
47
+
48
  if button_synopsis:
49
  if len(prod_title.split(' ')) > 0:
50
+ gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
51
  st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
52
  else:
53
  st.write('Write a title for the generator to work !')
54
+
55
+ if button_genres:
56
+ if len(prod_title.split(' ')) > 0:
57
+ classifier_output = generate_classification(model_clf, tokenizer_clf, prod_title, prod_synopsis)
58
+ st.write(f"Guessed genre : {classifier_output}")
59
+ else:
60
+ st.write('Write a title and a synopsis for the classifier to work !')