Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import traceback | |
import tiktoken | |
import gradio as gr | |
import torch.nn.functional as F | |
from transformer import GPTConfig, GPT | |
def load_model(model_path): | |
# model | |
config = GPTConfig() | |
model = GPT(config) | |
ckpt = torch.load(os.path.join(model_path), map_location="cpu") | |
model.load_state_dict(ckpt["model_state_dict"]) | |
model.to(device) | |
model.eval() | |
return model | |
def generate_text(text, max_length=64, num_return_sequences=2): | |
tokenizer = tiktoken.get_encoding('gpt2') | |
x = tokenizer.encode(text) | |
x = torch.tensor(x, dtype=torch.long) | |
x = x.unsqueeze(0) | |
x = x.repeat(num_return_sequences, 1) | |
x = x.to(device) | |
for _ in range(max_length): | |
with torch.no_grad(): | |
logits, _ = model(x) | |
# take the logits at the last position | |
logits = logits[:, -1, :] # (B, vocab_size) | |
# get the probabilities | |
probs = F.softmax(logits, dim=-1) | |
# do top-k sampling of 50 (huggingface pipeline default) | |
# topk_probs here becomes (5, 50), topk_indices is (5, 50) | |
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) | |
# select a token from the top-k probabilities | |
# note: multinomial does not demand the input to sum to 1 | |
ix = torch.multinomial(topk_probs, 1) # (B, 1) | |
# gather the corresponding indices | |
xcol = torch.gather(topk_indices, -1, ix) # (B, 1) | |
# append to the sequence | |
x = torch.cat((x, xcol), dim=1) | |
generated_text = [] | |
for i in range(num_return_sequences): | |
tokens = x[i, :max_length].tolist() | |
decoded = tokenizer.decode(tokens) | |
generated_text.append("\n>>>\n" + decoded + "\n\n") | |
return "".join(generated_text) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = load_model(R"GPT-Shakespeare.pth") | |
# Define the Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
title="Text Generation using GPT", | |
description = "<p style='text-align: center'> Decoder only transformer trained on \"Coriolanus\" by William Shakespeare </p>", | |
inputs= [ | |
gr.Textbox( | |
label="Input Text", | |
placeholder="Enter the text in style of Coriolanus", | |
lines=5 | |
), | |
gr.Slider(minimum=1, | |
maximum=128, | |
step=1, | |
value=32, | |
label="Max Sequence Length"), | |
gr.Slider(minimum=1, | |
maximum=5, | |
step=1, | |
value=2, | |
label="Number of Sequences to Return") | |
], | |
outputs=gr.Textbox(lines=5, | |
placeholder="Generated text will be shown here", | |
label="Generated Text"), | |
article= """<a href='https://github.com/KD1994/session-12-Transformer-from-scratch-pt2' target='_blank'> \ | |
<i class='fab fa-github' style='font-size: 24px;'></i></a> \ | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">""", | |
examples = [ | |
["My noble Coriolanus, temper thy rage. These men hold", 32, 3], | |
["Wisdom, say’st thou? Counsel, and truth? Nay, Menenius, they are", 64, 2], | |
["What speaks this man of war and violence?", 50, 1], | |
["Enough, Coriolanus! Thy words grow wild. What wouldst thou have?", 32, 5] | |
] | |
) | |
# Add error handling to launch | |
try: | |
demo.launch() | |
except Exception as e: | |
print(f"Error launching interface: {str(e)}") | |
print(traceback.format_exc()) | |