#%% import yaml import torch as t import gradio as gr import re from word_data import WordData import sampling import transformer_replication #%% MAIN = __name__ == '__main__' device = 'cuda' if t.cuda.is_available() else 'cpu' #%% shakespeare = WordData.from_file( '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL' ) if MAIN: print('Vocab size: ', len(shakespeare.vocab)) #%% #%% with open('config.yaml', 'r') as f: yaml_cfg = yaml.safe_load(f) #%% with open('model_state_dict.pt') as f: state_dict = t.load( 'model_state_dict.pt' ) #%% base_config = transformer_replication.TransformerConfig( num_layers=yaml_cfg['num_layers']['value'], num_heads=yaml_cfg['num_heads']['value'], vocab_size=len(shakespeare.vocab), hidden_size=yaml_cfg['hidden_size']['value'], max_seq_len=yaml_cfg['max_seq_len']['value'], dropout=yaml_cfg['dropout']['value'], ) shakespeare.model_max_length = yaml_cfg['max_seq_len']['value'] model = transformer_replication.DecoderOnlyTransformer(base_config) model.load_state_dict(state_dict) #%% def generate( text: str, max_tokens: int, temperature: float, top_k: int, ) -> str: return sampling.sample_tokens( model, shakespeare, text, max_tokens_generated=max_tokens, temperature=temperature, top_k=top_k, ) #%% def safe_generate( text: str, max_tokens: int = 300, temperature: float = 1.0, top_k: int = 20, ) -> str: try: raw = generate( text, max_tokens=max_tokens, temperature=temperature, top_k=top_k, ) match = re.match(r"(?P\D*)\d+\n", raw) if match is None: return raw return match.group('start') except KeyError as e: return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary" #%% examples = [ ["I sang a beautiful song"], ["To be free is to"], ["How I love thee"], ] #%% if MAIN: print(safe_generate('How I love thee')) #%% def make_demo(): demo = gr.Interface( fn=safe_generate, inputs=[ gr.components.Textbox(lines=5, label="Input Text"), gr.components.Slider( label='max tokens generated', minimum=1, maximum=1000, value=300, step=1, ), gr.components.Slider( label='temperature', minimum=0, maximum=2, value=1, step=0.1, ), gr.components.Slider( label='top_k', minimum=1, maximum=100, value=10, step=1, ), ], outputs=gr.components.Textbox(label="Generated Text"), examples=examples ) demo.launch() # %% ''' FIXME: * deploy to hugging face * link from github home '''