Spaces:
Runtime error
Runtime error
File size: 1,541 Bytes
ca94011 f29b441 ca94011 12b0f88 ca94011 f29b441 ca94011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import tensorflow as tf
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=3,
inter_op_parallelism_threads=2,
allow_soft_placement=True,
device_count = {'GPU':1, 'CPU':4})
session = tf.compat.v1.Session(config=config)
#for reproducability
SEED = 64
#maximum number of words in output text
# MAX_LEN = 30
title = st.text_input('Enter the seed words', ' ')
input_sequence = title
number = st.number_input('Insert how many words', 1)
MAX_LEN = number
if st.button('Submit'):
#get transformers
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
tokenizer = AutoTokenizer.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator")
GPT2 = model = AutoModelForCausalLM.from_pretrained("ml6team/gpt-2-medium-conditional-quote-generator")
import tensorflow as tf
tf.random.set_seed(SEED)
input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
input_ids = tokenizer.encode(input_sequence, return_tensors='tf')
# generate text until the output length (which includes the context length) reaches 50
greedy_output = GPT2.generate(input_ids, max_length = MAX_LEN)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens = True))
else:
st.write(' ')
# print("Output:\n" + 100 * '-')
# print(tokenizer.decode(sample_output[0], skip_special_tokens = True), '...') |