File size: 2,792 Bytes
4c2c4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85b344b
4c2c4e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#%%
import yaml
import torch as t
import gradio as gr
import re
from word_data import WordData
import sampling
import transformer_replication
#%%
MAIN = __name__ == '__main__'
device = 'cuda' if t.cuda.is_available() else 'cpu'
#%%
shakespeare = WordData.from_file(
    '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
)
if MAIN:
    print('Vocab size: ', len(shakespeare.vocab))
#%%
#%%
with open('config.yaml', 'r') as f:
    yaml_cfg = yaml.safe_load(f)
#%%
with open('model_state_dict.pt') as f:
    state_dict = t.load(
        'model_state_dict.pt'
    )
#%%
base_config = transformer_replication.TransformerConfig(
    num_layers=yaml_cfg['num_layers']['value'],
    num_heads=yaml_cfg['num_heads']['value'],
    vocab_size=len(shakespeare.vocab),
    hidden_size=yaml_cfg['hidden_size']['value'],
    max_seq_len=yaml_cfg['max_seq_len']['value'],
    dropout=yaml_cfg['dropout']['value'],
)
shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
model = transformer_replication.DecoderOnlyTransformer(base_config)

model.load_state_dict(state_dict)

#%%
def generate(
        text: str, max_tokens: int, temperature: float, 
        top_k: int,
) -> str:
    return sampling.sample_tokens(
        model, 
        shakespeare, 
        text, 
        max_tokens_generated=max_tokens, 
        temperature=temperature, 
        top_k=top_k,
    )

#%%
def safe_generate(
        text: str, max_tokens: int = 300, temperature: float = 1.0, 
        top_k: int = 20,
    ) -> str:
    try:
        raw = generate(
            text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
        )
        match = re.match(r"(?P<start>\D*)\d+\n", raw)
        if match is None:
            return raw
        return match.group('start')
    except KeyError as e:
        return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
#%%
examples = [
    ["I sang a beautiful song"],
    ["To be free is to"],
    ["How I love thee"],
]
#%%
if MAIN:
    print(safe_generate('How I love thee'))
#%%
def make_demo():
    demo = gr.Interface(
        fn=safe_generate,
        inputs=[
            gr.components.Textbox(lines=5, label="Input Text"),
            gr.components.Slider(
                label='max tokens generated', minimum=1, maximum=1000, 
                value=300, step=1,
            ),
            gr.components.Slider(
                label='temperature', minimum=0, maximum=2, value=1, step=0.1,
            ),
            gr.components.Slider(
                label='top_k', minimum=1, maximum=100, value=10, step=1,
            ),
        ],
        outputs=gr.components.Textbox(label="Generated Text"),
        examples=examples
    )
    demo.launch()
# %%
'''
FIXME:
* deploy to hugging face
* link from github home
'''