Spaces:
Runtime error
Runtime error
#%% | |
import yaml | |
import torch as t | |
import gradio as gr | |
import re | |
from word_data import WordData | |
import sampling | |
import transformer_replication | |
#%% | |
MAIN = __name__ == '__main__' | |
device = 'cuda' if t.cuda.is_available() else 'cpu' | |
#%% | |
shakespeare = WordData.from_file( | |
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL' | |
) | |
if MAIN: | |
print('Vocab size: ', len(shakespeare.vocab)) | |
#%% | |
#%% | |
with open('config.yaml', 'r') as f: | |
yaml_cfg = yaml.safe_load(f) | |
#%% | |
with open('model_state_dict.pt') as f: | |
state_dict = t.load( | |
'model_state_dict.pt', | |
map_location=device, | |
) | |
#%% | |
base_config = transformer_replication.TransformerConfig( | |
num_layers=yaml_cfg['num_layers']['value'], | |
num_heads=yaml_cfg['num_heads']['value'], | |
vocab_size=len(shakespeare.vocab), | |
hidden_size=yaml_cfg['hidden_size']['value'], | |
max_seq_len=yaml_cfg['max_seq_len']['value'], | |
dropout=yaml_cfg['dropout']['value'], | |
) | |
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value'] | |
model = transformer_replication.DecoderOnlyTransformer(base_config) | |
model.load_state_dict(state_dict) | |
#%% | |
def generate( | |
text: str, max_tokens: int, temperature: float, | |
top_k: int, | |
) -> str: | |
return sampling.sample_tokens( | |
model, | |
shakespeare, | |
text, | |
max_tokens_generated=max_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
) | |
#%% | |
def safe_generate( | |
text: str, max_tokens: int = 300, temperature: float = 1.0, | |
top_k: int = 20, | |
) -> str: | |
try: | |
raw = generate( | |
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, | |
) | |
match = re.match(r"(?P<start>\D*)\d+\n", raw) | |
if match is None: | |
return raw | |
return match.group('start') | |
except KeyError as e: | |
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary" | |
#%% | |
examples = [ | |
["I sang a beautiful song"], | |
["To be free is to"], | |
["How I love thee"], | |
] | |
#%% | |
if MAIN: | |
print(safe_generate('How I love thee')) | |
#%% | |
description = """ | |
Provide a prompt in the "Input Text" window below and then click "Submit". | |
The small Shakespeare transformer model trained on my laptop will attempt to | |
complete the Sonnet that you started. | |
Thanks to Project Gutenberg for providing the training corpus. | |
""" | |
#%% | |
def make_demo(): | |
demo = gr.Interface( | |
fn=safe_generate, | |
inputs=[ | |
gr.components.Textbox(lines=5, label="Input Text"), | |
gr.components.Slider( | |
label='max tokens generated', minimum=1, maximum=1000, | |
value=300, step=1, | |
), | |
gr.components.Slider( | |
label='temperature', minimum=0, maximum=2, value=1, step=0.1, | |
), | |
gr.components.Slider( | |
label='top_k', minimum=1, maximum=100, value=10, step=1, | |
), | |
], | |
outputs=gr.components.Textbox(label="Generated Text"), | |
examples=examples, | |
title='Shakespeare transformer sampling', | |
description=description, | |
) | |
demo.launch() | |
#%% | |