yash161101 commited on
Commit
12b0f88
·
1 Parent(s): 88c8380

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ #for reproducability
4
+ SEED = 12
5
+
6
+ #maximum number of words in output text
7
+ # MAX_LEN = 30
8
+
9
+ title = st.text_input('Enter the seed words', ' ')
10
+ input_sequence = title
11
+
12
+ number = st.number_input('Insert how many words', 1)
13
+ MAX_LEN = number
14
+
15
+ if st.button('Submit'):
16
+
17
+ #get transformers
18
+ from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
19
+
20
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
21
+ GPT2 = TFGPT2LMHeadModel.from_pretrained("gpt2-medium", pad_token_id=tokenizer.eos_token_id)
22
+
23
+ import tensorflow as tf
24
+ tf.random.set_seed(SEED)
25
+ input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
26
+
27
+ #sample only from 80% most likely words
28
+ sample_output = GPT2.generate(
29
+ input_ids,
30
+ do_sample = True,
31
+ max_length = MAX_LEN,
32
+ top_p = 0.8,
33
+ top_k = 0
34
+ )
35
+
36
+ st.write(tokenizer.decode(sample_output[0], skip_special_tokens = True))
37
+
38
+ else:
39
+ st.write(' ')
40
+
41
+
42
+
43
+ # print("Output:\n" + 100 * '-')
44
+ # print(tokenizer.decode(sample_output[0], skip_special_tokens = True), '...')