Spaces:
Sleeping
Sleeping
first commit
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- README.md +74 -1
- app.py +63 -0
- input.txt +0 -0
- logs.txt +7 -0
- model.pt +3 -0
- requirements.txt +3 -0
- resources/encoder-2.png +0 -0
- resources/interface.png +0 -0
- train.py +246 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
virtual/
|
2 |
+
__pycache__/
|
3 |
+
|
4 |
+
transformer.py
|
README.md
CHANGED
@@ -1 +1,74 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: GPT Text Generator
|
3 |
+
emoji: 🤖
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.50.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# GPT Text Generator
|
13 |
+
|
14 |
+
A simple text generation interface for a custom-trained GPT model.
|
15 |
+
|
16 |
+
## Project Overview
|
17 |
+
|
18 |
+
This project implements a GPT (Generative Pre-trained Transformer) model from scratch using PyTorch. The model is trained on text data and can generate human-like text based on given prompts.
|
19 |
+
|
20 |
+
### Model Architecture
|
21 |
+
|
22 |
+
<div align="center">
|
23 |
+
<img src="resources/encoder-2.png" alt="Transformer Architecture" width="600"/>
|
24 |
+
<p><i>Transformer Architecture (Note: Our GPT implementation uses only the Decoder part, right side of the image)<br/>Source: Attention Is All You Need paper</i></p>
|
25 |
+
</div>
|
26 |
+
|
27 |
+
Our implementation includes:
|
28 |
+
- Transformer-based architecture with 12 layers (decoder-only)
|
29 |
+
- 12 attention heads
|
30 |
+
- 768-dimensional embeddings
|
31 |
+
- Context window of 128 tokens
|
32 |
+
- ~125M parameters
|
33 |
+
- Uses GPT-2 tokenizer (50,257 vocabulary size)
|
34 |
+
|
35 |
+
### Features
|
36 |
+
- Custom implementation of Multi-Head Attention
|
37 |
+
- Position embeddings
|
38 |
+
- Layer normalization
|
39 |
+
- Residual connections
|
40 |
+
- Dropout for regularization
|
41 |
+
|
42 |
+
### Training
|
43 |
+
- Trained using AdamW optimizer
|
44 |
+
- Learning rate: 3e-4
|
45 |
+
- Batch size: 1
|
46 |
+
- Maximum sequence length: 128 tokens
|
47 |
+
- Training continues until loss < 0.1 or max epochs reached
|
48 |
+
- Generates sample text every 10 epochs to monitor progress
|
49 |
+
|
50 |
+
### Interface
|
51 |
+
The model is served through a Gradio interface that allows users to:
|
52 |
+
- Input custom prompts
|
53 |
+
- Adjust generation parameters:
|
54 |
+
- Maximum length of generated text
|
55 |
+
- Temperature (controls randomness)
|
56 |
+
- Top-k sampling parameter
|
57 |
+
|
58 |
+
<div align="center">
|
59 |
+
<img src="resources/interface.png" alt="Interface Demo" width="800"/>
|
60 |
+
<p><i>Text Generation Interface</i></p>
|
61 |
+
</div>
|
62 |
+
|
63 |
+
## Usage
|
64 |
+
Enter your prompt in the text box and adjust the generation parameters to control the output. Higher temperature values (>1.0) make the output more random, while lower values (<1.0) make it more focused and deterministic.
|
65 |
+
|
66 |
+
### Example Generations
|
67 |
+
|
68 |
+
- "I tell because besides because tell because tell cob Lic tell because because tell why tell tell tell tell because tell tell because tell tell because tell tell tell Tro why because"
|
69 |
+
- "First Citizen:
|
70 |
+
Hello
|
71 |
+
Face titles::
|
72 |
+
jer gentleman:::ello:: Peter:::
|
73 |
+
P heed
|
74 |
+
stuff: upon:: Had furniture imp"
|
app.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import gradio as gr
|
4 |
+
import tiktoken
|
5 |
+
from train import GPT, GPTConfig # Import the model architecture from train.py
|
6 |
+
|
7 |
+
# Initialize model and tokenizer
|
8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
config = GPTConfig()
|
10 |
+
model = GPT(config).to(device)
|
11 |
+
model.load_state_dict(torch.load('model.pt', map_location=device))
|
12 |
+
model.eval()
|
13 |
+
|
14 |
+
# Initialize tokenizer
|
15 |
+
enc = tiktoken.get_encoding('gpt2')
|
16 |
+
|
17 |
+
def generate_text(prompt, max_length=30, temperature=0.8, top_k=50):
|
18 |
+
# Encode the prompt
|
19 |
+
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
20 |
+
|
21 |
+
# Generate text
|
22 |
+
with torch.no_grad():
|
23 |
+
for _ in range(max_length):
|
24 |
+
# Get logits from the model
|
25 |
+
logits = model(input_ids)
|
26 |
+
logits = logits[:, -1, :] / temperature
|
27 |
+
|
28 |
+
# Apply top-k sampling
|
29 |
+
top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
|
30 |
+
probs = F.softmax(top_k_logits, dim=-1)
|
31 |
+
|
32 |
+
# Sample from the distribution
|
33 |
+
ix = torch.multinomial(probs, num_samples=1)
|
34 |
+
next_token = torch.gather(top_k_indices, -1, ix)
|
35 |
+
|
36 |
+
# Append to the sequence
|
37 |
+
input_ids = torch.cat((input_ids, next_token), dim=1)
|
38 |
+
|
39 |
+
# Stop if we generate an end of text token
|
40 |
+
if next_token.item() == enc.eot_token:
|
41 |
+
break
|
42 |
+
|
43 |
+
# Decode the generated text
|
44 |
+
generated_text = enc.decode(input_ids[0].tolist())
|
45 |
+
return generated_text
|
46 |
+
|
47 |
+
# Create the Gradio interface
|
48 |
+
iface = gr.Interface(
|
49 |
+
fn=generate_text,
|
50 |
+
inputs=[
|
51 |
+
gr.Textbox(label="Enter your prompt", lines=2),
|
52 |
+
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Max Length"),
|
53 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
|
54 |
+
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
55 |
+
],
|
56 |
+
outputs=gr.Textbox(label="Generated Text"),
|
57 |
+
title="Text Generation with GPT",
|
58 |
+
description="Enter a prompt and adjust the parameters to generate text using the trained GPT model."
|
59 |
+
)
|
60 |
+
|
61 |
+
# Launch the app
|
62 |
+
if __name__ == "__main__":
|
63 |
+
iface.launch()
|
input.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
logs.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using device: cuda
|
2 |
+
Model parameters: 163087441
|
3 |
+
loaded 338025 tokens
|
4 |
+
1 epoch = 660 batches
|
5 |
+
Stopping training...
|
6 |
+
Checkpoint saved...
|
7 |
+
Training completed...
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2762cc738c94e5d458932bce8f314dec8e9cfaa955597ce085cda2b9db08e2bc
|
3 |
+
size 652393042
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
tiktoken
|
resources/encoder-2.png
ADDED
![]() |
resources/interface.png
ADDED
![]() |
train.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import tiktoken
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class GPTConfig:
|
11 |
+
block_size: int = 1024 # max sequence length
|
12 |
+
vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
|
13 |
+
n_layer: int = 12 # number of layers
|
14 |
+
n_head: int = 12 # number of heads
|
15 |
+
n_embd: int = 768 # embedding dimension
|
16 |
+
dropout: float = 0.1
|
17 |
+
|
18 |
+
|
19 |
+
class MultiHeadAttention(nn.Module):
|
20 |
+
def __init__(self, config: GPTConfig):
|
21 |
+
super().__init__()
|
22 |
+
self.config = config
|
23 |
+
self.n_head = config.n_head
|
24 |
+
self.n_embd = config.n_embd
|
25 |
+
|
26 |
+
self.c_attn = nn.Linear(self.n_embd, 3*self.n_embd)
|
27 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd)
|
28 |
+
|
29 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
30 |
+
self.res_dropout = nn.Dropout(config.dropout)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
B, T, C = x.size() # [B, T, n_embd]
|
34 |
+
|
35 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each
|
36 |
+
|
37 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd//n_head]
|
38 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd//n_head]
|
39 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd//n_head]
|
40 |
+
|
41 |
+
attn = (q @ k.transpose(-2, -1)) * 1.0 / (k.size(-1)**0.5) # [B, n_head, T, T]
|
42 |
+
attn = F.softmax(attn, dim=-1) # [B, n_head, T, T]
|
43 |
+
attn = self.attn_dropout(attn) # [B, n_head, T, T]
|
44 |
+
|
45 |
+
y = attn @ v # [B, n_head, T, n_embd//n_head]
|
46 |
+
|
47 |
+
y = y.transpose(1,2).contiguous().view(B, T, C) # [B, T, n_embd]
|
48 |
+
y = self.c_proj(y) # [B, T, n_embd]
|
49 |
+
y = self.res_dropout(y) # [B, T, n_embd]
|
50 |
+
|
51 |
+
return y
|
52 |
+
|
53 |
+
|
54 |
+
class FeedForward(nn.Module):
|
55 |
+
def __init__(self, config: GPTConfig):
|
56 |
+
super().__init__()
|
57 |
+
self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd)
|
58 |
+
self.c_proj = nn.Linear(4*config.n_embd, config.n_embd)
|
59 |
+
self.drop = nn.Dropout(config.dropout)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.c_fc(x) # [B, T, 4*n_embd]
|
63 |
+
x = F.gelu(x) # [B, T, 4*n_embd]
|
64 |
+
x = self.c_proj(x) # [B, T, n_embd]
|
65 |
+
x = self.drop(x) # [B, T, n_embd]
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class Block(nn.Module):
|
70 |
+
def __init__(self, config: GPTConfig):
|
71 |
+
super().__init__()
|
72 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
73 |
+
self.attn = MultiHeadAttention(config)
|
74 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
75 |
+
self.mlp = FeedForward(config)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
x = x + self.attn(self.ln_1(x))
|
79 |
+
x = x + self.mlp(self.ln_2(x))
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class GPT(nn.Module):
|
84 |
+
def __init__(self, config: GPTConfig):
|
85 |
+
super().__init__()
|
86 |
+
self.config = config
|
87 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
88 |
+
self.wpe = nn.Embedding(config.block_size, config.n_embd)
|
89 |
+
self.dropout = nn.Dropout(config.dropout)
|
90 |
+
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
91 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
92 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
93 |
+
|
94 |
+
self.apply(self._init_weights)
|
95 |
+
|
96 |
+
def _init_weights(self, module):
|
97 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
98 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
99 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
100 |
+
module.bias.data.zero_()
|
101 |
+
elif isinstance(module, nn.LayerNorm):
|
102 |
+
module.bias.data.zero_()
|
103 |
+
module.weight.data.fill_(1.0)
|
104 |
+
|
105 |
+
def forward(self, idx):
|
106 |
+
B, T = idx.size()
|
107 |
+
|
108 |
+
pos = torch.arange(T, dtype=torch.long, device=idx.device).unsqueeze(0) # [1, T]
|
109 |
+
|
110 |
+
wte = self.wte(idx) # [B, T, n_embd]
|
111 |
+
wpe = self.wpe(pos) # [1, T, n_embd]
|
112 |
+
|
113 |
+
x = self.dropout(wte+wpe) # [B, T, n_embd]
|
114 |
+
|
115 |
+
# Transformer blocks
|
116 |
+
for block in self.blocks:
|
117 |
+
x = block(x) # [B, T, n_embd]
|
118 |
+
|
119 |
+
x = self.ln_f(x) # [B, T, n_embd]
|
120 |
+
logits = self.lm_head(x) # [B, T, vocab_size]
|
121 |
+
|
122 |
+
return logits
|
123 |
+
|
124 |
+
|
125 |
+
class DataLoaderLite:
|
126 |
+
def __init__(self, B, T):
|
127 |
+
self.B = B
|
128 |
+
self.T = T
|
129 |
+
|
130 |
+
# at init load tokens from disk and store them in memory
|
131 |
+
with open('input.txt', 'r') as f:
|
132 |
+
text = f.read()
|
133 |
+
enc = tiktoken.get_encoding('gpt2')
|
134 |
+
tokens = enc.encode(text)
|
135 |
+
self.tokens = torch.tensor(tokens)
|
136 |
+
print(f'loaded {len(self.tokens)} tokens')
|
137 |
+
print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
|
138 |
+
|
139 |
+
# state
|
140 |
+
self.current_position = 0
|
141 |
+
self.current_epoch = 0 # Track the current epoch
|
142 |
+
|
143 |
+
def next_batch(self):
|
144 |
+
B, T = self.B, self.T
|
145 |
+
buf = self.tokens[self.current_position: self.current_position + B * T + 1]
|
146 |
+
x = (buf[:-1]).view(B, T) # inputs
|
147 |
+
y = (buf[1:]).view(B, T) # targets
|
148 |
+
# advance the position in the tensor
|
149 |
+
self.current_position += B*T
|
150 |
+
|
151 |
+
# if loading the next batch would be out of bounds, reset
|
152 |
+
if self.current_position + (B * T + 1) > len(self.tokens):
|
153 |
+
self.current_position = 0
|
154 |
+
self.current_epoch += 1 # Increment the epoch count
|
155 |
+
|
156 |
+
return x, y
|
157 |
+
|
158 |
+
|
159 |
+
def generate_sequences(model, enc, num_return_sequences=5, max_length=30, device='cpu'):
|
160 |
+
# x = torch.zeros((num_return_sequences, 1), dtype=torch.long, device=device) # Initialize with a start token or zeros
|
161 |
+
prompt = "I tell "
|
162 |
+
tokens = enc.encode(prompt)
|
163 |
+
x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
|
164 |
+
|
165 |
+
while x.size(1) < max_length:
|
166 |
+
# forward the model to get the logits
|
167 |
+
with torch.no_grad():
|
168 |
+
logits = model(x) # (B, T, vocab_size)
|
169 |
+
# take the logits at the last position
|
170 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
171 |
+
# get the probabilities
|
172 |
+
probs = F.softmax(logits, dim=-1)
|
173 |
+
# do top-k sampling of 50 (huggingface pipeline default)
|
174 |
+
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
175 |
+
# select a token from the top-k probabilities
|
176 |
+
ix = torch.multinomial(topk_probs, 1) # (B, 1)
|
177 |
+
# gather the corresponding indices
|
178 |
+
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
|
179 |
+
# append to the sequence
|
180 |
+
x = torch.cat((x, xcol), dim=1)
|
181 |
+
|
182 |
+
# print the generated text
|
183 |
+
for i in range(num_return_sequences):
|
184 |
+
tokens = x[i, :max_length].tolist()
|
185 |
+
decoded = enc.decode(tokens)
|
186 |
+
print(">", decoded)
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
# SEED
|
191 |
+
torch.manual_seed(1337)
|
192 |
+
if torch.cuda.is_available():
|
193 |
+
torch.cuda.manual_seed(1337)
|
194 |
+
|
195 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
196 |
+
print(f"using device: {device}")
|
197 |
+
|
198 |
+
config = GPTConfig()
|
199 |
+
model = GPT(config).to(device)
|
200 |
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
|
201 |
+
|
202 |
+
train_loader = DataLoaderLite(B = 4, T = 128)
|
203 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)
|
204 |
+
|
205 |
+
num_epochs = 140 # Set the number of epochs
|
206 |
+
for epoch in range(num_epochs): # Loop over the specified number of epochs
|
207 |
+
stop = False
|
208 |
+
num_batches = len(train_loader.tokens) // (train_loader.B * train_loader.T) # Total number of batches
|
209 |
+
|
210 |
+
# Initialize tqdm instance
|
211 |
+
progress_bar = tqdm(range(num_batches), desc=f'Epoch {epoch + 1}, Loss: 0.0000') # Initial description
|
212 |
+
|
213 |
+
for _ in progress_bar: # Use the tqdm instance
|
214 |
+
x, y = train_loader.next_batch()
|
215 |
+
x, y = x.to(device), y.to(device)
|
216 |
+
optimizer.zero_grad()
|
217 |
+
logits = model(x) # Forward pass
|
218 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) # Compute loss
|
219 |
+
loss.backward() # Backward pass
|
220 |
+
optimizer.step() # Update weights
|
221 |
+
|
222 |
+
# Update tqdm description with the current loss
|
223 |
+
progress_bar.set_description(f'Epoch {epoch + 1}, loss/seq: {loss.item():.4f}') # Update the tqdm description
|
224 |
+
|
225 |
+
if loss.item() < 0.099999: # coz i raised the seq_len in dataloader by 4x
|
226 |
+
stop = True
|
227 |
+
break
|
228 |
+
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
del x, y, logits, loss
|
231 |
+
|
232 |
+
if stop:
|
233 |
+
print('Stopping training...')
|
234 |
+
torch.save(model.state_dict(), f'model.pt')
|
235 |
+
print("Checkpoint saved...")
|
236 |
+
break
|
237 |
+
|
238 |
+
# Generate sequences, save model every 10 epochs
|
239 |
+
if (epoch + 1) % 10 == 0:
|
240 |
+
print(f"Generating sequences after epoch {epoch + 1}:")
|
241 |
+
generate_sequences(model, tiktoken.get_encoding('gpt2'), device=device)
|
242 |
+
torch.save(model.state_dict(), f'model.pt')
|
243 |
+
print("Checkpoint saved...")
|
244 |
+
print("#"*30)
|
245 |
+
|
246 |
+
print("Training completed...")
|