mkmenta's picture
Update app.py
dd43eee verified
"""GPT-1 and GPT-2 Text Generation demo."""
import gradio as gr
from torch.cuda import is_available
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2Tokenizer, GPT2LMHeadModel
tokenizer = None
model = None
loaded_model = None
def load_model(model_name):
"""Loads the model and tokenizer from HuggingFace."""
global tokenizer, model, loaded_model
loaded_model = model_name
huggingface_model_name = model_name.split('(')[1][:-1]
if huggingface_model_name == 'openai-gpt': # GPT-1
tokenizer = OpenAIGPTTokenizer.from_pretrained(huggingface_model_name)
model = OpenAIGPTLMHeadModel.from_pretrained(huggingface_model_name)
else: # GPT-2
tokenizer = GPT2Tokenizer.from_pretrained(huggingface_model_name)
model = GPT2LMHeadModel.from_pretrained(huggingface_model_name)
# Load model in CUDA if available
if is_available():
model = model.cuda()
def generate(inp, model_name, temperature, top_p, rep_pty, max_length):
"""Generates text using the given model and parameters."""
if loaded_model != model_name:
load_model(model_name)
inputs = tokenizer.encode(inp, return_tensors='pt')
if is_available():
inputs = inputs.cuda()
outputs = model.generate(inputs,
max_length=max_length,
temperature=temperature,
num_return_sequences=1,
top_p=top_p,
repetition_penalty=rep_pty)
out = tokenizer.decode(outputs[0], skip_special_tokens=True)
if 'GPT-1' in model_name:
out = out.replace(inp.lower(), "")
else:
out = out.replace(inp, "")
return out
SAMPLE_INPUT = (
"In a shocking finding, scientists discovered a herd of unicorns living in a remote,"
" previously unexplored valley, in the Andes Mountains. Even more surprising to the"
" researchers was the fact that the unicorns spoke perfect English."
)
with gr.Blocks() as demo:
gr.Markdown("# 🦄 Try GPT-1 and GPT-2")
with gr.Row():
with gr.Column(scale=4):
inp = gr.Textbox(label="Input text:",
placeholder="Enter some text to get started.",
value=SAMPLE_INPUT,
lines=10)
out = gr.Textbox(label="Generated text:", lines=25)
with gr.Column(scale=1):
with gr.Row():
model_name = gr.Dropdown(label="Select a model:",
choices=['GPT-2 XL (gpt2-xl)',
'GPT-2 L (gpt2-large)',
'GPT-2 M (gpt2-medium)',
'GPT-2 S (gpt2)',
'GPT-1 (openai-gpt)'],
value='GPT-2 XL (gpt2-xl)')
btn_run = gr.Button("Generate")
temperature = gr.Slider(
label="Temperature",
info=("Degree of randomness in the output, where higher values make it more unpredictable"
" and creative, while lower values make it more deterministic and focused."),
minimum=0.01, maximum=3.0, step=0.01, value=0.7)
top_p = gr.Slider(
label="Top-p",
info=("If set to float < 1, only the most probable tokens with probabilities that add up"
" to `top_p` or higher are kept for generation."),
minimum=0.01, maximum=1.0, step=0.01, value=.9)
rep_pty = gr.Slider(label="Repetition Penalty",
info="Token repetition penalty. 1.0 means no penalty.",
minimum=1.0, maximum=2.0, step=0.01, value=1.2)
max_length = gr.Number(label="Max Length",
info="The maximum length of the sequence to be generated.",
minimum=1, maximum=1024, value=256, precision=0)
btn_run.click(fn=generate, inputs=[inp, model_name, temperature, top_p, rep_pty, max_length], outputs=out)
demo.launch()