import os import torch import traceback import tiktoken import gradio as gr import torch.nn.functional as F from transformer import GPTConfig, GPT def load_model(model_path): # model config = GPTConfig() model = GPT(config) ckpt = torch.load(os.path.join(model_path), map_location="cpu") model.load_state_dict(ckpt["model_state_dict"]) model.to(device) model.eval() return model def generate_text(text, max_length=64, num_return_sequences=2): tokenizer = tiktoken.get_encoding('gpt2') x = tokenizer.encode(text) x = torch.tensor(x, dtype=torch.long) x = x.unsqueeze(0) x = x.repeat(num_return_sequences, 1) x = x.to(device) for _ in range(max_length): with torch.no_grad(): logits, _ = model(x) # take the logits at the last position logits = logits[:, -1, :] # (B, vocab_size) # get the probabilities probs = F.softmax(logits, dim=-1) # do top-k sampling of 50 (huggingface pipeline default) # topk_probs here becomes (5, 50), topk_indices is (5, 50) topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # select a token from the top-k probabilities # note: multinomial does not demand the input to sum to 1 ix = torch.multinomial(topk_probs, 1) # (B, 1) # gather the corresponding indices xcol = torch.gather(topk_indices, -1, ix) # (B, 1) # append to the sequence x = torch.cat((x, xcol), dim=1) generated_text = [] for i in range(num_return_sequences): tokens = x[i, :max_length].tolist() decoded = tokenizer.decode(tokens) generated_text.append("\n>>>\n" + decoded + "\n\n") return "".join(generated_text) device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model(R"GPT-Shakespeare.pth") # Define the Gradio interface demo = gr.Interface( fn=generate_text, title="Text Generation using GPT", description = "
Decoder only transformer trained on \"Coriolanus\" by William Shakespeare
", inputs= [ gr.Textbox( label="Input Text", placeholder="Enter the text in style of Coriolanus", lines=5 ), gr.Slider(minimum=1, maximum=128, step=1, value=32, label="Max Sequence Length"), gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Number of Sequences to Return") ], outputs=gr.Textbox(lines=5, placeholder="Generated text will be shown here", label="Generated Text"), article= """ \ \ """, examples = [ ["My noble Coriolanus, temper thy rage. These men hold", 32, 3], ["Wisdom, say’st thou? Counsel, and truth? Nay, Menenius, they are", 64, 2], ["What speaks this man of war and violence?", 50, 1], ["Enough, Coriolanus! Thy words grow wild. What wouldst thou have?", 32, 5] ] ) # Add error handling to launch try: demo.launch() except Exception as e: print(f"Error launching interface: {str(e)}") print(traceback.format_exc())