shakespeare-demo / shakespeare_demo.py
skar0's picture
Initial commit
4c2c4e8
raw
history blame
2.79 kB
#%%
import yaml
import torch as t
import gradio as gr
import re
from word_data import WordData
import sampling
import transformer_replication
#%%
MAIN = __name__ == '__main__'
device = 'cuda' if t.cuda.is_available() else 'cpu'
#%%
shakespeare = WordData.from_file(
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
)
if MAIN:
print('Vocab size: ', len(shakespeare.vocab))
#%%
#%%
with open('config.yaml', 'r') as f:
yaml_cfg = yaml.safe_load(f)
#%%
with open('model_state_dict.pt') as f:
state_dict = t.load(
'model_state_dict.pt'
)
#%%
base_config = transformer_replication.TransformerConfig(
num_layers=yaml_cfg['num_layers']['value'],
num_heads=yaml_cfg['num_heads']['value'],
vocab_size=len(shakespeare.vocab),
hidden_size=yaml_cfg['hidden_size']['value'],
max_seq_len=yaml_cfg['max_seq_len']['value'],
dropout=yaml_cfg['dropout']['value'],
)
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
model = transformer_replication.DecoderOnlyTransformer(base_config)
model.load_state_dict(state_dict)
#%%
def generate(
text: str, max_tokens: int, temperature: float,
top_k: int,
) -> str:
return sampling.sample_tokens(
model,
shakespeare,
text,
max_tokens_generated=max_tokens,
temperature=temperature,
top_k=top_k,
)
#%%
def safe_generate(
text: str, max_tokens: int = 300, temperature: float = 1.0,
top_k: int = 20,
) -> str:
try:
raw = generate(
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
)
match = re.match(r"(?P<start>\D*)\d+\n", raw)
if match is None:
return raw
return match.group('start')
except KeyError as e:
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
#%%
examples = [
["I sang a beautiful song"],
["To be free is to"],
["How I love thee"],
]
#%%
if MAIN:
print(safe_generate('How I love thee'))
#%%
def make_demo():
demo = gr.Interface(
fn=safe_generate,
inputs=[
gr.components.Textbox(lines=5, label="Input Text"),
gr.components.Slider(
label='max tokens generated', minimum=1, maximum=1000,
value=300, step=1,
),
gr.components.Slider(
label='temperature', minimum=0, maximum=2, value=1, step=0.1,
),
gr.components.Slider(
label='top_k', minimum=1, maximum=100, value=10, step=1,
),
],
outputs=gr.components.Textbox(label="Generated Text"),
examples=examples
)
demo.launch()
# %%
'''
FIXME:
* deploy to heroku
* link from github home
'''