Spaces:
Runtime error
Runtime error
File size: 3,121 Bytes
4c2c4e8 eafeaef 4c2c4e8 b8753fd 4c2c4e8 b8753fd 4c2c4e8 b8753fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
#%%
import yaml
import torch as t
import gradio as gr
import re
from word_data import WordData
import sampling
import transformer_replication
#%%
MAIN = __name__ == '__main__'
device = 'cuda' if t.cuda.is_available() else 'cpu'
#%%
shakespeare = WordData.from_file(
'100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
)
if MAIN:
print('Vocab size: ', len(shakespeare.vocab))
#%%
#%%
with open('config.yaml', 'r') as f:
yaml_cfg = yaml.safe_load(f)
#%%
with open('model_state_dict.pt') as f:
state_dict = t.load(
'model_state_dict.pt',
map_location=device,
)
#%%
base_config = transformer_replication.TransformerConfig(
num_layers=yaml_cfg['num_layers']['value'],
num_heads=yaml_cfg['num_heads']['value'],
vocab_size=len(shakespeare.vocab),
hidden_size=yaml_cfg['hidden_size']['value'],
max_seq_len=yaml_cfg['max_seq_len']['value'],
dropout=yaml_cfg['dropout']['value'],
)
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
model = transformer_replication.DecoderOnlyTransformer(base_config)
model.load_state_dict(state_dict)
#%%
def generate(
text: str, max_tokens: int, temperature: float,
top_k: int,
) -> str:
return sampling.sample_tokens(
model,
shakespeare,
text,
max_tokens_generated=max_tokens,
temperature=temperature,
top_k=top_k,
)
#%%
def safe_generate(
text: str, max_tokens: int = 300, temperature: float = 1.0,
top_k: int = 20,
) -> str:
try:
raw = generate(
text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
)
match = re.match(r"(?P<start>\D*)\d+\n", raw)
if match is None:
return raw
return match.group('start')
except KeyError as e:
return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
#%%
examples = [
["I sang a beautiful song"],
["To be free is to"],
["How I love thee"],
]
#%%
if MAIN:
print(safe_generate('How I love thee'))
#%%
description = """
Provide a prompt in the "Input Text" window below and then click "Submit".
The small Shakespeare transformer model trained on my laptop will attempt to
complete the Sonnet that you started.
Thanks to Project Gutenberg for providing the training corpus.
"""
#%%
def make_demo():
demo = gr.Interface(
fn=safe_generate,
inputs=[
gr.components.Textbox(lines=5, label="Input Text"),
gr.components.Slider(
label='max tokens generated', minimum=1, maximum=1000,
value=300, step=1,
),
gr.components.Slider(
label='temperature', minimum=0, maximum=2, value=1, step=0.1,
),
gr.components.Slider(
label='top_k', minimum=1, maximum=100, value=10, step=1,
),
],
outputs=gr.components.Textbox(label="Generated Text"),
examples=examples,
title='Shakespeare transformer sampling',
description=description,
)
demo.launch()
#%%
|