Spaces:
Running
Running
"""GPT-1 and GPT-2 Text Generation demo.""" | |
import gradio as gr | |
from torch.cuda import is_available | |
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2Tokenizer, GPT2LMHeadModel | |
tokenizer = None | |
model = None | |
loaded_model = None | |
def load_model(model_name): | |
"""Loads the model and tokenizer from HuggingFace.""" | |
global tokenizer, model, loaded_model | |
loaded_model = model_name | |
huggingface_model_name = model_name.split('(')[1][:-1] | |
if huggingface_model_name == 'openai-gpt': # GPT-1 | |
tokenizer = OpenAIGPTTokenizer.from_pretrained(huggingface_model_name) | |
model = OpenAIGPTLMHeadModel.from_pretrained(huggingface_model_name) | |
else: # GPT-2 | |
tokenizer = GPT2Tokenizer.from_pretrained(huggingface_model_name) | |
model = GPT2LMHeadModel.from_pretrained(huggingface_model_name) | |
# Load model in CUDA if available | |
if is_available(): | |
model = model.cuda() | |
def generate(inp, model_name, temperature, top_p, rep_pty, max_length): | |
"""Generates text using the given model and parameters.""" | |
if loaded_model != model_name: | |
load_model(model_name) | |
inputs = tokenizer.encode(inp, return_tensors='pt') | |
if is_available(): | |
inputs = inputs.cuda() | |
outputs = model.generate(inputs, | |
max_length=max_length, | |
temperature=temperature, | |
num_return_sequences=1, | |
top_p=top_p, | |
repetition_penalty=rep_pty) | |
out = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if 'GPT-1' in model_name: | |
out = out.replace(inp.lower(), "") | |
else: | |
out = out.replace(inp, "") | |
return out | |
SAMPLE_INPUT = ( | |
"In a shocking finding, scientists discovered a herd of unicorns living in a remote," | |
" previously unexplored valley, in the Andes Mountains. Even more surprising to the" | |
" researchers was the fact that the unicorns spoke perfect English." | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🦄 Try GPT-1 and GPT-2") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
inp = gr.Textbox(label="Input text:", | |
placeholder="Enter some text to get started.", | |
value=SAMPLE_INPUT, | |
lines=10) | |
out = gr.Textbox(label="Generated text:", lines=25) | |
with gr.Column(scale=1): | |
with gr.Row(): | |
model_name = gr.Dropdown(label="Select a model:", | |
choices=['GPT-2 XL (gpt2-xl)', | |
'GPT-2 L (gpt2-large)', | |
'GPT-2 M (gpt2-medium)', | |
'GPT-2 S (gpt2)', | |
'GPT-1 (openai-gpt)'], | |
value='GPT-2 XL (gpt2-xl)') | |
btn_run = gr.Button("Generate") | |
temperature = gr.Slider( | |
label="Temperature", | |
info=("Degree of randomness in the output, where higher values make it more unpredictable" | |
" and creative, while lower values make it more deterministic and focused."), | |
minimum=0.01, maximum=3.0, step=0.01, value=0.7) | |
top_p = gr.Slider( | |
label="Top-p", | |
info=("If set to float < 1, only the most probable tokens with probabilities that add up" | |
" to `top_p` or higher are kept for generation."), | |
minimum=0.01, maximum=1.0, step=0.01, value=.9) | |
rep_pty = gr.Slider(label="Repetition Penalty", | |
info="Token repetition penalty. 1.0 means no penalty.", | |
minimum=1.0, maximum=2.0, step=0.01, value=1.2) | |
max_length = gr.Number(label="Max Length", | |
info="The maximum length of the sequence to be generated.", | |
minimum=1, maximum=1024, value=256, precision=0) | |
btn_run.click(fn=generate, inputs=[inp, model_name, temperature, top_p, rep_pty, max_length], outputs=out) | |
demo.launch() | |