File size: 3,121 Bytes
4c2c4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eafeaef
 
4c2c4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8753fd
 
 
 
 
 
 
 
4c2c4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8753fd
 
 
4c2c4e8
 
b8753fd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#%%
import yaml
import torch as t
import gradio as gr
import re
from word_data import WordData
import sampling
import transformer_replication
#%%
MAIN = __name__ == '__main__'
device = 'cuda' if t.cuda.is_available() else 'cpu'
#%%
shakespeare = WordData.from_file(
    '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
)
if MAIN:
    print('Vocab size: ', len(shakespeare.vocab))
#%%
#%%
with open('config.yaml', 'r') as f:
    yaml_cfg = yaml.safe_load(f)
#%%
with open('model_state_dict.pt') as f:
    state_dict = t.load(
        'model_state_dict.pt',
        map_location=device,
    )
#%%
base_config = transformer_replication.TransformerConfig(
    num_layers=yaml_cfg['num_layers']['value'],
    num_heads=yaml_cfg['num_heads']['value'],
    vocab_size=len(shakespeare.vocab),
    hidden_size=yaml_cfg['hidden_size']['value'],
    max_seq_len=yaml_cfg['max_seq_len']['value'],
    dropout=yaml_cfg['dropout']['value'],
)
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
model = transformer_replication.DecoderOnlyTransformer(base_config)

model.load_state_dict(state_dict)

#%%
def generate(
        text: str, max_tokens: int, temperature: float, 
        top_k: int,
) -> str:
    return sampling.sample_tokens(
        model, 
        shakespeare, 
        text, 
        max_tokens_generated=max_tokens, 
        temperature=temperature, 
        top_k=top_k,
    )

#%%
def safe_generate(
        text: str, max_tokens: int = 300, temperature: float = 1.0, 
        top_k: int = 20,
    ) -> str:
    try:
        raw = generate(
            text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
        )
        match = re.match(r"(?P<start>\D*)\d+\n", raw)
        if match is None:
            return raw
        return match.group('start')
    except KeyError as e:
        return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
#%%
examples = [
    ["I sang a beautiful song"],
    ["To be free is to"],
    ["How I love thee"],
]
#%%
if MAIN:
    print(safe_generate('How I love thee'))
#%%
description = """
Provide a prompt in the "Input Text" window below and then click "Submit".
The small Shakespeare transformer model trained on my laptop will attempt to
complete the Sonnet that you started.

Thanks to Project Gutenberg for providing the training corpus.
"""
#%%
def make_demo():
    demo = gr.Interface(
        fn=safe_generate,
        inputs=[
            gr.components.Textbox(lines=5, label="Input Text"),
            gr.components.Slider(
                label='max tokens generated', minimum=1, maximum=1000, 
                value=300, step=1,
            ),
            gr.components.Slider(
                label='temperature', minimum=0, maximum=2, value=1, step=0.1,
            ),
            gr.components.Slider(
                label='top_k', minimum=1, maximum=100, value=10, step=1,
            ),
        ],
        outputs=gr.components.Textbox(label="Generated Text"),
        examples=examples,
        title='Shakespeare transformer sampling',
        description=description,
    )
    demo.launch()
#%%