#%% import yaml import torch as t import gradio as gr import re from word_data import WordData import sampling import transformer_replication #%% MAIN = __name__ == '__main__' device = 'cuda' if t.cuda.is_available() else 'cpu' #%% shakespeare = WordData.from_file( '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL' ) if MAIN: print('Vocab size: ', len(shakespeare.vocab)) #%% #%% with open('config.yaml', 'r') as f: yaml_cfg = yaml.safe_load(f) #%% with open('model_state_dict.pt') as f: state_dict = t.load( 'model_state_dict.pt', map_location=device, ) #%% base_config = transformer_replication.TransformerConfig( num_layers=yaml_cfg['num_layers']['value'], num_heads=yaml_cfg['num_heads']['value'], vocab_size=len(shakespeare.vocab), hidden_size=yaml_cfg['hidden_size']['value'], max_seq_len=yaml_cfg['max_seq_len']['value'], dropout=yaml_cfg['dropout']['value'], ) shakespeare.model_max_length = yaml_cfg['max_seq_len']['value'] model = transformer_replication.DecoderOnlyTransformer(base_config) model.load_state_dict(state_dict) #%% def generate( text: str, max_tokens: int, temperature: float, top_k: int, ) -> str: return sampling.sample_tokens( model, shakespeare, text, max_tokens_generated=max_tokens, temperature=temperature, top_k=top_k, ) #%% def safe_generate( text: str, max_tokens: int = 300, temperature: float = 1.0, top_k: int = 20, ) -> str: try: raw = generate( text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, ) match = re.match(r"(?P\D*)\d+\n", raw) if match is None: return raw return match.group('start') except KeyError as e: return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary" #%% examples = [ ["I sang a beautiful song"], ["To be free is to"], ["How I love thee"], ] #%% if MAIN: print(safe_generate('How I love thee')) #%% description = """ Provide a prompt in the "Input Text" window below and then click "Submit". The small Shakespeare transformer model trained on my laptop will attempt to complete the Sonnet that you started. Thanks to Project Gutenberg for providing the training corpus. """ #%% def make_demo(): demo = gr.Interface( fn=safe_generate, inputs=[ gr.components.Textbox(lines=5, label="Input Text"), gr.components.Slider( label='max tokens generated', minimum=1, maximum=1000, value=300, step=1, ), gr.components.Slider( label='temperature', minimum=0, maximum=2, value=1, step=0.1, ), gr.components.Slider( label='top_k', minimum=1, maximum=100, value=10, step=1, ), ], outputs=gr.components.Textbox(label="Generated Text"), examples=examples, title='Shakespeare transformer sampling', description=description, ) demo.launch() #%%