tranquilkd's picture
Update app.py
c3fc384 verified
import os
import torch
import traceback
import tiktoken
import gradio as gr
import torch.nn.functional as F
from transformer import GPTConfig, GPT
def load_model(model_path):
# model
config = GPTConfig()
model = GPT(config)
ckpt = torch.load(os.path.join(model_path), map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()
return model
def generate_text(text, max_length=64, num_return_sequences=2):
tokenizer = tiktoken.get_encoding('gpt2')
x = tokenizer.encode(text)
x = torch.tensor(x, dtype=torch.long)
x = x.unsqueeze(0)
x = x.repeat(num_return_sequences, 1)
x = x.to(device)
for _ in range(max_length):
with torch.no_grad():
logits, _ = model(x)
# take the logits at the last position
logits = logits[:, -1, :] # (B, vocab_size)
# get the probabilities
probs = F.softmax(logits, dim=-1)
# do top-k sampling of 50 (huggingface pipeline default)
# topk_probs here becomes (5, 50), topk_indices is (5, 50)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
# select a token from the top-k probabilities
# note: multinomial does not demand the input to sum to 1
ix = torch.multinomial(topk_probs, 1) # (B, 1)
# gather the corresponding indices
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
# append to the sequence
x = torch.cat((x, xcol), dim=1)
generated_text = []
for i in range(num_return_sequences):
tokens = x[i, :max_length].tolist()
decoded = tokenizer.decode(tokens)
generated_text.append("\n>>>\n" + decoded + "\n\n")
return "".join(generated_text)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(R"GPT-Shakespeare.pth")
# Define the Gradio interface
demo = gr.Interface(
fn=generate_text,
title="Text Generation using GPT",
description = "<p style='text-align: center'> Decoder only transformer trained on \"Coriolanus\" by William Shakespeare </p>",
inputs= [
gr.Textbox(
label="Input Text",
placeholder="Enter the text in style of Coriolanus",
lines=5
),
gr.Slider(minimum=1,
maximum=128,
step=1,
value=32,
label="Max Sequence Length"),
gr.Slider(minimum=1,
maximum=5,
step=1,
value=2,
label="Number of Sequences to Return")
],
outputs=gr.Textbox(lines=5,
placeholder="Generated text will be shown here",
label="Generated Text"),
article= """<a href='https://github.com/KD1994/session-12-Transformer-from-scratch-pt2' target='_blank'> \
<i class='fab fa-github' style='font-size: 24px;'></i></a> \
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">""",
examples = [
["My noble Coriolanus, temper thy rage. These men hold", 32, 3],
["Wisdom, say’st thou? Counsel, and truth? Nay, Menenius, they are", 64, 2],
["What speaks this man of war and violence?", 50, 1],
["Enough, Coriolanus! Thy words grow wild. What wouldst thou have?", 32, 5]
]
)
# Add error handling to launch
try:
demo.launch()
except Exception as e:
print(f"Error launching interface: {str(e)}")
print(traceback.format_exc())