shakespeare-demo / shakespeare_demo.py
skar0's picture
Set map_location in torch.load call
eafeaef
#%%
import yaml
import torch as t
import gradio as gr
import re
from word_data import WordData
import sampling
import transformer_replication
#%%
MAIN = __name__ == '__main__'
device = 'cuda' if t.cuda.is_available() else 'cpu'
#%%
shakespeare = WordData.from_file(
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
)
if MAIN:
print('Vocab size: ', len(shakespeare.vocab))
#%%
#%%
with open('config.yaml', 'r') as f:
yaml_cfg = yaml.safe_load(f)
#%%
with open('model_state_dict.pt') as f:
state_dict = t.load(
'model_state_dict.pt',
map_location=device,
)
#%%
base_config = transformer_replication.TransformerConfig(
num_layers=yaml_cfg['num_layers']['value'],
num_heads=yaml_cfg['num_heads']['value'],
vocab_size=len(shakespeare.vocab),
hidden_size=yaml_cfg['hidden_size']['value'],
max_seq_len=yaml_cfg['max_seq_len']['value'],
dropout=yaml_cfg['dropout']['value'],
)
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
model = transformer_replication.DecoderOnlyTransformer(base_config)
model.load_state_dict(state_dict)
#%%
def generate(
text: str, max_tokens: int, temperature: float,
top_k: int,
) -> str:
return sampling.sample_tokens(
model,
shakespeare,
text,
max_tokens_generated=max_tokens,
temperature=temperature,
top_k=top_k,
)
#%%
def safe_generate(
text: str, max_tokens: int = 300, temperature: float = 1.0,
top_k: int = 20,
) -> str:
try:
raw = generate(
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
)
match = re.match(r"(?P<start>\D*)\d+\n", raw)
if match is None:
return raw
return match.group('start')
except KeyError as e:
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
#%%
examples = [
["I sang a beautiful song"],
["To be free is to"],
["How I love thee"],
]
#%%
if MAIN:
print(safe_generate('How I love thee'))
#%%
description = """
Provide a prompt in the "Input Text" window below and then click "Submit".
The small Shakespeare transformer model trained on my laptop will attempt to
complete the Sonnet that you started.
Thanks to Project Gutenberg for providing the training corpus.
"""
#%%
def make_demo():
demo = gr.Interface(
fn=safe_generate,
inputs=[
gr.components.Textbox(lines=5, label="Input Text"),
gr.components.Slider(
label='max tokens generated', minimum=1, maximum=1000,
value=300, step=1,
),
gr.components.Slider(
label='temperature', minimum=0, maximum=2, value=1, step=0.1,
),
gr.components.Slider(
label='top_k', minimum=1, maximum=100, value=10, step=1,
),
],
outputs=gr.components.Textbox(label="Generated Text"),
examples=examples,
title='Shakespeare transformer sampling',
description=description,
)
demo.launch()
#%%