yash161101 commited on
Commit
ca94011
·
1 Parent(s): f871782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -6
app.py CHANGED
@@ -1,8 +1,47 @@
1
- from transformers import AutoModel, AutoTokenizer
2
 
3
- model_name = "ml6team/gpt-2-small-conditional-quote-generator"
4
- model = AutoModel.from_pretrained(model_name)
5
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
6
 
7
- inputs = tokenizer("Hello world!", return_tensors="pt")
8
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
 
3
+ import tensorflow as tf
4
+ config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=3,
5
+ inter_op_parallelism_threads=2,
6
+ allow_soft_placement=True,
7
+ device_count = {'GPU':1, 'CPU':4})
8
 
9
+ session = tf.compat.v1.Session(config=config)
10
+
11
+ #for reproducability
12
+ SEED = 64
13
+
14
+ #maximum number of words in output text
15
+ # MAX_LEN = 30
16
+
17
+ title = st.text_input('Enter the seed words', ' ')
18
+ input_sequence = title
19
+
20
+ number = st.number_input('Insert how many words', 1)
21
+ MAX_LEN = number
22
+
23
+ if st.button('Submit'):
24
+
25
+ #get transformers
26
+ from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
27
+
28
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt-2")
29
+ GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt-2", pad_token_id=tokenizer.eos_token_id)
30
+
31
+ import tensorflow as tf
32
+ tf.random.set_seed(SEED)
33
+ input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
34
+
35
+ input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
36
+
37
+ # generate text until the output length (which includes the context length) reaches 50
38
+ greedy_output = GPT2.generate(input_ids, max_length = MAX_LEN)
39
+
40
+ print("Output:\n" + 100 * '-')
41
+ print(tokenizer.decode(greedy_output[0], skip_special_tokens = True))
42
+ else:
43
+ st.write(' ')
44
+
45
+
46
+ # print("Output:\n" + 100 * '-')
47
+ # print(tokenizer.decode(sample_output[0], skip_special_tokens = True), '...')