File size: 3,579 Bytes
292f256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da298a9
 
292f256
 
 
 
 
 
 
 
 
3ef2e7b
292f256
 
 
 
 
 
 
 
 
 
da298a9
292f256
da298a9
 
c3fc384
 
 
 
da298a9
292f256
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import traceback
import tiktoken
import gradio as gr
import torch.nn.functional as F
from transformer import GPTConfig, GPT


def load_model(model_path):
    # model
    config = GPTConfig()
    model = GPT(config)

    ckpt = torch.load(os.path.join(model_path), map_location="cpu")
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()

    return model


def generate_text(text, max_length=64, num_return_sequences=2):

    tokenizer = tiktoken.get_encoding('gpt2')
    x = tokenizer.encode(text)
    x = torch.tensor(x, dtype=torch.long)    
    x = x.unsqueeze(0)
    x = x.repeat(num_return_sequences, 1)
    x = x.to(device)

    for _ in range(max_length):

        with torch.no_grad():
            logits, _ = model(x)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        x = torch.cat((x, xcol), dim=1)
    
    generated_text = []
    for i in range(num_return_sequences):
        tokens = x[i, :max_length].tolist()
        decoded = tokenizer.decode(tokens)
        generated_text.append("\n>>>\n" + decoded + "\n\n")

    return "".join(generated_text)


device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(R"GPT-Shakespeare.pth")

# Define the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    title="Text Generation using GPT",
    description = "<p style='text-align: center'> Decoder only transformer trained on \"Coriolanus\" by William Shakespeare </p>",
    inputs= [
        gr.Textbox(
            label="Input Text", 
            placeholder="Enter the text in style of Coriolanus",
            lines=5
            ),
        gr.Slider(minimum=1, 
                  maximum=128, 
                  step=1, 
                  value=32, 
                  label="Max Sequence Length"),
        gr.Slider(minimum=1, 
                  maximum=5, 
                  step=1, 
                  value=2, 
                  label="Number of Sequences to Return")
    ],
    outputs=gr.Textbox(lines=5, 
                       placeholder="Generated text will be shown here", 
                       label="Generated Text"),
    article= """<a href='https://github.com/KD1994/session-12-Transformer-from-scratch-pt2' target='_blank'> \
            <i class='fab fa-github' style='font-size: 24px;'></i></a> \
            <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">""",
    examples = [
        ["My noble Coriolanus, temper thy rage. These men hold", 32, 3], 
        ["Wisdom, say’st thou? Counsel, and truth? Nay, Menenius, they are", 64, 2],
        ["What speaks this man of war and violence?", 50, 1],
        ["Enough, Coriolanus! Thy words grow wild. What wouldst thou have?", 32, 5]
    ]
)

# Add error handling to launch
try:
    demo.launch()
except Exception as e:
    print(f"Error launching interface: {str(e)}")
    print(traceback.format_exc())